Reputation: 71
Could anyone please help me fix the following error when going through "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)"? Thanks a lot.
"AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'"
Upvotes: 2
Views: 8571
Reputation: 86328
jaxlib.xla_extension.PmapFunction
was added in jaxlib version 0.72; it sounds like you have an older jaxlib version installed. You should update it using:
pip install -U jaxlib
Note: if you're using GPU/TPU, you should instead use the appropriate accelerator-specific installation command found at https://github.com/google/jax#installation.
If this does not work, please check your Python version. jaxlib began requiring Python 3.7 or newer in version 0.1.70, so if you are using Python 3.6, you will need to upgrade Python before you can upgrade to a more recent jaxlib.
It appears the problematic line was added to the haiku
package a few hours before you posted the question: https://github.com/deepmind/dm-haiku/commit/e6a13af352a8b46d355ac1b7131b64c615cfcf57
Another option if you don't want to update jaxlib would be to install a stable version of dm-haiku
rather than using the development version:
pip install dm-haiku==0.0.5
Upvotes: 3