Zanam
Zanam

Reputation: 4807

Numpy replacing specific column index per row by using a list of indexes with nan

I am trying the following:

a = np.array([[1,2,3], [4,5,6], [7,8,9]])

print a
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

a[np.arange(len(a)), [1,0,2]] = 20 #--Code1

print a
array([[ 1, 20,  3],
       [20,  5,  6],
       [ 7,  8, 20]])

However, if my index has nan in them as:

a[np.arange(len(a)), [1,np.nan,2]] = 20  #--Code2

It errors out.

What I was trying to do is, if there is nan present in the index, don't change anything.

i.e. I wanted to implement Code2 above so that I can obtain the following:

    array([[ 1, 20,  3],
           [4,  5,  6],
           [ 7,  8, 20]])

Upvotes: 2

Views: 50

Answers (1)

Divakar
Divakar

Reputation: 221584

Use masking -

m = ~np.isnan(idx) # Mask of non-NaNs
row = np.arange(a.shape[0])[m]
col = idx[m].astype(int)
a[row, col] = 20

where, idx is the indexing array.

Sample run -

In [161]: a = np.array([[1,2,3], [4,5,6], [7,8,9]])

In [162]: idx = np.array([1,np.nan,2])

In [163]: m = ~np.isnan(idx) # Mask of non-NaNs
     ...: row = np.arange(a.shape[0])[m]
     ...: col = idx[m].astype(int)
     ...: a[row, col] = 20
     ...: 

In [164]: a
Out[164]: 
array([[ 1, 20,  3],
       [ 4,  5,  6],
       [ 7,  8, 20]])

Upvotes: 1

Related Questions