junka
junka

Reputation: 163

Stable diffusion: AttributeError: module 'jax.random' has no attribute 'KeyArray'

When I run the stable diffusion on colab https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
with no modification, it fails on the line

from diffusers import StableDiffusionPipeline

The error log is

AttributeError: module 'jax.random' has no attribute 'KeyArray'

How can I fix this or any clue ?

The import should work, the ipynb should run with no error.

Upvotes: 7

Views: 14800

Answers (3)

dbenton
dbenton

Reputation: 313

# Change this
# !pip install diffusers==0.11.1

# To just
!pip install diffusers 

If you've already run pip install in your Colab runtime, you'll need to either disconnect and open a new runtime (my recommendation) or use --upgrade.

Diffusers v0.11.1 is now over 18 months old, and the notebook works with current v0.29.0 without any other changes. Instead of using an old version of diffusers, requiring an old version of jax, we can use the latest versions.

Upvotes: 0

jakevdp
jakevdp

Reputation: 86443

jax.random.KeyArray was deprecated in JAX v0.4.16 and removed in JAX v0.4.24. Given this, it sounds like the HuggingFace stable diffusion code only works JAX v0.4.23 or earlier.

You can install JAX v0.4.23 with GPU support like this:

pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

or, if you prefer targeting a local CUDA installation, like this:

pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

For more information on GPU installation, see JAX Installation: NVIDIA GPU.

From the colab tutorial, update the second segment into:

!pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install diffusers==0.11.1
!pip install transformers scipy ftfy accelerate

Upvotes: 10

junka
junka

Reputation: 163

In the end, we need to downgrade the jax, Try each from the lateset to ealier, and luckily it works for

jax==0.4.23 jaxlib==0.4.23

Upvotes: 0

Related Questions