Dan Udwary
Dan Udwary

Reputation: 23

How can I combine rows in a pandas dataframe based on comparing values in two columns?

Consider a pandas dataframe like:

df = pd.DataFrame({'id': ['001', '001', '002', '002', '003', '003', '004', '004', '005', '005'], 
                      'start': [1, 200, 200, 1, 1, 200, 200, 1, 1, 1000],
                      'end': [1000, 500, 500, 1000, 500, 1000, 1000, 500, 500, 2000]})
or 
    id  start   end
0  001      1  1000
1  001    200   500
2  002    200   500
3  002      1  1000
4  003      1   500
5  003    200  1000
6  004    200  1000
7  004      1   500
8  005      1   500
9  005   1000  2000

I would like to end up with a pandas dataframe such that if the starts and ends produce overlap for a row with a given id, then they are combined. (Indices are unimportant, here.) Is there a clever or efficient way to do this without resorting to a lot of complicated iteration? (My actual data may have up to millions of rows.)

The end result from the example above should be:

   id  start   end
  001      1  1000
  002      1  1000
  003      1  1000
  004      1  1000
  005      1   500
  005   1000  2000

Upvotes: 2

Views: 91

Answers (2)

jordiplam
jordiplam

Reputation: 66

One possible way to do this could be grouping by the id column and then applying a function to merge intervals:

import pandas as pd

# Load or create the dataframe df.

def merge_intervals(group):
    l = zip(group['start'], group['end'])
    merged = []
    for i in sorted(l):
        if not merged or merged[-1][1] < i[0]:
            merged.append(list(i))
        else:
            merged[-1][1] = max(merged[-1][1], i[1])
    start, end = zip(*[(x[0], x[1]) for x in merged])
    return pd.DataFrame({
        'id':    group['id'][0],
        'start': start,
        'end':   end
    })

df_new = df.groupby(df['id'], as_index=False).apply(merge_intervals)

In your example, the output looks like this

      id  start   end
0 0  001      1  1000
1 0  002      1  1000
2 0  003      1  1000
3 0  004      1  1000
4 0  005      1   500
  1  005   1000  2000

A MultiIndex DataFrame is created, and it will have the same columns.

Thanks to @scott-boston for noticing the error, and @henry-yik for its answer.

Upvotes: 1

Henry Yik
Henry Yik

Reputation: 22503

One way is to create a function to merge the intervals and then groupby and apply:

def merge(l):
    l = sorted(l, key=lambda x: x[0])
    merged = []
    for i in l:
        if not merged or merged[-1][1] < i[0]:
            merged.append(i)
        else:
            merged[-1][1] = max(merged[-1][1], i[1])
    return merged

print (df.groupby("id").apply(lambda d: merge(d[["start","end"]].values)).explode())

id
001       [1, 1000]
002       [1, 1000]
003       [1, 1000]
004       [1, 1000]
005        [1, 500]
005    [1000, 2000]
dtype: object

Upvotes: 1

Related Questions