snowleopard
snowleopard

Reputation: 739

Sorting values in numpy structured array based on field name value

I have the following structured array:

import numpy as np

x = np.rec.array([(22,2,200.,2000.), (44,2,400.,4000.), (55,5,500.,5000.), (33,3,400.,3000.)],
              dtype={'names':['subcase','id', 'vonmises','maxprincipal'], 'formats':['i4','i4','f4','f4']})

I am trying to get the max vonmises for each id.

For example the max vonmises for id 2 would be 400. And i do want the corresponding subcase, and maxprincipal.

Here is what i have done so far:

print repr(x[['subcase','id','vonmises']][(x['id']==2) & (x['vonmises']==max(x['vonmises'][x['id']==2]))])

Here is the output:

array([(44, 2, 400.0)], 
  dtype=(numpy.record, [('subcase', '<i4'), ('id', '<i4'), ('vonmises', '<f4')]))

The issue i am having now is that i want this to work for all ids that are in the array, not just id=2.

i.e. want to get the following output:

array([(44, 2, 400.0),(55, 5, 500.0),(33, 3, 400.0)], 
  dtype=(numpy.record, [('subcase', '<i4'), ('id', '<i4'), ('vonmises', '<f4')]))

Is there a nice way to accomplish this without specifying each individual id?

Upvotes: 3

Views: 658

Answers (4)

Eelco Hoogendoorn
Eelco Hoogendoorn

Reputation: 10759

Using the numpy_indexed package, this would be a simple one-liner:

import numpy_indexed as npi
ids, maxvonmises = npi.group_by(x.id).max(x.vonmises)

Probably similar performance to pandas, but a lot more readable, and no need to adapt your dataformat to the problem at hand.

Upvotes: 2

Eric
Eric

Reputation: 97601

Here's an approach without groupby:

# sort as desired
x.sort(order=['id','vonmises'])

# keep the first element, and every element with a different id to the one before it
keep = np.empty(x.shape, dtype=np.bool)
keep[0] = True
keep[1:] = x[:-1].id != x[1:].id

x_filt = x[keep]

Which gives:

rec.array([(22, 2, 200.0, 2000.0), (33, 3, 400.0, 3000.0), (55, 5, 500.0, 5000.0)], 
      dtype=[('subcase', '<i4'), ('id', '<i4'), ('vonmises', '<f4'), ('maxprincipal', '<f4')])

Upvotes: 2

hpaulj
hpaulj

Reputation: 231385

Here's an approach using np.sort (or argsort) followed by itertools.groupby. But this grouping tools produces a generator of generators, which is messier to work with.

In [29]: x = np.rec.array([(22,2,200.,2000.), (44,2,400.,4000.), (55,5,500.,5000.), (33,3,400.,3000.)],
              dtype={'names':['subcase','id', 'vonmises','maxprincipal'], 'formats':['i4','i4','f4','f4']})

In [30]: ind=x.argsort(order=['id','vonmises'])

In [31]: ind
Out[31]: 
rec.array([0, 1, 3, 2], 
          dtype=int32)

In [32]: x[ind]
Out[32]: 
rec.array([(22, 2, 200.0, 2000.0), (44, 2, 400.0, 4000.0), (33, 3, 400.0, 3000.0),
 (55, 5, 500.0, 5000.0)], 
          dtype=[('subcase', '<i4'), ('id', '<i4'), ('vonmises', '<f4'), ('maxprincipal', '<f4')])

In [33]: import itertools

In [34]: [list(v) for k,v in itertools.groupby(x[ind],lambda i:i['id'])]
Out[34]: 
[[(22, 2, 200.0, 2000.0), (44, 2, 400.0, 4000.0)],
 [(33, 3, 400.0, 3000.0)],
 [(55, 5, 500.0, 5000.0)]]

Then we have to fetch the last (or first for min) record of each group, and then reconstitute the recarray.

In [39]: mx=[list(v)[-1] for k,v in itertools.groupby(x[ind],lambda i:i['id'])]

In [43]: np.rec.fromrecords(mx,dtype=x.dtype)
Out[43]: 
rec.array([(44, 2, 400.0, 4000.0), (33, 3, 400.0, 3000.0), (55, 5, 500.0, 5000.0)], 
          dtype=[('subcase', '<i4'), ('id', '<i4'), ('vonmises', '<f4'), ('maxprincipal', '<f4')])

Elements of mx are np.record with the correct dtype, but mx itself is a list.

Or compactly:

g=itertools.groupby(np.sort(x,order=['id','vonmises']), lambda i:i['id'])
np.rec.fromrecords([list(v)[-1] for k,v in g], dtype=x.dtype)

Upvotes: 3

Colonel Beauvel
Colonel Beauvel

Reputation: 31171

I do not know why you use this format but here is a hack with pandas:

import pandas as pd

df  = pd.DataFrame(x)
df_ = df.groupby('id')['vonmises'].max().reset_index()

In [213]: df_.merge(df, on=['id','vonmises'])[['id','vonmises','subcase']]

Out[213]:
array([[   2.,  400.,   44.],
       [   3.,  400.,   33.],
       [   5.,  500.,   55.]], dtype=float32)

Upvotes: 2

Related Questions