Reputation: 585
I have the following labels
>>> lab
array([3, 0, 3 ,3, 1, 1, 2 ,2, 3, 0, 1,4])
I want to assign this label to another numpy array i.e
>>> arr
array([[81, 1, 3, 87], # 3
[ 2, 0, 1, 0], # 0
[13, 6, 0, 0], # 3
[14, 0, 1, 30], # 3
[ 0, 0, 0, 0], # 1
[ 0, 0, 0, 0], # 1
[ 0, 0, 0, 0], # 2
[ 0, 0, 0, 0], # 2
[ 0, 0, 0, 0], # 3
[ 0, 0, 0, 0], # 0
[ 0, 0, 0, 0], # 1
[13, 2, 0, 11]]) # 4
and add all corresponding rows with same labels.
The output must be
([[108, 7, 4,117]--3
[ 0, 0, 0, 0]--0
[ 0, 0, 0, 0]--1
[ 0, 0, 0, 0]--2
[13, 2, 0, 11]])--4
Upvotes: 0
Views: 118
Reputation: 231550
numpy
doesn't have a group_by
function like pandas
, but it does have a reduceat
method that performs fast array actions on groups of elements (rows). But it's application in this case is a bit messy.
Start with our 2 arrays:
In [39]: arr
Out[39]:
array([[81, 1, 3, 87],
[ 2, 0, 1, 0],
[13, 6, 0, 0],
[14, 0, 1, 30],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[13, 2, 0, 11]])
In [40]: lbls
Out[40]: array([3, 0, 3, 3, 1, 1, 2, 2, 3, 0, 1, 4])
Find the indices that will sort lbls
(and rows of arr
) into contiguous blocks:
In [41]: I=np.argsort(lbls)
In [42]: I
Out[42]: array([ 1, 9, 4, 5, 10, 6, 7, 0, 2, 3, 8, 11], dtype=int32)
In [43]: s_lbls=lbls[I]
In [44]: s_lbls
Out[44]: array([0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4])
In [45]: s_arr=arr[I,:]
In [46]: s_arr
Out[46]:
array([[ 2, 0, 1, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[81, 1, 3, 87],
[13, 6, 0, 0],
[14, 0, 1, 30],
[ 0, 0, 0, 0],
[13, 2, 0, 11]])
Find the boundaries of these blocks, i.e. where s_lbls
jumps:
In [47]: J=np.where(np.diff(s_lbls))
In [48]: J
Out[48]: (array([ 1, 4, 6, 10], dtype=int32),)
Add the index of the start of the first block (see the reduceat
docs)
In [49]: J1=[0]+J[0].tolist()
In [50]: J1
Out[50]: [0, 1, 4, 6, 10]
Apply add.reduceat
:
In [51]: np.add.reduceat(s_arr,J1,axis=0)
Out[51]:
array([[ 2, 0, 1, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[108, 7, 4, 117],
[ 13, 2, 0, 11]], dtype=int32)
These are your numbers, sorted by lbls
(for 0,1,2,3,4).
With reduceat
you could take other actions like maximum, product etc.
Upvotes: 1