JAX: Difference between revisions
imported>Samuela No edit summary |
imported>Samuela No edit summary |
||
(9 intermediate revisions by 2 users not shown) | |||
Line 1: | Line 1: | ||
[https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the <code>python3Packages.{jax, jaxlib}</code> packages. | [https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the <code>python3Packages.{jax, jaxlib, jaxlibWithCuda}</code> packages. | ||
'''Cache''': Using the [https://app.cachix.org/cache/cuda-maintainers#pull cuda-maintainers cache] is recommended! It will save you valuable time and electrons. Getting set up should be as simple as <code>cachix use cuda-maintainers</code>. See the [[CUDA]] wiki page for more info. | |||
NOTE: JAX requires Python 3.9, the current version of <code>python3</code> in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!). | NOTE: JAX requires Python 3.9, the current version of <code>python3</code> in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!). | ||
== Example shell.nix, CPU only == | == Example shell.nix, CPU only == | ||
<syntaxHighlight lang=nix> | <syntaxHighlight lang=nix> | ||
let | let | ||
# | # Last updated 01/31/2022. Check status.nixos.org for updates. | ||
pkgs = import (fetchTarball("https://github.com/ | pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
Line 16: | Line 16: | ||
python3Packages.jaxlib | python3Packages.jaxlib | ||
]; | ]; | ||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
Line 30: | Line 27: | ||
hardware.opengl.enable = true; | hardware.opengl.enable = true; | ||
</syntaxHighlight> | </syntaxHighlight> | ||
Then you can use the | Then you can use the <code>jaxlibWithCuda</code> package (equivalent to setting the <code>cudaSupport</code> parameter): | ||
<syntaxHighlight lang=nix> | <syntaxHighlight lang=nix> | ||
let | let | ||
# | # Last updated 01/31/2022. Check status.nixos.org for updates. | ||
pkgs = import (fetchTarball("https://github.com/ | pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
python3 | python3 | ||
python3Packages.jax | python3Packages.jax | ||
python3Packages.jaxlibWithCuda | |||
]; | ]; | ||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
Line 51: | Line 45: | ||
It should print either "cpu", "gpu", or "tpu". | It should print either "cpu", "gpu", or "tpu". | ||
Note that hydra may not cache <code>jaxlibWithCuda</code> builds on cache.nixos.org since CUDA is "unfree." @samuela publishes builds on a public cachix [https://app.cachix.org/cache/ploop#pull ploop] cache. These are periodically built and pushed from [https://github.com/samuela/nixpkgs-upkeep/ nixpkgs-upkeep]. | |||
== FAQ == | |||
=== How do I package JAX libraries? === | |||
Never ever ever put <code>jaxlib</code> in <code>propagatedBuildInputs</code>. However, it may live happily in <code>buildInputs</code> or <code>checkInputs</code>. See https://github.com/NixOS/nixpkgs/pull/156808 for context. | |||
=== RuntimeError: Unknown: no kernel image is available for execution on the device === | |||
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable <code>XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"</code>. Also consider upgrading your CUDA driver. | |||
See https://github.com/google/jax/issues/5723#issuecomment-913038780. | |||
[[Category:Applications]] |
Latest revision as of 19:27, 31 March 2022
JAX is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the python3Packages.{jax, jaxlib, jaxlibWithCuda}
packages.
Cache: Using the cuda-maintainers cache is recommended! It will save you valuable time and electrons. Getting set up should be as simple as cachix use cuda-maintainers
. See the CUDA wiki page for more info.
NOTE: JAX requires Python 3.9, the current version of python3
in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!).
Example shell.nix, CPU only
let
# Last updated 01/31/2022. Check status.nixos.org for updates.
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
python3Packages.jaxlib
];
}
Example shell.nix with GPU support
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your /etc/nixos/configuration.nix
:
# NVIDIA drivers are unfree
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = [ "nvidia" ];
hardware.opengl.enable = true;
Then you can use the jaxlibWithCuda
package (equivalent to setting the cudaSupport
parameter):
let
# Last updated 01/31/2022. Check status.nixos.org for updates.
pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
python3Packages.jaxlibWithCuda
];
}
You can test that JAX is using the GPU as intended with
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
It should print either "cpu", "gpu", or "tpu".
Note that hydra may not cache jaxlibWithCuda
builds on cache.nixos.org since CUDA is "unfree." @samuela publishes builds on a public cachix ploop cache. These are periodically built and pushed from nixpkgs-upkeep.
FAQ
How do I package JAX libraries?
Never ever ever put jaxlib
in propagatedBuildInputs
. However, it may live happily in buildInputs
or checkInputs
. See https://github.com/NixOS/nixpkgs/pull/156808 for context.
RuntimeError: Unknown: no kernel image is available for execution on the device
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"
. Also consider upgrading your CUDA driver.
See https://github.com/google/jax/issues/5723#issuecomment-913038780.