Reputation: 31
I am attempting to model a protein using an AlphaFold2 (AlphaFold v2.1.0.) CoLab (https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb#scrollTo=pc5-mbsX9PZC).
I have done this successfully on 9/2/2022. However I have repeatedly had issues since 9/7/2022 doing the modelling with a different peptide sequence.
I get the following warning when I run the search against the genetic databases:
/opt/conda/lib/python3.7/site-packages/haiku/_src/data_structures.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
PyTreeDef = type(jax.tree_structure(None))
I then get several other future warnings when I run AlphaFold2 about other jax.tree_ deprecations.
The problem with AlphaFold running seems to be related to this:
AttributeError: module 'jax' has no attribute 'tree_multimap'
I have tried substituting jax.tree_util.tree_structure
with no success.
I see another question on stackoverflow that is similar (AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'), however I do not know how best to implement the solution in the CoLab environment.
How should I fix this issue so that AlphaFold2 will run properly?
Traceback shown below:
44 processed_feature_dict = model_runner.process_features(np_example, random_seed=0)
---> 45 prediction = model_runner.predict(processed_feature_dict, random_seed=random.randrange(sys.maxsize))
/opt/conda/lib/python3.7/site-packages/haiku/_src/stateful.py in difference(before, after)
310 params_before, params_after = box_and_fill_missing(before.params,
311 after.params)
--> 312 params_after = jax.tree_multimap(functools.partial(if_changed, is_new_param),
313 params_before, params_after)
Upvotes: 3
Views: 7769
Reputation: 86328
jax.tree_multimap
was deprecated in JAX version 0.3.5, and removed in JAX version 0.3.16.
You can either change the source to use jax.tree_map
as a drop-in replacement for jax.tree_multimap
, or install an older version of JAX, e.g.:
!pip install "jax<=0.3.16" "jaxlib<=0.3.16"
And then be sure to restart your runtime to pick up the new versiom.
Upvotes: 3