JAX: Difference between revisions

imported>Samuela
No edit summary
imported>Samuela
No edit summary
Line 9: Line 9:
let
let
   # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
   # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
   pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/b7845603e0d986c40634688e908675f9e33adf47.tar.gz")) {};
   pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {};
in pkgs.mkShell {
in pkgs.mkShell {
   buildInputs = with pkgs; [
   buildInputs = with pkgs; [
Line 16: Line 16:
     python3Packages.jaxlib
     python3Packages.jaxlib
   ];
   ];
  # See https://github.com/google/jax/issues/5723#issuecomment-913038780
  XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
}
</syntaxHighlight>
</syntaxHighlight>
Line 31: Line 34:
let
let
   # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
   # Fork with jax/jaxlib. See https://github.com/NixOS/nixpkgs/pull/134894.
   pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/b7845603e0d986c40634688e908675f9e33adf47.tar.gz")) {};
   pkgs = import (fetchTarball("https://github.com/samuela/nixpkgs/archive/dfc52fafe0cb7f574493e0e7c44c57200dd1a0fe.tar.gz")) {};
in pkgs.mkShell {
in pkgs.mkShell {
   buildInputs = with pkgs; [
   buildInputs = with pkgs; [
Line 38: Line 41:
     (python3Packages.jaxlib.override { cudaSupport = true; })
     (python3Packages.jaxlib.override { cudaSupport = true; })
   ];
   ];
  # See https://github.com/google/jax/issues/5723#issuecomment-913038780
  XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
}
</syntaxHighlight>
</syntaxHighlight>