JAX

From NixOS Wiki

JAX is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the python3Packages.{jax, jaxlib, jaxlibWithCuda} packages.

Cache: Using the cuda-maintainers cache is recommended! It will save you valuable time and electrons. Getting set up should be as simple as cachix use cuda-maintainers. See the CUDA wiki page for more info.

NOTE: JAX requires Python 3.9, the current version of python3 in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!).

Example shell.nix, CPU only

let
  # Last updated 01/31/2022. Check status.nixos.org for updates.
  pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
  buildInputs = with pkgs; [
    python3
    python3Packages.jax
    python3Packages.jaxlib
  ];
}

Example shell.nix with GPU support

JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your /etc/nixos/configuration.nix:

# NVIDIA drivers are unfree
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = [ "nvidia" ];
hardware.opengl.enable = true;

Then you can use the jaxlibWithCuda package (equivalent to setting the cudaSupport parameter):

let
  # Last updated 01/31/2022. Check status.nixos.org for updates.
  pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
  buildInputs = with pkgs; [
    python3
    python3Packages.jax
    python3Packages.jaxlibWithCuda
  ];
}

You can test that JAX is using the GPU as intended with

python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"

It should print either "cpu", "gpu", or "tpu".

Note that hydra may not cache jaxlibWithCuda builds on cache.nixos.org since CUDA is "unfree." @samuela publishes builds on a public cachix ploop cache. These are periodically built and pushed from nixpkgs-upkeep.

FAQ

How do I package JAX libraries?

Never ever ever put jaxlib in propagatedBuildInputs. However, it may live happily in buildInputs or checkInputs. See https://github.com/NixOS/nixpkgs/pull/156808 for context.

RuntimeError: Unknown: no kernel image is available for execution on the device

This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1". Also consider upgrading your CUDA driver.

See https://github.com/google/jax/issues/5723#issuecomment-913038780.