Reputation: 649
I've recently started experimenting with the interesting python library Jax, which contains a boosted Numpy as well as Automatic Differentiator. What I wanted to try to create, is a crude "differentiable renderer", by writing a shader and loss function in python, then using Jax's AD to find the gradient. We should then be able to inverse render an image by running gradient descent on this loss gradient. I have made it work fairly well with simple shaders, but I have run into problems when I use boolean expressions. This is the code of my shader, which generates a checkerboard pattern:
import jax.numpy as np
class CheckerShader:
def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
self.color1 = None
self.color2 = None
self.scale = None
self.scale_min = 0
self.scale_max = 20
self.color1 = color1
self.color2 = color2
self.scale = scale * 20
def checker(self, x: float, y: float) -> float:
xi = np.abs(np.floor(x))
yi = np.abs(np.floor(y))
first_col = np.mod(xi, 2) == np.mod(yi, 2)
return first_col
def shade(self, x: float, y: float):
x = x * self.scale
y = y * self.scale
first_col = self.checker(x, y)
if first_col:
return self.color1
else:
return self.color2
And this is my render function, which is the first place that JIT fails:
import jax.numpy as np
import numpy as onp
import jax
def render(scale, c1, c2):
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
sh = CheckerShader(scale, c1, c2)
jit_func = jax.jit(sh.shade)
for y in range(HEIGHT):
for x in range(WIDTH):
val = jit_func(x / WIDTH, y / HEIGHT)
img[y, x, :] = val
return img
The error message I receive is:
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
and I guess it is because you can't run JIT on a function with booleans whose values depend on something that is not decided during compile time. But how can I rewrite it to work with JIT? Without JIT, it is painfully slow.
Another question I have is, is there something I can do to speed up Jax's Numpy in general? Rendering my image (100x100 pixels) with normal Numpy takes a few milliseconds, but with Jax's Numpy, it takes seconds! Thanks :D
Upvotes: 1
Views: 1728
Reputation: 688
But how can I rewrite it to work with JIT?
Ivo has a nice answer here - simply use np.where
.
Another question I have is, is there something I can do to speed up Jax's Numpy in general?
There are probably three reasons why this is slow.
The first is the nature of JITing. It will be slow the first time you run your code, but if you run the same code multiple times the speed should increase. I would also try to JIT the entire render function if possible, if you plan to run this multiple times.
The second reason is that switching between numpy and jax.numpy will be very slow. You write
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
but it'll be much faster if you write
img = np.zeros((WIDTH, HEIGHT, CHANNELS))
The third is that you are looping over width and height rather than using vectorized operations. I don't see why you can't do this in fully vectorized form.
Upvotes: 2
Reputation: 3422
Replace
if first_col:
return self.color1
else:
return self.color2
with
return np.where(first_col, self.color1, self.color2)
Upvotes: 3