lstbl
lstbl

Reputation: 537

Changing specified values on numpy array in a certain dimension

Wondering if there is an easy way to do this:

Say I have an numpy array with the shape (2,3,2), for example:

x = 
[[[ 0, 1],
  [ 2, 3],
  [ 4, 5]],

 [[ 6, 7],
  [ 8, 9],
  [10,11]]]

If I wanted to replace all the entries that corresponded to axis=1 and position=0, with zero, I could do this easily:

x[:,0,:] = 0
x = 
[[[ 0  0]
  [ 2  3]
  [ 4  5]]

 [[ 0  0]
  [ 8  9]
  [10 11]]]

However, what If I had a list of axes that I wanted to perform these operations on. Is there a built-in numpy function for this? Ideally it'd look something like this:

array_replace(array=x,axis=1,pos=0,replace_val=0)

Which would give the same array as above.

I can think of a way to do this by flattening matrices and calculating where the positions of each variable would be based on the dimension of each array, but I'm wondering if there is already something built into numpy.

Upvotes: 2

Views: 100

Answers (1)

willeM_ Van Onsem
willeM_ Van Onsem

Reputation: 477676

You can construct a tuple of slices and then use the specific axis for the position. So you can define such function as:

def array_replace(array, axis, pos, replace_val):
    array[(slice(None),) * axis + (pos,)] = replace_val

So what we do is constructing a 1-tuple with a slice object (slice(None),). A slice object is what Python generates behind the curtains for a colon :. So x[:,0,:] is a simple representation for x[(slice(None),0,slice(None))].

Next we repeat this slice axis times (the number of times before the specified axis), followed by the position we want. The remaining slices are optional, so we do not specify these here. We then use numpy's broadcasting to assign the replace_val to all these indices.

This then generates:

>>> x
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]]])
>>> array_replace(array=x, axis=1, pos=0, replace_val=0)
>>> x
array([[[ 0,  0],
        [ 2,  3],
        [ 4,  5]],

       [[ 0,  0],
        [ 8,  9],
        [10, 11]]])

Upvotes: 2

Related Questions