Reputation: 41
I have launched a Google Cloud TPU VM instance and installed the latest version of JAX, but it cannot see my TPU. Following the instructions at https://cloud.google.com/tpu/docs/troubleshooting/trouble-jax I encounter the following:
>>> import jax
>>> jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
>>> TF_CPP_MIN_LOG_LEVEL=0
>>> jax.devices()
[CpuDevice(id=0)]
All of the Google Search results I have seen for this error suggest installing JAX with CUDA support, but shouldn't that be unnecessary with TPUs?
Upvotes: 4
Views: 2225
Reputation: 1017
I recommend to upgrade the jax version.
pip3 install -u jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
EDIT: Actually it seems like this is a bug, refer to:
https://github.com/google/jax/issues/13260
Upvotes: 2