ThR37
ThR37

Reputation: 4085

Using numpy.bmat with numba

I am trying to use np.bmat in my numba-optimized python program. To do so, I have to manually define a jitted function bmat since the native one from numpy is not supported:

@njit
def _bmat_2d(matrices):
    arr_rows = []
    for row in matrices:
        arr_rows.append(np.concatenate(row, axis=-1))
    return np.array(np.concatenate(arr_rows, axis=0))

(this code is more or less a simplified copy of the one from numpy)

However:

  1. numba only accepts tuples in input of np.concatenate [1]
  2. numba is very bad at casting arbitrary list to tuples [2]

Do you have any idea for this ?

Refs:

Upvotes: 2

Views: 131

Answers (2)

Since np.hstack was incompatible with numba for me, I had to write my own solution. Maybe some of you find this useful. It's not pretty but it does the job.

This essentially does the same thing as J = np.bmat([[J_1, J_2], [J_3, J_4]]).

Just be sure to change J = np.zeros((8, len(J_1[0])*2)) to fit the output array you want:

import numpy as np
import numba

@numba.njit
def main():
    J_1 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_2 = np.array([[-85.33333333, 34.13333333, 34.13333333, 17.06666667], [34.13333333, -34.13333333, 0., 0.], [34.13333333, 0., -34.13333333, 0.], [17.06666667, 0., 0., -870.4]])
    J_3 = np.array([[85.33333333, -34.13333333, -34.13333333, -17.06666667], [-34.13333333, 34.13333333, -0., -0.], [-34.13333333, -0., 34.13333333, -0.], [-17.06666667, -0., -0., 870.4]])
    J_4 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])

    J = np.zeros((8, len(J_1[0])*2))
    for idx, _ in enumerate(J_1[0]):
        J[0][idx], J[1][idx], J[2][idx], J[3][idx], J[4][idx], J[5][idx], J[6][idx], J[7][idx] = J_1[0][idx], J_1[1][idx], J_1[2][idx], J_1[3][idx], J_3[0][idx], J_3[1][idx], J_3[2][idx], J_3[3][idx]
        J[0][idx+len(J_1[0])], J[1][idx+len(J_1[0])], J[2][idx+len(J_1[0])], J[3][idx+len(J_1[0])], J[4][idx+len(J_1[0])], J[5][idx+len(J_1[0])], J[6][idx+len(J_1[0])], J[7][idx+len(J_1[0])] = J_2[0][idx], J_2[1][idx], J_2[2][idx], J_2[3][idx], J_4[0][idx], J_4[1][idx], J_4[2][idx], J_4[3][idx]

    print(J)

if __name__ == '__main__':
    main()

Edit:

A guy helped me on another thread with this simple replacement for np.bmat which works inside numba.njit:

J = np.vstack((np.hstack((J_1, J_2)), np.hstack((J_3, J_4))))

Upvotes: 0

JoshAdel
JoshAdel

Reputation: 68682

Would the following work for your purposes?

import numpy as np
import numba as nb

@nb.njit
def _bmat_2d(m):
    out = np.hstack(m[0])
    for row in m[1:]:
        x = np.hstack(row)
        out = np.vstack((out, x))

    return out

A = np.random.randint(10, size=(3,2))
B = np.random.randint(10, size=(3,1))
C = np.random.randint(10, size=(3,3))
D = np.random.randint(10, size=(4,6))

a = np.bmat(((A, B, C), (D,)))
b = _bmat_2d(((A, B, C), (D,)))

print(np.allclose((a, b))  # True

Note that you have to pass in a tuple-of-tuples, rather than a list-of-lists or else you will get a "reflected list" error since Numba in the current version cannot handle list-of-lists.

Upvotes: 2

Related Questions