Vishal
Vishal

Reputation: 2336

Cumulative sums and carryovers - vectorize with pandas

Can I do the following with pandas.Series math on a and b without looping through explicitly?

In [38]: a = pd.Series([4, 8, 3, 6, 2])

In [39]: b = pd.Series([3, 9, 5, 5, 4])

In [40]: alist = a.tolist()
    ...: blist = b.tolist()
    ...: for i in range(len(alist)):
    ...:     diff = max(0, alist[i] - blist[i])
    ...:     try:
    ...:         alist[i + 1] = alist[i + 1] + diff
    ...:     except IndexError:
    ...:         if diff > 0:
    ...:             alist.append(diff)
    ...:     blist[i] = max(0, blist[i] - alist[i])
    ...: 

In [41]: alist
Out[41]: [4, 9, 3, 6, 3]

In [42]: blist
Out[42]: [0, 0, 2, 0, 1]

I'm incrementing the next value of a with the difference of a and b if it is greater than zero, and then subtracting b from that cumulative-sum-like calc.

Upvotes: 3

Views: 63

Answers (4)

harpan
harpan

Reputation: 8631

Consider below code that uses .shift() and then roll() .

df=pd.DataFrame({
    'a': a,
    'b': b
})
alist = list(np.roll((df['a'].shift(-1)+(df['a']-df['b']).clip(lower=0)).fillna(df.iloc[0]['a']), 1).astype(int))
blist = list((df['b'] - alist).clip(lower=0))
print(allist)
print(blist)

Output:

[4, 9, 3, 6, 3]
[0, 0, 2, 0, 1]

Upvotes: 0

BENY
BENY

Reputation: 323306

IIUc, you need shift (this line can be replace by shift alist[i + 1] = alist[i + 1] + diff)

alist=a.add((a-b).clip(lower=0).shift(),fill_value=0).astype(int)
blist=(b-alist).clip_lower(0)
alist
Out[340]: 
0    4
1    9
2    3
3    6
4    3

blist
Out[341]: 
0    0
1    0
2    2
3    0
4    1

Upvotes: 2

cmaher
cmaher

Reputation: 5215

Here's another numpy approach using where and roll:

alist = np.where(np.roll(a - b > 0, 1), a + np.roll(a - b, 1), a)
blist = np.maximum(b.values - alist, 0)

print alist
# [4 9 3 6 3]
print blist
# [0 0 2 0 1]

Upvotes: 1

jpp
jpp

Reputation: 164693

This is one way using numpy:

import numpy as np

a += np.maximum(0, a-b).shift().fillna(0).astype(int)
b = np.maximum(0, b - a)

print(a)

0    4
1    9
2    3
3    6
4    3
dtype: int64

print(b)

0    0
1    0
2    2
3    0
4    1
dtype: int64

Upvotes: 2

Related Questions