Reputation: 6194
I am trying to solve a diffusion kernel with JAX and this is my JAX port of existing GPU CUDA code. JAX gives the correct answer, but it is about 5x slower than CUDA. How can I speed this up further? Not sure if my implementation of the diff
function is the best. I tried to use the same formulation as the equivalent C++ code.
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
from functools import partial
from timeit import default_timer as timer
# Numpy-like operation.
@partial(jit, static_argnums=(6, 7, 8))
def diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot):
i_c = jnp.s_[1:ktot-1, 1:jtot-1, 1:itot-1]
i_w = jnp.s_[1:ktot-1, 1:jtot-1, 0:itot-2]
i_e = jnp.s_[1:ktot-1, 1:jtot-1, 2:itot ]
i_s = jnp.s_[1:ktot-1, 0:jtot-2, 1:itot-1]
i_n = jnp.s_[1:ktot-1, 2:jtot , 1:itot-1]
i_b = jnp.s_[0:ktot-2, 1:jtot-1, 1:itot-1]
i_t = jnp.s_[2:ktot , 1:jtot-1, 1:itot-1]
at_new = at.at[i_c].add(
visc * (
+ ( (a[i_e] - a[i_c])
- (a[i_c] - a[i_w]) ) * dxidxi
+ ( (a[i_n] - a[i_c])
- (a[i_c] - a[i_s]) ) * dyidyi
+ ( (a[i_t] - a[i_c])
- (a[i_c] - a[i_b]) ) * dzidzi
)
)
return at_new
itot = 384;
jtot = 384;
ktot = 384;
float_type = jnp.float32
nloop = 30;
ncells = itot*jtot*ktot;
dxidxi = float_type(0.1)
dyidyi = float_type(0.1)
dzidzi = float_type(0.1)
visc = float_type(0.1)
@jit
def init_a(index):
return (index/(index+1))**2
## FIRST EXPERIMENT.
at = jnp.zeros((ktot, jtot, itot), dtype=float_type)
index = jnp.arange(ncells, dtype=float_type)
a = init_a(index)
del(index)
a = a.reshape(ktot, jtot, itot)
at = diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot).block_until_ready()
print("(first check) at={0}".format(at.flatten()[itot*jtot+itot+itot//2]))
# Time the loop
start = timer()
for i in range(nloop):
at = diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot).block_until_ready()
end = timer()
print("Time/iter: {0} s ({1} iters)".format((end-start)/nloop, nloop))
print("at={0}".format(at.flatten()[itot*jtot+itot+itot//4]))
Upvotes: 0
Views: 87