Laura K.
Laura K.

Reputation: 71

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'

Could anyone please help me fix the following error when going through "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)"? Thanks a lot.

"AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'"

Upvotes: 2

Views: 8571

Answers (1)

jakevdp
jakevdp

Reputation: 86328

jaxlib.xla_extension.PmapFunction was added in jaxlib version 0.72; it sounds like you have an older jaxlib version installed. You should update it using:

pip install -U jaxlib

Note: if you're using GPU/TPU, you should instead use the appropriate accelerator-specific installation command found at https://github.com/google/jax#installation.

If this does not work, please check your Python version. jaxlib began requiring Python 3.7 or newer in version 0.1.70, so if you are using Python 3.6, you will need to upgrade Python before you can upgrade to a more recent jaxlib.

It appears the problematic line was added to the haiku package a few hours before you posted the question: https://github.com/deepmind/dm-haiku/commit/e6a13af352a8b46d355ac1b7131b64c615cfcf57 Another option if you don't want to update jaxlib would be to install a stable version of dm-haiku rather than using the development version:

pip install dm-haiku==0.0.5

Upvotes: 3

Related Questions