JAX: Difference between revisions

imported>Samuela
No edit summary
imported>Samuela
No edit summary
Line 16: Line 16:
     python3Packages.jaxlib
     python3Packages.jaxlib
   ];
   ];
  # See https://github.com/google/jax/issues/5723#issuecomment-913038780
  XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
}
</syntaxHighlight>
</syntaxHighlight>
Line 41: Line 38:
     (python3Packages.jaxlib.override { cudaSupport = true; })
     (python3Packages.jaxlib.override { cudaSupport = true; })
   ];
   ];
  # See https://github.com/google/jax/issues/5723#issuecomment-913038780
  XLA_FLAGS = "--xla_gpu_force_compilation_parallelism=1";
}
}
</syntaxHighlight>
</syntaxHighlight>
Line 51: Line 45:


It should print either "cpu", "gpu", or "tpu".
It should print either "cpu", "gpu", or "tpu".
== FAQ ==
=== RuntimeError: Unknown: no kernel image is available for execution on the device ===
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable <code>XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"</code>. Also consider upgrading your CUDA driver.
See https://github.com/google/jax/issues/5723#issuecomment-913038780.