Chiel
Chiel

Reputation: 6194

JAX 3d convolution kernel speedup

I am trying to solve a diffusion kernel with JAX and this is my JAX port of existing GPU CUDA code. JAX gives the correct answer, but it is about 5x slower than CUDA. How can I speed this up further? Not sure if my implementation of the diff function is the best. I tried to use the same formulation as the equivalent C++ code.

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
from functools import partial

from timeit import default_timer as timer


# Numpy-like operation.
@partial(jit, static_argnums=(6, 7, 8))
def diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot):
    i_c = jnp.s_[1:ktot-1, 1:jtot-1, 1:itot-1]
    i_w = jnp.s_[1:ktot-1, 1:jtot-1, 0:itot-2]
    i_e = jnp.s_[1:ktot-1, 1:jtot-1, 2:itot  ]
    i_s = jnp.s_[1:ktot-1, 0:jtot-2, 1:itot-1]
    i_n = jnp.s_[1:ktot-1, 2:jtot  , 1:itot-1]
    i_b = jnp.s_[0:ktot-2, 1:jtot-1, 1:itot-1]
    i_t = jnp.s_[2:ktot  , 1:jtot-1, 1:itot-1]

    at_new = at.at[i_c].add(
            visc * (
                + ( (a[i_e] - a[i_c])
                  - (a[i_c] - a[i_w]) ) * dxidxi
                + ( (a[i_n] - a[i_c])
                  - (a[i_c] - a[i_s]) ) * dyidyi
                + ( (a[i_t] - a[i_c])
                  - (a[i_c] - a[i_b]) ) * dzidzi
                )
            )

    return at_new


itot = 384;
jtot = 384;
ktot = 384;

float_type = jnp.float32

nloop = 30;
ncells = itot*jtot*ktot;

dxidxi = float_type(0.1)
dyidyi = float_type(0.1)
dzidzi = float_type(0.1)
visc = float_type(0.1)

@jit
def init_a(index):
    return (index/(index+1))**2


## FIRST EXPERIMENT.
at = jnp.zeros((ktot, jtot, itot), dtype=float_type)
index = jnp.arange(ncells, dtype=float_type)
a = init_a(index)
del(index)
a = a.reshape(ktot, jtot, itot)

at = diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot).block_until_ready()
print("(first check) at={0}".format(at.flatten()[itot*jtot+itot+itot//2]))

# Time the loop
start = timer()
for i in range(nloop):
    at = diff(at, a, visc, dxidxi, dyidyi, dzidzi, itot, jtot, ktot).block_until_ready()
end = timer()

print("Time/iter: {0} s ({1} iters)".format((end-start)/nloop, nloop))
print("at={0}".format(at.flatten()[itot*jtot+itot+itot//4]))

Upvotes: 0

Views: 87

Answers (0)

Related Questions