pd shah
pd shah

Reputation: 1406

remove for loop in numpy statment

there is a array like:

x:
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69]],

       [[70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79]],

       [[80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89]],

       [[90, 91, 92, 93, 94],
        [95, 96, 97, 98, 99]]])

the goal is group by every item i by i+3 items, and in each group check all items of zero axis are bigger than 30.

grouping item i and i+3:

for i in range(0,x.shape[0]-3):
    x[i:i+3]
    print()

array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]])

array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])

array([[[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]]])

array([[[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

array([[[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69]]])

array([[[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69]],

       [[70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79]]])

array([[[60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69]],

       [[70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79]],

       [[80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89]]])

and finally check condition:

for i in range(0,x.shape[0]-3+1):
    (x[i:i+3] > 30).all(axis=0)
    print()


array([[False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)

array([[False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)

array([[False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)

array([[False,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True]], dtype=bool)

the question is: is there any way to remove for loop? for better performance.

Upvotes: 2

Views: 130

Answers (1)

Divakar
Divakar

Reputation: 221714

Here's one efficient approach using np.lib.stride_tricks.as_strided -

(strided_axis0(x,3)>30).all(1)

Strides based function strided_axis0 is from here.

Upvotes: 3

Related Questions