JAX

From NixOS Wiki
Revision as of 18:59, 4 September 2021 by imported>Samuela

JAX is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the python3Packages.{jax, jaxlib} packages.

NOTE: JAX requires Python 3.9, the current version of python3 in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!).

NOTE: JAX has not yet landed in nixpkgs master (see https://github.com/NixOS/nixpkgs/pull/134894). However, it is available via a fork: https://github.com/samuela/nixpkgs/tree/scratch.

Example shell.nix, CPU only

let
  # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
  pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/b7845603e0d986c40634688e908675f9e33adf47.tar.gz")) {};
in pkgs.mkShell {
  buildInputs = with pkgs; [
    python3
    python3Packages.jax
    python3Packages.jaxlib
  ];
}

Example shell.nix with GPU support

JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your /etc/nixos/configuration.nix:

# NVIDIA drivers are unfree
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = [ "nvidia" ];
hardware.opengl.enable = true;

Then you can use the jaxlib package by setting the cudaSupport parameter:

let
  # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
  pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/b7845603e0d986c40634688e908675f9e33adf47.tar.gz")) {};
in pkgs.mkShell {
  buildInputs = with pkgs; [
    python3
    python3Packages.jax
    (python3Packages.jaxlib.override { cudaSupport = true; })
  ];
}

You can test that JAX is using the GPU as intended with

python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"

It should print either "cpu", "gpu", or "tpu".