
Reputation: 59

best way to mipmap on jax

this is not really a question, but rather I was wondering if anyone has a better way of doing an occupancy grid in Jax (or in another language) for a 3D grid. Here is some working code, does anyone has a better solution (or any problems with my code?)

import jax.numpy as jnp
from jax import lax
import numpy as np

def mipmap(mat):
    assert mat.ndim == 3
    xdim, ydim, zdim = mat.shape
    assert xdim == ydim
    assert ydim == zdim

    levels = jnp.log2(xdim)
    mipmap = []

    data = jnp.array(mat.astype(jnp.float32))

    occupancy = data > 0
    occupancy = jnp.array(occupancy.astype(jnp.float32))

    data = data[None, :, :, :, None]
    kernel = jnp.ones([2, 2, 2])[:, :, :, jnp.newaxis, jnp.newaxis]
    dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC'))

    for i in range(int(levels)):
        out = lax.conv_general_dilated(data,  # lhs = image tensor
                                   kernel,  # rhs = conv kernel tensor
                                   (2, 2, 2),  # window strides
                                   'SAME',  # padding mode
                                   (1, 1, 1),  # lhs/image dilation
                                   (1, 1, 1),  # rhs/kernel dilation
                                   dn)  # dimension_numbers

        occupancy = out > 0
        occupancy = jnp.array(occupancy.astype(jnp.float32))
        data = occupancy
        mipmap.append(occupancy[0, :, :, :, 0].astype(int))

    return mipmap

# example
entry = np.zeros([4, 4, 4])
entry[0, 0, 0] = 1
entry[2, 0, 0] = 1
entry[3, 3, 3] = 1
entry[0, 3, 3] = 1

occupancy = mipmap(entry)

Thanks for reading and letting me know :)

Upvotes: 0

Views: 48

Answers (1)


Reputation: 59

actually I think I made a mistake, one needs to count the occupancy at the end, the 0/1 occupancy of a level is not an operation that distributes along levels... This gives the correct answers:

def mipmap_compute(mat):
    xdim, ydim, zdim = mat.shape

    levels = jnp.log2(xdim).astype(int)
    mipmap = []

    occupancy_mat = np.zeros(mat.shape)
    occupancy_mat[mat > 0] = 1
    data = jnp.array(occupancy_mat.astype(jnp.float32))

    data = data[None, :, :, :, None]
    kernel_list = []
    dn_list = []
    for i in range(levels):
        kernel = jnp.ones([2**i, 2**i, 2**i])[:, :, :, jnp.newaxis,     jnp.newaxis]
        dn = lax.conv_dimension_numbers(data.shape, kernel.shape,     ('NHWDC', 'HWDIO', 'NHWDC'))

    for i in range(levels):
        dn = dn_list[i]
        kernel = kernel_list[i]
        out = lax.conv_general_dilated(data,  # lhs = image tensor
                                   kernel,  # rhs = conv kernel tensor
                                   (2**i, 2**i, 2**i),  # window strides
                                   'SAME',  # padding mode
                                   (1, 1, 1),  # lhs/image dilation
                                   (1, 1, 1),  # rhs/kernel dilation
                                   dn)  # dimension_numbers
        mipmap.append(out[0, :, :, :, 0])
        print("Finished level i = " + str(i))
    return mipmap

Upvotes: 0

Related Questions