Reputation: 59
this is not really a question, but rather I was wondering if anyone has a better way of doing an occupancy grid in Jax (or in another language) for a 3D grid. Here is some working code, does anyone has a better solution (or any problems with my code?)
import jax.numpy as jnp
from jax import lax
import numpy as np
def mipmap(mat):
assert mat.ndim == 3
xdim, ydim, zdim = mat.shape
assert xdim == ydim
assert ydim == zdim
levels = jnp.log2(xdim)
mipmap = []
data = jnp.array(mat.astype(jnp.float32))
occupancy = data > 0
occupancy = jnp.array(occupancy.astype(jnp.float32))
mipmap.append(occupancy.astype(int))
data = data[None, :, :, :, None]
kernel = jnp.ones([2, 2, 2])[:, :, :, jnp.newaxis, jnp.newaxis]
dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC'))
for i in range(int(levels)):
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2, 2, 2), # window strides
'SAME', # padding mode
(1, 1, 1), # lhs/image dilation
(1, 1, 1), # rhs/kernel dilation
dn) # dimension_numbers
occupancy = out > 0
occupancy = jnp.array(occupancy.astype(jnp.float32))
data = occupancy
mipmap.append(occupancy[0, :, :, :, 0].astype(int))
return mipmap
# example
entry = np.zeros([4, 4, 4])
entry[0, 0, 0] = 1
entry[2, 0, 0] = 1
entry[3, 3, 3] = 1
entry[0, 3, 3] = 1
occupancy = mipmap(entry)
Thanks for reading and letting me know :)
Upvotes: 0
Views: 48
Reputation: 59
actually I think I made a mistake, one needs to count the occupancy at the end, the 0/1 occupancy of a level is not an operation that distributes along levels... This gives the correct answers:
def mipmap_compute(mat):
xdim, ydim, zdim = mat.shape
levels = jnp.log2(xdim).astype(int)
mipmap = []
occupancy_mat = np.zeros(mat.shape)
occupancy_mat[mat > 0] = 1
data = jnp.array(occupancy_mat.astype(jnp.float32))
data = data[None, :, :, :, None]
kernel_list = []
dn_list = []
for i in range(levels):
kernel = jnp.ones([2**i, 2**i, 2**i])[:, :, :, jnp.newaxis, jnp.newaxis]
dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC'))
kernel_list.append(kernel)
dn_list.append(dn)
for i in range(levels):
dn = dn_list[i]
kernel = kernel_list[i]
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2**i, 2**i, 2**i), # window strides
'SAME', # padding mode
(1, 1, 1), # lhs/image dilation
(1, 1, 1), # rhs/kernel dilation
dn) # dimension_numbers
mipmap.append(out[0, :, :, :, 0])
print("Finished level i = " + str(i))
return mipmap
Upvotes: 0