JAX: Difference between revisions
imported>Samuela No edit summary |
imported>Samuela No edit summary |
||
| Line 9: | Line 9: | ||
let | let | ||
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | ||
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/ | pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
| Line 16: | Line 16: | ||
python3Packages.jaxlib | python3Packages.jaxlib | ||
]; | ]; | ||
# See https://github.com/google/jax/issues/5723#issuecomment-913038780 | |||
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1"; | |||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
| Line 31: | Line 34: | ||
let | let | ||
# Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894. | ||
pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/ | pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {}; | ||
in pkgs.mkShell { | in pkgs.mkShell { | ||
buildInputs = with pkgs; [ | buildInputs = with pkgs; [ | ||
| Line 38: | Line 41: | ||
(python3Packages.jaxlib.override { cudaSupport = true; }) | (python3Packages.jaxlib.override { cudaSupport = true; }) | ||
]; | ]; | ||
# See https://github.com/google/jax/issues/5723#issuecomment-913038780 | |||
XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1"; | |||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||