JAX: Difference between revisions
imported>Samuela No edit summary |
m →Example shell.nix with GPU support: syntax highlight code block, additional crosslinking and formatting |
||
| (6 intermediate revisions by 4 users not shown) | |||
| Line 1: | Line 1: | ||
[https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available | [https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available in [[Nixpkgs]] in the <code>python3Packages.{{{nixos:package|python3Packages.jax|jax}}, {{nixos:package|python3%20jaxlib|jaxlib}}, {{nixos:package|python3*.jaxlibWithCuda|jaxlibWithCuda}}}</code> packages. | ||
{{tip|1='''Cache''': Using the [https://app.cachix.org/cache/nix-community nix-community cache] is recommended! It will save you valuable time and electrons. Getting set up should be as simple as <code>cachix use nix-community</code>. See the [[CUDA]] wiki page for more info. | |||
}} | |||
== Example shell.nix, CPU only == | == Example shell.nix, CPU only == | ||
<syntaxHighlight lang=nix> | <syntaxHighlight lang=nix> | ||
let | let | ||
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/ | # Last updated 01/31/2022. Check status.nixos.org for updates. | ||
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {}; | |||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
| Line 17: | Line 19: | ||
== Example shell.nix with GPU support == | == Example shell.nix with GPU support == | ||
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your <code>/etc/nixos/configuration.nix</code>: | JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a [[NVIDIA]] GPU and [[OpenGL]]. In your <code>/etc/nixos/configuration.nix</code>: | ||
<syntaxHighlight lang=nix> | <syntaxHighlight lang=nix> | ||
# NVIDIA drivers are unfree | # NVIDIA drivers are unfree | ||
| Line 27: | Line 29: | ||
<syntaxHighlight lang=nix> | <syntaxHighlight lang=nix> | ||
let | let | ||
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/ | # Last updated 01/31/2022. Check status.nixos.org for updates. | ||
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {}; | |||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
| Line 38: | Line 41: | ||
You can test that JAX is using the GPU as intended with | You can test that JAX is using the GPU as intended with | ||
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)" | <syntaxHighlight lang=bash> | ||
# after version 0.8.0 | |||
python -c "import jax.extend as ex; print(ex.backend.get_backend().platform)" | |||
# before version 0.8.0 | |||
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)" | |||
</syntaxHighlight> | |||
It should print either | It should print either <code>cpu</code>, <code>gpu</code>, or <code>tpu</code>. | ||
Note | {{Note| [[Hydra]] may not cache <code>jaxlibWithCuda</code> builds on cache.nixos.org since CUDA is [[unfree software]]. @samuela publishes builds on a public cachix [https://app.cachix.org/cache/ploop#pull ploop] cache. These are periodically built and pushed from [https://github.com/samuela/nixpkgs-upkeep/ nixpkgs-upkeep].}} | ||
== FAQ == | == FAQ == | ||
| Line 54: | Line 63: | ||
[[Category:Applications]] | [[Category:Applications]] | ||
[[Category:Python]] | |||
Latest revision as of 19:11, 9 December 2025
JAX is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available in Nixpkgs in the python3Packages.{ packages.
jax, jaxlib, jaxlibWithCuda}
cachix use nix-community. See the CUDA wiki page for more info.
Example shell.nix, CPU only
let
# Last updated 01/31/2022. Check status.nixos.org for updates.
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
python3Packages.jaxlib
];
}
Example shell.nix with GPU support
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your /etc/nixos/configuration.nix:
# NVIDIA drivers are unfree
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = [ "nvidia" ];
hardware.opengl.enable = true;
Then you can use the jaxlibWithCuda package (equivalent to setting the cudaSupport parameter):
let
# Last updated 01/31/2022. Check status.nixos.org for updates.
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
python3Packages.jaxlibWithCuda
];
}
You can test that JAX is using the GPU as intended with
# after version 0.8.0
python -c "import jax.extend as ex; print(ex.backend.get_backend().platform)"
# before version 0.8.0
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
It should print either cpu, gpu, or tpu.
jaxlibWithCuda builds on cache.nixos.org since CUDA is unfree software. @samuela publishes builds on a public cachix ploop cache. These are periodically built and pushed from nixpkgs-upkeep.FAQ
How do I package JAX libraries?
Never ever ever put jaxlib in propagatedBuildInputs. However, it may live happily in buildInputs or checkInputs. See https://github.com/NixOS/nixpkgs/pull/156808 for context.
RuntimeError: Unknown: no kernel image is available for execution on the device
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1". Also consider upgrading your CUDA driver.
See https://github.com/google/jax/issues/5723#issuecomment-913038780.