import jaxEvery framework has thousands of functions and classes. You won’t memorize them — you’ll look them up.
Two Python builtins do most of the work:
dir(module) — what’s in here?help(thing) (or ?thing in Jupyter) — how do I use it?Plus the official docs: pytorch.org, jax.dev, tensorflow.org, mxnet.apache.org.
dir: discovering the APIStandard import:
dir(...) lists names in a module. Filter private names and show a small prefix on slides; in a notebook you can inspect the full list interactively:
['PRNGKey', 'ball', 'bernoulli', 'beta', 'binomial', 'bits', 'categorical', 'cauchy', 'chisquare', 'choice', 'clone', 'dirichlet', 'double_sided_maxwell', ...
help: usage detailsOnce you have the name, help(...) prints the docstring with arguments, defaults, and a usage example:
Help on function ones in module jax.numpy:
ones(shape: Any, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None, *, device: jaxlib._jax.Device | ...
Create an array full of ones.
JAX implementation of :func:`numpy.ones`.
...
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True, True, True],
[ True, True, True]], dtype=bool)
.. _explicit sharding: https://docs.jax.dev/en/latest/parallel.html
dir(module) — list contents.help(symbol) (or symbol? in Jupyter) — show the docstring.Tab) is your fastest discovery tool.