user395882
user395882

Reputation: 665

Not able to import python package jax in Google TPU

I am working on linux console and typing python takes me into the python console. When I use the following command in TPU machine

import jax

then it generates following mss and get out of the python prompt.

paramjeetsingh80@t1v-n-1c883486-w-0:~$ python3
Python 3.8.5 (default, Jan 27 2021, 15:41:15)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2021-07-08 17:41:39.660523: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)
paramjeetsingh80@t1v-n-1c883486-w-0:~$

This issue is causing problem in my code so I would like to figure out, what is this issue and how to get rid of this?

Upvotes: 3

Views: 1168

Answers (2)

jakevdp
jakevdp

Reputation: 86328

It may be that your system does not have the correct version of libtpu. Try installing the version listed here.

You should be able to do this automatically with

$ pip install -U pip  # older pip may not support extra requirements
$ pip install -U jax  # newer jax required for [tpu] extras declaration
$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Upvotes: 3

user395882
user395882

Reputation: 665

Above command give some error but I researched and below command worked for me. But your answer give me the direction that it is a package issue.

pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Upvotes: 1

Related Questions