piRSquared
piRSquared

Reputation: 294348

fill diagonal on each slice along axis using strides

consider the numpy array a

a = np.arange(18).reshape(2, 3, 3)
print(a)

[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]

I want to fill in the diagonal of each slice along axis=0
I accomplished this using this answer

a[:, np.arange(3), np.arange(3)] = -1

print(a)

[[[-1  1  2]
  [ 3 -1  5]
  [ 6  7 -1]]

 [[-1 10 11]
  [12 -1 14]
  [15 16 -1]]]

How could I have done this using strides?

Upvotes: 2

Views: 221

Answers (1)

Divakar
Divakar

Reputation: 221584

Here's an approach with strides -

def fill_diag_strided(a, fillval):
    m,n,r = a.strides
    a_diag = np.lib.stride_tricks.as_strided(a, shape=a.shape[:2],strides=(m,r+n))
    a_diag[:] = fillval

Sample run -

In [369]: a
Out[369]: 
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

In [370]: fill_diag_strided(a, fillval=-1)

In [371]: a
Out[371]: 
array([[[-1,  1,  2],
        [ 3, -1,  5],
        [ 6,  7, -1]],

       [[-1, 10, 11],
        [12, -1, 14],
        [15, 16, -1]]])

Runtime test against an indexing-based approach -

In [35]: a = np.random.rand(100000,3,3)

In [36]: n = a.shape[1]

In [37]: %timeit a[:, np.arange(n), np.arange(n)] = -1
100 loops, best of 3: 4.45 ms per loop

In [38]: %timeit fill_diag_strided(a, fillval=-1)
1000 loops, best of 3: 1.76 ms per loop

Upvotes: 2

Related Questions