Reputation: 133
I have a following function which processes an image
# img : RGB image (512, 512, 3)
# kernel : 5x5 filtering kernel
# dilation : can take integer values 1, 2, 4, ...
# some_data : no conditions are dependent on the value of this argument and it
# remains constant across multiple invocations this function
def filtering(img, kernel, dilation, some_data):
h, w, _ = img.shape
filtered_img = jnp.zeros(img.shape)
radius = 2
for i in range(radius, h-radius):
for j in range(radius, w-radius):
center_pos = np.array([i, j])
sum = jnp.array([0.0, 0.0, 0.0])
sum_w = 0.0
for ii in range(-radius, radius + 1):
for jj in range(-radius, radius + 1):
pos = center_pos + dilation * np.array([ii, jj])
# if not for the `compute_weight` function this could have been a dilated convolution
weight = kernel[ii + radius, jj + radius] * compute_weight(center_pos, pos, some_data)
sum += img[pos[0], pos[1], :] * weight
sum_w += weight
filtered_img = filtered_img.at[i, j].set(sum/sum_w)
return filtered_img
The first function call (jit compiled) takes approx. 6 hours to run (tried on both GPU and CPU). Since it is jit compiled subsequent runs may be faster, but the first run is prohibitively expensive.
I tried removing the compute_weight
function and replacing the two innermost nested loops with jnp.sum(img[i-radius:i+radius+1, j-radius:j+radius+1] * filter, axis=(0, 1))
and the first function invocation still takes around 30 minutes to run. Based on this observation and some of the other questions on SO, this seems to be due to the for loops
in general.
Will rewriting this in a more functional way and using jax constucts like loops help or is this happening due to some other issue?
Upvotes: 2
Views: 1028
Reputation: 86443
The issue here is not execution time, the issue is compilation time. JAX's JIT compilation will flatten all Python control flow: what this means is that for your input, you are generating 512 * 512 * 5 * 5 copies of the jaxpr for the inner loop, and sending them to XLA for compilation. Since compilation costs scale as roughly the square of the length of the program, the result will be an extremely long compilation.
Your best bet here is probably to rewrite this in terms of jax.fori_loop
, which will lower the loop logic directly to XLA without the large compilation cost.
Even better, since it looks like what you're doing is some flavor of a convolution, would be to express this in terms of something like jax.scipy.signal.convolve2d
, which will be far faster than doing the looping manually.
Upvotes: 2