JAX: Difference between revisions
m jax.lib.xla_bridge module will be removed in JAX v0.9.0; all its APIs were deprecated and removed by JAX v0.8.0 |
m →Example shell.nix with GPU support: syntax highlight code block, additional crosslinking and formatting |
||
| Line 41: | Line 41: | ||
You can test that JAX is using the GPU as intended with | You can test that JAX is using the GPU as intended with | ||
<syntaxHighlight lang=bash> | |||
# after version 0.8.0 | |||
python -c "import jax.extend as ex; print(ex.backend.get_backend().platform)" | |||
# before version 0.8.0 | |||
python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)" | |||
</syntaxHighlight> | |||
It should print either | It should print either <code>cpu</code>, <code>gpu</code>, or <code>tpu</code>. | ||
Note | {{Note| [[Hydra]] may not cache <code>jaxlibWithCuda</code> builds on cache.nixos.org since CUDA is [[unfree software]]. @samuela publishes builds on a public cachix [https://app.cachix.org/cache/ploop#pull ploop] cache. These are periodically built and pushed from [https://github.com/samuela/nixpkgs-upkeep/ nixpkgs-upkeep].}} | ||
== FAQ == | == FAQ == | ||