WillyWonka
WillyWonka

Reputation: 133

Nested for loops in jax

I have a following function which processes an image


# img : RGB image (512, 512, 3)
# kernel : 5x5 filtering kernel 
# dilation : can take integer values 1, 2, 4, ...
# some_data : no conditions are dependent on the value of this argument and it 
#            remains constant across multiple invocations this function
def filtering(img, kernel, dilation, some_data):
    h, w, _ = img.shape
    filtered_img = jnp.zeros(img.shape)

    radius = 2
    for i in range(radius, h-radius):
        for j in range(radius, w-radius):

            center_pos = np.array([i, j])
            sum = jnp.array([0.0, 0.0, 0.0])
            sum_w = 0.0
            
            for ii in range(-radius, radius + 1):
                for jj in range(-radius, radius + 1):

                    pos = center_pos + dilation * np.array([ii, jj])
                    
                    # if not for the `compute_weight` function this could have been a dilated convolution
                    weight = kernel[ii + radius, jj + radius] * compute_weight(center_pos, pos, some_data)
                    sum += img[pos[0], pos[1], :] * weight
                    sum_w += weight

            filtered_img = filtered_img.at[i, j].set(sum/sum_w)

    return filtered_img

The first function call (jit compiled) takes approx. 6 hours to run (tried on both GPU and CPU). Since it is jit compiled subsequent runs may be faster, but the first run is prohibitively expensive.

I tried removing the compute_weight function and replacing the two innermost nested loops with jnp.sum(img[i-radius:i+radius+1, j-radius:j+radius+1] * filter, axis=(0, 1)) and the first function invocation still takes around 30 minutes to run. Based on this observation and some of the other questions on SO, this seems to be due to the for loops in general.

Will rewriting this in a more functional way and using jax constucts like loops help or is this happening due to some other issue?

Upvotes: 2

Views: 1028

Answers (1)

jakevdp
jakevdp

Reputation: 86443

The issue here is not execution time, the issue is compilation time. JAX's JIT compilation will flatten all Python control flow: what this means is that for your input, you are generating 512 * 512 * 5 * 5 copies of the jaxpr for the inner loop, and sending them to XLA for compilation. Since compilation costs scale as roughly the square of the length of the program, the result will be an extremely long compilation.

Your best bet here is probably to rewrite this in terms of jax.fori_loop, which will lower the loop logic directly to XLA without the large compilation cost.

Even better, since it looks like what you're doing is some flavor of a convolution, would be to express this in terms of something like jax.scipy.signal.convolve2d, which will be far faster than doing the looping manually.

Upvotes: 2

Related Questions