kmartin
kmartin

Reputation: 187

How to get last N rows RELATIVE to another row in pandas (vector solution)?

I've asked this question in the context of another longer one, but I think I tried to ask too many things at once. So, for simplicity:

I have a data frame where a key is pressed every trial. I want to add a column that shows the last N rows. So if my data looks like this:

trial sid  key_pressed        RT  
1     S04            x  0.502242        
2     S04            m  0.348620      
3     S04            m  0.312491       
4     S04            x  0.342541      
5     S04            n  0.419384       
6     S04            n  0.348211      
7     S04            z  0.376369   

afterward it would look like this (for every individual sid):

trial sid  key_pressed        RT           last_3
1     S04            x  0.502242        NaN
2     S04            m  0.348620        NaN
3     S04            m  0.312491        [x, m, m]
4     S04            x  0.342541        [m, m, x]
5     S04            n  0.419384        [m, x, n]
6     S04            n  0.348211        [x, n, n]
7     S04            z  0.376369        [n, n, z]

Is there a vectorized solution to this? I can't seem to figure out how to select relative rows. (New to pandas - not great at thinking this way, yet)

UPDATE: Based on advice from contributors below, I wound up doing this:

df['shifted'] = pd.concat([df.groupby('sid')['key_pressed'].shift(2) + df.groupby('sid')['key_pressed'].shift(1) + df.groupby('sid')['key_pressed'].shift(0)])

which created a string mxm, for example. Which is better.

Upvotes: 2

Views: 973

Answers (4)

Alex Riley
Alex Riley

Reputation: 176860

One way would be to use shift to move the relevant column down n rows and then concatenate the entries (they are strings so we can use +):

df.last_3 = df.key_pressed.shift(1) + ', ' + df.key_pressed.shift(2) + ', ' + df.key_pressed.shift(3)

This creates strings of the previous three entries separated by a comma and space (not lists). I'd avoid using lists in DataFrames if possible as things can get a little messy.

Upvotes: 2

Marius
Marius

Reputation: 60080

This solution avoids looping, but I'm not sure whether it really counts as 'vectorized', since once you start using apply() I think you start losing any performance benefits granted by vectorization:

key_table = pd.concat(
    [df.key_pressed.shift(2), df.key_pressed.shift(1), df.key_pressed], 
    axis=1
)
 df['last_3'] = key_table.apply(
    lambda row: ', '.join(str(k) for k in row),
    axis=1
)

Output:

   trial  sid key_pressed        RT       last_3
0      1  S04           x  0.502242  nan, nan, x
1      2  S04           m  0.348620    nan, x, m
2      3  S04           m  0.312491      x, m, m
3      4  S04           x  0.342541      m, m, x
4      5  S04           n  0.419384      m, x, n
5      6  S04           n  0.348211      x, n, n
6      7  S04           z  0.376369      n, n, z

Upvotes: 0

Dan Allan
Dan Allan

Reputation: 35255

What do you want to do with those lists? Storing lists inside Series/DataFrames is not usually very convenient. Anyway, this would get you close. You have to handle the nans, and then you're done.

In [6]: pd.concat([df.key_pressed.shift(i) for i in [0, 1, 2]], 1).apply(tuple, 1).map(list)
Out[6]: 
0    [x, nan, nan]
1      [m, x, nan]
2        [m, m, x]
3        [x, m, m]
4        [n, x, m]
5        [n, n, x]
6        [z, n, n]
dtype: object

Notice that we have to convert to a tuple and then a list, to avoid pandas automatically taking our list and making it back into a Series. Try this and you'll see why it doesn't work:

pd.concat([df.key_pressed.shift(i) for i in [0, 1, 2]], 1).apply(list, 1)

Upvotes: 1

kmartin
kmartin

Reputation: 187

Oh - perhaps this is the best solution. One can "shift" the data by a certain amount:

df['shifted'] = df.groupby('sid')['key_pressed'].shift(2)

Then I could create lists from this shifted data.

Upvotes: 0

Related Questions