orange
orange

Reputation: 8090

Adding rows when using .apply()

I'd like to apply an "aggregation" function to a groupby instance of a DataFrame by which the result is not reducing the final number of rows, but increasing it (not quite an "aggregation"). For instance, the below call should result in a duplication of cat=A, B rows depending on the result returned by the function (with cola and colb columns). Unfortunately, the added index columns are somehow added as columns in the result.

>>> df = pd.DataFrame({
  'date': pd.date_range('1/1/2018', periods=10, freq='10D'),
  'val': range(10),
  'cat': ['A'] * 7 + ['B'] * 3
})

>>> def func(x):
        data = range(20)
        # 2 x 10 = 20 rows
        index = pd.MultiIndex.from_product([
          [3, 4], pd.date_range('1/1/2018', periods=10, freq='10D')
        ], names=['cola', 'colb'])
        return pd.Series(data, index=index)

>>> df.groupby('cat').apply(func)
cola          3                                                         \
colb 2018-01-01 2018-01-11 2018-01-21 2018-01-31 2018-02-10 2018-02-20   
cat                                                                      
A             0          1          2          3          4          5   
B             0          1          2          3          4          5   

cola                                                      4             \
colb 2018-03-02 2018-03-12 2018-03-22 2018-04-01 2018-01-01 2018-01-11   
cat                                                                      
A             6          7          8          9         10         11   
B             6          7          8          9         10         11   

cola                                                                    \
colb 2018-01-21 2018-01-31 2018-02-10 2018-02-20 2018-03-02 2018-03-12   
cat                                                                      
A            12         13         14         15         16         17   
B            12         13         14         15         16         17   

cola                        
colb 2018-03-22 2018-04-01  
cat                         
A            18         19  
B            18         19 

Is there anything I can do to get this done or is .apply() just not geared to duplicate rows?

Upvotes: 0

Views: 41

Answers (1)

piRSquared
piRSquared

Reputation: 294488

IIUC: you don't want to return a Series. Return a DataFrame instead.

def func(x):
    data = range(20)
    # 2 x 10 = 20 rows
    index = pd.MultiIndex.from_product([
      [3, 4], pd.date_range('1/1/2018', periods=10, freq='10D')
    ], names=['cola', 'colb'])
    return pd.DataFrame(dict(Stuff=data), index=index)

df.groupby('cat').apply(func)

                     Stuff
cat cola colb             
A   3    2018-01-01      0
         2018-01-11      1
         2018-01-21      2
         2018-01-31      3
         2018-02-10      4
         2018-02-20      5
         2018-03-02      6
         2018-03-12      7
         2018-03-22      8
         2018-04-01      9
    4    2018-01-01     10
         2018-01-11     11
         2018-01-21     12
         2018-01-31     13
         2018-02-10     14
         2018-02-20     15
         2018-03-02     16
         2018-03-12     17
         2018-03-22     18
         2018-04-01     19
B   3    2018-01-01      0
         2018-01-11      1
         2018-01-21      2
         2018-01-31      3
         2018-02-10      4
         2018-02-20      5
         2018-03-02      6
         2018-03-12      7
         2018-03-22      8
         2018-04-01      9
    4    2018-01-01     10
         2018-01-11     11
         2018-01-21     12
         2018-01-31     13
         2018-02-10     14
         2018-02-20     15
         2018-03-02     16
         2018-03-12     17
         2018-03-22     18
         2018-04-01     19

Alternatively, you can keep your Series and use pd.concat

def func(x):
    data = range(20)
    # 2 x 10 = 20 rows
    index = pd.MultiIndex.from_product([
      [3, 4], pd.date_range('1/1/2018', periods=10, freq='10D')
    ], names=['cola', 'colb'])
    return pd.Series(data, index=index)

pd.concat({key: func(value) for key, value in df.groupby('cat')})

   cola  colb      
A  3     2018-01-01     0
         2018-01-11     1
         2018-01-21     2
         2018-01-31     3
         2018-02-10     4
         2018-02-20     5
         2018-03-02     6
         2018-03-12     7
         2018-03-22     8
         2018-04-01     9
   4     2018-01-01    10
         2018-01-11    11
         2018-01-21    12
         2018-01-31    13
         2018-02-10    14
         2018-02-20    15
         2018-03-02    16
         2018-03-12    17
         2018-03-22    18
         2018-04-01    19
B  3     2018-01-01     0
         2018-01-11     1
         2018-01-21     2
         2018-01-31     3
         2018-02-10     4
         2018-02-20     5
         2018-03-02     6
         2018-03-12     7
         2018-03-22     8
         2018-04-01     9
   4     2018-01-01    10
         2018-01-11    11
         2018-01-21    12
         2018-01-31    13
         2018-02-10    14
         2018-02-20    15
         2018-03-02    16
         2018-03-12    17
         2018-03-22    18
         2018-04-01    19
dtype: int64

Upvotes: 1

Related Questions