f. c.
f. c.

Reputation: 1135

Why does jax.numpy.dot() run slower than numpy.dot() on CPU?

I want to use JAX to accelerate my numpy code on CPU, later on GPU. Here is my example code running on my local computer (only CPU):

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

I got a warning and the time costs are as follows:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

Did I do anything wrong here? Should JAX also accelerate numpy code on CPUs?

Upvotes: 3

Views: 2323

Answers (1)

jkr
jkr

Reputation: 19260

There are probably performance differences between Jax and Numpy, but in the original post, the time differences mostly come down to a mistake in the array creation. The array used by Jax has the shape 3000x3000, whereas the array used by Numpy is a 1D array with length 2. The first argument to numpy.random.normal is loc (i.e., the mean of the Gaussian from which to sample). The keyword argument size= should be used to indicate the shape of the array.

numpy.random.normal(loc=0.0, scale=1.0, size=None)

Once this change is made, the performance between Jax and Numpy is less different.

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

The output of one run is

Time of jnp: 2.3315 s
Time of np: 2.8811 s

When measuring timed performance, one should collect multiple runs because a function's performance is a spread of times instead of a single value. This can be done with the Python standard library timeit.timeit function or the %timeit magic in IPython and Jupyter Notebook.

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It seems like there is an optimized dot operation for 32-bit floats in Numpy.

Upvotes: 3

Related Questions