Reputation: 399
For two arrays, say, a = np.array([1,2,3,4])
and b = np.array([5,6])
, is there a way, if any, to obtain a 2d array of the following form without looping:
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
i.e. to insert b
in all possible places of a
.
And if loops are unavoidable, how to do it the most computationally efficient way (a
can be long, the length of b
is irrelevant)?
Example of how it can be done using loops is trivial:
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
for i in range(rows):
res[i, i:len(b)+i] = b
res[i, len(b)+i:] = a[i:]
res[i, 0:i] = a[0:i]
print(rows.astype(int))
[[5 6 1 2 3 4]
[1 5 6 2 3 4]
[1 2 5 6 3 4]
[1 2 3 5 6 4]
[1 2 3 4 5 6]]
Upvotes: 4
Views: 327
Reputation: 2816
You can create a zero array based on the expected shape and then fill the desired indices by the b
values and finally fill remained zero values by tile of the a
array with the needed shape as:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
# [[0 1]
# [1 2]
# [2 3]
# [3 4]
# [4 5]]
zero_arr = np.zeros((shape, size), dtype=np.float64)
zero_arr[np.arange(shape)[:, None], ind] = b
# [[5 6 0 0 0 0]
# [0 5 6 0 0 0]
# [0 0 5 6 0 0]
# [0 0 0 5 6 0]
# [0 0 0 0 5 6]]
zero_arr[zero_arr == 0] = np.tile(a, shape)
# [[5 6 1 2 3 4]
# [1 5 6 2 3 4]
# [1 2 5 6 3 4]
# [1 2 3 5 6 4]
# [1 2 3 4 5 6]]
This method will beat tax evader method in terms of performance, on larger arrays, if we create zero_arr
based on any value that not contained in the b
array. E.g. if we have 0
in b
, so zero_arr == 0
will misled the solution. One possible method is to use -np.ones
with zero_arr == -1
if b
contains just positive values. We can create this array using np.fill
or np.full
if know which arbitrary value is not in that to use aforementioned indexing method; This value could be selected by using np.arange
or np.random
and check to find a value that is not in the b
, too. But more comprehensive one is as below:
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = np.lib.stride_tricks.sliding_window_view(np.arange(size), b.shape[0])
ind_1 = np.lib.stride_tricks.sliding_window_view(np.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = ind_1 % size
arr = np.zeros((shape, size), dtype=np.float64)
arange_ = np.arange(shape)[:, None]
arr[arange_, ind_0] = b
arr[arange_, np.sort(ind_1)] = np.broadcast_to(a, (shape, a.shape[0])) # or use np.tile
If you don't have any limitation to use other libraries (as you was agree with the numba one) and can run the code on GPU, I believe that JAX library will beat numba on larger arrays. I have converted the written codes by NumPy into JAX jitted form to see how this library can handle such matrix form problems in terms of performance. Besides the benchmarks to evaluate and compare the performances between JAX and numba on this issue, these codes have a learning aspects about how to use jax numpy where (jnp.where
) with JAX jit decorator, where we must specify sizes statically in that to be workable. Another aspect was about creating equivalent np.lib.stride_tricks.sliding_window_view
or np.lib.stride_tricks.as_strided
in jax jitted function by jax library. The evader code is converted, too, but by some changes (I think are needed for JAX usage); I don't know if I could write it in more compacted form (shorter). I think the written code can be rewritten in more optimized form which will get, even, faster codes.
from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
@jit
def jax_initial(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind = moving_window(jnp.arange(size), b.shape[0])
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arr = arr.at[jnp.arange(shape)[:, None], ind].set(b)
broad = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0])))
idx = jnp.where(arr == 0, size=broad.size)
return arr.at[idx].set(broad)
@jit
def jax_comp(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
ind_0 = moving_window(jnp.arange(size), b.shape[0])
ind_1 = moving_window(jnp.arange(b.shape[0], size + shape - 1), a.shape[0])
ind_1 = jnp.remainder(ind_1, size) # ind_1 = ind_1 % size
arr = jnp.zeros((shape, size), dtype=jnp.float64)
arange_ = jnp.arange(shape)[:, None]
arr = arr.at[arange_, ind_0].set(b)
arr = arr.at[arange_, jnp.sort(ind_1)].set(jnp.broadcast_to(a, (shape, a.shape[0])))
return arr
@jit
def jax_evader(a, b):
shape, size = a.shape[0] + 1, a.shape[0] + b.shape[0]
res = jnp.empty([shape, size])
frame = jnp.reshape(jnp.arange(shape * size), (shape, size))
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res_0 = jnp.ravel(jnp.broadcast_to(b, (shape, b.shape[0]))) # jnp.tile(b, res.shape[0])
res_1 = jnp.ravel(jnp.broadcast_to(a, (shape, a.shape[0]))) # jnp.tile(a, res.shape[0])
idx = jnp.where(diag_mask, size=res_0.size)
idx_v = jnp.where(~diag_mask, size=res_1.size)
res = res.at[idx].set(res_0)
return res.at[idx_v].set(res_1)
some Benchmarks temporary link
We can parallel the numba code with signatures which, on my system, ran at least 3 times faster than the Pig one:
a = np.random.rand(10000)
b = np.array([5, 6, 7, 8, 9, 10, 11], dtype=np.int64)
@nb.njit('float64[::1], int64[::1]', parallel=True)
def fill_parallel(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in nb.prange(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
numba parallelized code is the fastest code so far.
Upvotes: 1
Reputation: 7736
Consider using numba acceleration. This happens to be what numba is best at. For your example, it can speed up nearly 6 times:
from timeit import timeit
import numpy as np
from numba import njit
a = np.arange(1, 5)
b = np.array([5, 6])
def fill(a, b):
rows = a.size + 1
cols = a.size + b.size
res = np.empty((rows, cols))
for i in range(rows):
res[i, i:b.size + i] = b
res[i, b.size + i:] = a[i:]
res[i, :i] = a[:i]
return res
if __name__ == '__main__':
print('before:', timeit(lambda: fill(a, b)))
fill = njit(fill)
print('after:', timeit(lambda: fill(a, b)))
Output:
before: 9.488150399993174
after: 1.6149254000047222
Upvotes: 1
Reputation: 4508
I think you can use masking to add diagonal values b
to the res
array and add the rest of the cells in the array with values from a
import numpy as np
a = np.array([1,2,3,4])
b = np.array([5,6])
rows = len(a) + 1
cols = len(a) + len(b)
res = np.empty([rows, cols])
# Frame for masking
frame = np.arange(rows * cols).reshape(rows, cols)
diag_mask = (frame % (res.shape[1] + 1)) < (b.shape[0])
res[diag_mask] = np.tile(b, res.shape[0])
res[~diag_mask] = np.tile(a, res.shape[0])
Upvotes: 1