Dr. Wilson
Dr. Wilson

Reputation: 31

module 'jax' has no attribute 'tree_multimap' in AlphaFold2 CoLab

I am attempting to model a protein using an AlphaFold2 (AlphaFold v2.1.0.) CoLab (https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb#scrollTo=pc5-mbsX9PZC).

I have done this successfully on 9/2/2022. However I have repeatedly had issues since 9/7/2022 doing the modelling with a different peptide sequence.

I get the following warning when I run the search against the genetic databases:

/opt/conda/lib/python3.7/site-packages/haiku/_src/data_structures.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))

I then get several other future warnings when I run AlphaFold2 about other jax.tree_ deprecations.

The problem with AlphaFold running seems to be related to this:

AttributeError: module 'jax' has no attribute 'tree_multimap'

I have tried substituting jax.tree_util.tree_structure with no success.

I see another question on stackoverflow that is similar (AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'), however I do not know how best to implement the solution in the CoLab environment.

How should I fix this issue so that AlphaFold2 will run properly?

Traceback shown below:

     44     processed_feature_dict = model_runner.process_features(np_example, random_seed=0)
---> 45     prediction = model_runner.predict(processed_feature_dict, random_seed=random.randrange(sys.maxsize))

/opt/conda/lib/python3.7/site-packages/haiku/_src/stateful.py in difference(before, after)
    310   params_before, params_after = box_and_fill_missing(before.params,
    311                                                      after.params)
--> 312   params_after = jax.tree_multimap(functools.partial(if_changed, is_new_param),
    313                                    params_before, params_after)

Upvotes: 3

Views: 7769

Answers (1)

jakevdp
jakevdp

Reputation: 86328

jax.tree_multimap was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16.

You can either change the source to use jax.tree_map as a drop-in replacement for jax.tree_multimap, or install an older version of JAX, e.g.:

!pip install "jax<=0.3.16" "jaxlib<=0.3.16"

And then be sure to restart your runtime to pick up the new versiom.

Upvotes: 3

Related Questions