JAX: Difference between revisions

imported>Samuela
No edit summary
Pigs (talk | contribs)
Remove out of date note, add link to nixpkgs and other wiki pages
 
(6 intermediate revisions by 2 users not shown)
Line 1: Line 1:
[https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the <code>python3Packages.{jax, jaxlib, jaxlibWithCuda}</code> packages.
[https://github.com/google/jax JAX] is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available in [[Nixpkgs]] in the <code>python3Packages.{{{nixos:package|python3Packages.jax|jax}}, {{nixos:package|python3%20jaxlib|jaxlib}}, {{nixos:package|python3*.jaxlibWithCuda|jaxlibWithCuda}}}</code> packages.


NOTE: JAX requires Python 3.9, the current version of <code>python3</code> in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!).  
{{tip|1='''Cache''': Using the [https://app.cachix.org/cache/nix-community nix-community cache] is recommended! It will save you valuable time and electrons. Getting set up should be as simple as <code>cachix use nix-community</code>. See the [[CUDA]] wiki page for more info.
}}


== Example shell.nix, CPU only ==
== Example shell.nix, CPU only ==
<syntaxHighlight lang=nix>
<syntaxHighlight lang=nix>
let
let
   pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/eac6215607e4ccceb9631b01ee8f8903a6e82e02.tar.gz")) {};
  # Last updated 01/31/2022. Check status.nixos.org for updates.
   pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
in pkgs.mkShell {
   buildInputs = with pkgs; [
   buildInputs = with pkgs; [
Line 17: Line 19:


== Example shell.nix with GPU support ==
== Example shell.nix with GPU support ==
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your <code>/etc/nixos/configuration.nix</code>:
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a [[NVIDIA]] GPU and [[OpenGL]]. In your <code>/etc/nixos/configuration.nix</code>:
<syntaxHighlight lang=nix>
<syntaxHighlight lang=nix>
# NVIDIA drivers are unfree
# NVIDIA drivers are unfree
Line 24: Line 26:
hardware.opengl.enable = true;
hardware.opengl.enable = true;
</syntaxHighlight>
</syntaxHighlight>
Then you can use the jaxlib package by setting the <code>cudaSupport</code> parameter:
Then you can use the <code>jaxlibWithCuda</code> package (equivalent to setting the <code>cudaSupport</code> parameter):
<syntaxHighlight lang=nix>
<syntaxHighlight lang=nix>
let
let
   # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
   # Last updated 01/31/2022. Check status.nixos.org for updates.
   pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/eac6215607e4ccceb9631b01ee8f8903a6e82e02.tar.gz")) {};
   pkgs = import (fetchTarball("https://github.com/NixOS/nixpkgs/archive/376934f4b7ca6910b243be5fabcf3f4228043725.tar.gz")) {};
in pkgs.mkShell {
in pkgs.mkShell {
   buildInputs = with pkgs; [
   buildInputs = with pkgs; [
Line 43: Line 45:
It should print either "cpu", "gpu", or "tpu".
It should print either "cpu", "gpu", or "tpu".


Note that hydra may not cache `jaxlibWithCuda` builds on cache.nixos.org since CUDA is "unfree." @samuela publishes builds on a public cachix [https://app.cachix.org/cache/ploop#pull ploop] cache. These are periodically built and pushed from [https://github.com/samuela/nixpkgs-upkeep/ nixpkgs-upkeep].
Note that hydra may not cache <code>jaxlibWithCuda</code> builds on cache.nixos.org since CUDA is "unfree." @samuela publishes builds on a public cachix [https://app.cachix.org/cache/ploop#pull ploop] cache. These are periodically built and pushed from [https://github.com/samuela/nixpkgs-upkeep/ nixpkgs-upkeep].


== FAQ ==
== FAQ ==
=== How do I package JAX libraries? ===
Never ever ever put <code>jaxlib</code> in <code>propagatedBuildInputs</code>. However, it may live happily in <code>buildInputs</code> or <code>checkInputs</code>. See https://github.com/NixOS/nixpkgs/pull/156808 for context.
=== RuntimeError: Unknown: no kernel image is available for execution on the device ===
=== RuntimeError: Unknown: no kernel image is available for execution on the device ===
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable <code>XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"</code>. Also consider upgrading your CUDA driver.  
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable <code>XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"</code>. Also consider upgrading your CUDA driver.  
Line 52: Line 57:


[[Category:Applications]]
[[Category:Applications]]
[[Category:Python]]