Installing the stack#
jax-ai-stack
is a metapackage that can be installed with the following command:
pip install jax-ai-stack
This pins particular versions of component projects which are known to work correctly together via the integration tests in this repository. Packages include:
JAX: the core JAX package, which includes array operations and program transformations like
jit
,vmap
,grad
, etc.flax: build neural networks with JAX
ml_dtypes: NumPy dtype extensions for machine learning.
optax: gradient processing and optimization in JAX.
orbax: checkpointing and persistence utilities for JAX.
Optional packages#
Additionally, there are optional packages you can install with pip
extras.
The following command:
pip install jax-ai-stack[grain]
will install a compatible version of the grain data loader (currently linux-only).
Similarly, the following command:
pip install jax-ai-stack[tfds]
will install a compatible version of tensorflow and tensorflow-datasets.
Pinned versions#
The jax-ai-stack
meta-package does periodic releases, with date-based version strings. For
example, if you’d like to pin the set of packages from November 2024, you can use this installation
command:
pip install jax-ai-stack==2024.11.1
For the full list of released versions and the pinned packages, refer to the Change log.