Zorobay
Zorobay

Reputation: 649

Problems with Jax's JIT and Numpy restrictions

I've recently started experimenting with the interesting python library Jax, which contains a boosted Numpy as well as Automatic Differentiator. What I wanted to try to create, is a crude "differentiable renderer", by writing a shader and loss function in python, then using Jax's AD to find the gradient. We should then be able to inverse render an image by running gradient descent on this loss gradient. I have made it work fairly well with simple shaders, but I have run into problems when I use boolean expressions. This is the code of my shader, which generates a checkerboard pattern:

import jax.numpy as np

class CheckerShader:

    def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
        self.color1 = None
        self.color2 = None
        self.scale = None
        self.scale_min = 0
        self.scale_max = 20
        self.color1 = color1
        self.color2 = color2
        self.scale = scale * 20

    def checker(self, x: float, y: float) -> float:
        xi = np.abs(np.floor(x))
        yi = np.abs(np.floor(y))

        first_col = np.mod(xi, 2) == np.mod(yi, 2)
        return first_col

    def shade(self, x: float, y: float):
        x = x * self.scale
        y = y * self.scale

        first_col = self.checker(x, y)
        if first_col:
            return self.color1
        else:
            return self.color2

And this is my render function, which is the first place that JIT fails:

import jax.numpy as np
import numpy as onp
import jax

def render(scale, c1, c2):
    img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
    sh = CheckerShader(scale, c1, c2)
    jit_func = jax.jit(sh.shade)

    for y in range(HEIGHT):
        for x in range(WIDTH):
            val = jit_func(x / WIDTH, y / HEIGHT)
            img[y, x, :] = val

    return img

The error message I receive is:

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

and I guess it is because you can't run JIT on a function with booleans whose values depend on something that is not decided during compile time. But how can I rewrite it to work with JIT? Without JIT, it is painfully slow.

Another question I have is, is there something I can do to speed up Jax's Numpy in general? Rendering my image (100x100 pixels) with normal Numpy takes a few milliseconds, but with Jax's Numpy, it takes seconds! Thanks :D

Upvotes: 1

Views: 1728

Answers (2)

Nick McGreivy
Nick McGreivy

Reputation: 688

But how can I rewrite it to work with JIT?

Ivo has a nice answer here - simply use np.where.

Another question I have is, is there something I can do to speed up Jax's Numpy in general?

There are probably three reasons why this is slow.

The first is the nature of JITing. It will be slow the first time you run your code, but if you run the same code multiple times the speed should increase. I would also try to JIT the entire render function if possible, if you plan to run this multiple times.

The second reason is that switching between numpy and jax.numpy will be very slow. You write

img = onp.zeros((WIDTH, HEIGHT, CHANNELS))

but it'll be much faster if you write

img = np.zeros((WIDTH, HEIGHT, CHANNELS))

The third is that you are looping over width and height rather than using vectorized operations. I don't see why you can't do this in fully vectorized form.

Upvotes: 2

Ivo Danihelka
Ivo Danihelka

Reputation: 3422

Replace

if first_col:
    return self.color1
else:
    return self.color2

with

return np.where(first_col, self.color1, self.color2)

Upvotes: 3

Related Questions