Allen Qin
Allen Qin

Reputation: 19947

Numpy: Flatten some columns of an 2 D array

Suppose I have a numpy array as below

a = np.asarray([[1,2,3],[1,4,3],[2,5,4],[2,7,5]])

array([[1, 2, 3],
       [1, 4, 3],
       [2, 5, 4],
       [2, 7, 5]])

How can I flatten column 2 and 3 for each unique element in column 1 like below:

array([[1, 2, 3, 4, 3],
       [2, 5, 4, 7, 5],])

Thank you for your help.

Upvotes: 2

Views: 847

Answers (3)

Divakar
Divakar

Reputation: 221524

Since as posted in the comments, we know that each unique element in column-0 would have a fixed number of rows and by which I assumed it was meant same number of rows, we can use a vectorized approach to solve the case. We sort the rows based on column-0 and look for shifts along it, which would signify group change and thus give us the exact number of rows associated per unique element in column-0. Let's call it L. Finally, we slice sorted array to select columns-1,2 and group L rows together by reshaping. Thus, the implementation would be -

sa = a[a[:,0].argsort()]
L = np.unique(sa[:,0],return_index=True)[1][1]
out = np.column_stack((sa[::L,0],sa[:,1:].reshape(-1,2*L)))

For more performance boost, we can use np.diff to calculate L, like so -

L = np.where(np.diff(sa[:,0])>0)[0][0]+1

Sample run -

In [103]: a
Out[103]: 
array([[1, 2, 3],
       [3, 7, 8],
       [1, 4, 3],
       [2, 5, 4],
       [3, 8, 2],
       [2, 7, 5]])

In [104]: sa = a[a[:,0].argsort()]
     ...: L = np.unique(sa[:,0],return_index=True)[1][1]
     ...: out = np.column_stack((sa[::L,0],sa[:,1:].reshape(-1,2*L)))
     ...: 

In [105]: out
Out[105]: 
array([[1, 2, 3, 4, 3],
       [2, 5, 4, 7, 5],
       [3, 7, 8, 8, 2]])

Upvotes: 0

akuiper
akuiper

Reputation: 214927

Another option using list comprehension:

np.array([np.insert(a[a[:,0] == k, 1:].flatten(), 0, k) for k in np.unique(a[:,0])])

# array([[1, 2, 3, 4, 3],
#        [2, 5, 4, 7, 5]])

Upvotes: 2

John1024
John1024

Reputation: 113824

import numpy as np
a = np.asarray([[1,2,3],[1,4,3],[2,5,4],[2,7,5]])
d = {}
for row in a:
    d[row[0]] = np.concatenate( (d.get(row[0], []), row[1:]) ) 
r = np.array([np.concatenate(([key], d[key])) for key in d])
print(r)

This prints:

[[ 1.  2.  3.  4.  3.]
 [ 2.  5.  4.  7.  5.]]

Upvotes: 2

Related Questions