dcxst
dcxst

Reputation: 192

importing jax fails on mac with m1 chip

For python 3.8.8 and using the new mac air (with the m1 chip), in jupyter notebooks and in python terminal, import jax raises this error

Python 3.8.8 (default, Apr 13 2021, 12:59:45)
[Clang 10.0.0 ] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/steve/Documents/code/jax/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File "/Users/steve/Documents/code/jax/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File "/Users/steve/Documents/code/jax/jax/_src/config.py", line 26, in <module>
    from jax import lib
  File "/Users/steve/Documents/code/jax/jax/lib/__init__.py", line 63, in <module>
    cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.

I suspect it occurs because of the m1 chip.

I tried using jax with pip install jax, then I built it from source as suggested by the comment, by cloning their repository and following the instructions given here, but the same error message shows.

Upvotes: 9

Views: 14404

Answers (4)

Nicholas G Reich
Nicholas G Reich

Reputation: 1138

I had a similar problem. Since I already had Anaconda installed and didn't want to clutter up my space with Anaconda + miniconda + homebrew versions of python and package management and whatever, I hunted around for a simple solution. What ended up working for me was first uninstalling jax and jaxlib and then installing jax and jaxlib via conda-forge directly:

pip uninstall jax jaxlib
conda install -c conda-forge jaxlib
conda install -c conda-forge jax

Upvotes: 10

emil
emil

Reputation: 331

As of now (January 2022), jax is available for M1 Macs. Make sure to uninstall jax and jaxlib and then install the new packages via pip:

pip install --upgrade jax jaxlib

Afterwards, you can use jax without problems.

--Edit-- I am running on a machine with the following specs:

ProductName:    macOS
ProductVersion: 12.1
BuildVersion:   21C52

and with Python 3.9.6 within a conda environment.

Upvotes: 1

dcxst
dcxst

Reputation: 192

Thanks @jakevdp I looked at the issue you linked and found a workaround :

Thanks to Noah who mentioned in issue #5501 that you could just use a previous version of jax and jaxlib, for my purposes jaxlib==0.1.60 and jax==0.2.10 work just fine!

Upvotes: 5

jakevdp
jakevdp

Reputation: 86328

JAX does not yet provide pre-built jaxlib wheels that are compatible with M1 chips. The best source of information I know on building jaxlib on M1 is probably this github issue: https://github.com/google/jax/issues/5501, which also tracks improving this support.

Hopefully M1 support will be improved in the near future, but it's taking a while for the scientific computing infrastructure up and down the stack to catch up with the requirements of the new chips.

Upvotes: 2

Related Questions