JAX: Difference between revisions
imported>Samuela No edit summary |
imported>Samuela No edit summary |
||
| Line 16: | Line 16: | ||
python3Packages.jaxlib | python3Packages.jaxlib | ||
]; | ]; | ||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
| Line 41: | Line 38: | ||
(python3Packages.jaxlib.override { cudaSupport = true; }) | (python3Packages.jaxlib.override { cudaSupport = true; }) | ||
]; | ]; | ||
} | } | ||
</syntaxHighlight> | </syntaxHighlight> | ||
| Line 51: | Line 45: | ||
It should print either "cpu", "gpu", or "tpu". | It should print either "cpu", "gpu", or "tpu". | ||
== FAQ == | |||
=== RuntimeError: Unknown: no kernel image is available for execution on the device === | |||
This usually indicates that you have a driver version that is too old for the CUDA toolkit version the package is built with. The easiest fix is to set the environment variable <code>XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1"</code>. Also consider upgrading your CUDA driver. | |||
See https://github.com/google/jax/issues/5723#issuecomment-913038780. | |||