JAX: Difference between revisions
imported>Samuela No edit summary |
imported>Samuela No edit summary |
||
Line 9: | Line 9: | ||
let | let | ||
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | ||
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/ | pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
Line 16: | Line 16: | ||
python3Packages.jaxlib | python3Packages.jaxlib | ||
]; | ]; | ||
# See https://github.com/google/jax/issues/5723#issuecomment-913038780 | |||
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1"; | |||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
Line 31: | Line 34: | ||
let | let | ||
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | ||
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/ | pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
Line 38: | Line 41: | ||
(python3Packages.jaxlib.override { cudaSupport = true; }) | (python3Packages.jaxlib.override { cudaSupport = true; }) | ||
]; | ]; | ||
# See https://github.com/google/jax/issues/5723#issuecomment-913038780 | |||
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1"; | |||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> |
Revision as of 21:00, 4 September 2021
JAX is a framework for program transformation, esp. for automatic differentiation and machine learning. It's available on Nix/NixOS in the python3Packages.{jax, jaxlib}
packages.
NOTE: JAX requires Python 3.9, the current version of python3
in nixpkgs (as of 9/4/2021). JAX is currently only packaged for x86_64-linux (send a PR for your platform!).
NOTE: JAX has not yet landed in nixpkgs master (see https://github.com/NixOS/nixpkgs/pull/134894). However, it is available via a fork: https://github.com/samuela/nixpkgs/tree/scratch.
Example shell.nix, CPU only
let
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
python3Packages.jaxlib
];
# See https://github.com/google/jax/issues/5723#issuecomment-913038780
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
Example shell.nix with GPU support
JAX defers execution to the jaxlib library for execution. In order to use GPU support you'll need a NVIDIA GPU and OpenGL. In your /etc/nixos/configuration.nix
:
# NVIDIA drivers are unfree
nixpkgs.config.allowUnfree = true;
services.xserver.videoDrivers = [ "nvidia" ];
hardware.opengl.enable = true;
Then you can use the jaxlib package by setting the cudaSupport
parameter:
let
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {};
in pkgs.mkShell {
buildInputs = with pkgs; [
python3
python3Packages.jax
(python3Packages.jaxlib.override { cudaSupport = true; })
];
# See https://github.com/google/jax/issues/5723#issuecomment-913038780
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
You can test that JAX is using the GPU as intended with
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
It should print either "cpu", "gpu", or "tpu".