Sledge
Sledge

Reputation: 1345

Polars join two dataframes if column value in other column

I have two dataframes that I'd like to join if one column's value is contained in the other column. The dataframes look like this:

df1 = pl.DataFrame({"col1": [1, 2, 3], "col2": ["x1, x2, x3", "x2, x3", "x3"]})
df2 = pl.DataFrame({"col3": [4, 5, 6], "col4": ["x1", "x2", "x3"]})

I tried to do:

model_data = df1.join(df2, on="col2")

Which does not produce the desired result. What I'd like to see is something like this:

shape: (6, 4)
┌──────┬────────────┬──────┬──────┐
│ col1 ┆ col2       ┆ col3 ┆ col4 │
│ ---  ┆ ---        ┆ ---  ┆ ---  │
│ i64  ┆ str        ┆ i64  ┆ str  │
╞══════╪════════════╪══════╪══════╡
│ 1    ┆ x1, x2, x3 ┆ 4    ┆ x1   │
│ 1    ┆ x1, x2, x3 ┆ 5    ┆ x2   │
│ 1    ┆ x1, x2, x3 ┆ 6    ┆ x3   │
│ 2    ┆ x2, x3     ┆ 5    ┆ x2   │
│ 2    ┆ x2, x3     ┆ 6    ┆ x3   │
│ 3    ┆ x3         ┆ 6    ┆ x3   │
└──────┴────────────┴──────┴──────┘

It's a question of how you do the join when one value is contained by another value. I could not find good examples of this in the docs.

Upvotes: 1

Views: 3436

Answers (2)

Dean MacGregor
Dean MacGregor

Reputation: 18446

Another approach which might be counterintuitively faster is to do a cross join between them and then filter out the times that the col4 is not in col2

That would look something like this...

cj = df1.join(df2, how='cross')
filt = cj.map_rows(lambda x: x[3] in x[1])

(
    cj.with_columns(filt.to_series().alias('filt'))
    .filter(pl.col('filt') == True)
    .select(pl.exclude('filt'))
)

Essentially what happens is that you create cj which is a df that has every row of df1 and df2 mashed together. You then create filt which is just a series of Trues and Falses that you can filter by. You filter by that and then do a select to exclude that helper column. You just have to be careful of those index positions in the lambda expression of that second line.

You'll have to test the performance of this vs @jqurious's method. If (big if, I don't know) this one is faster then it's because the str.split.explode isn't as efficient as just mashing everything together. Unfortunately the series.str.contains method is looking for a fixed regex or literal so that's why this uses a lambda.

Upvotes: 2

jqurious
jqurious

Reputation: 51

You could .str.split() col2 and .explode() the resulting list.

(df1.with_columns(pl.col("col2").str.split(", ").alias("col4")) 
    .explode("col4")
)
shape: (6, 3)
┌──────┬────────────┬──────┐
│ col1 ┆ col2       ┆ col4 │
│ ---  ┆ ---        ┆ ---  │
│ i64  ┆ str        ┆ str  │
╞══════╪════════════╪══════╡
│ 1    ┆ x1, x2, x3 ┆ x1   │
│ 1    ┆ x1, x2, x3 ┆ x2   │
│ 1    ┆ x1, x2, x3 ┆ x3   │
│ 2    ┆ x2, x3     ┆ x2   │
│ 2    ┆ x2, x3     ┆ x3   │
│ 3    ┆ x3         ┆ x3   │
└──────┴────────────┴──────┘

Which you can then .join() on.

(df1.with_columns(pl.col("col2").str.split(", ").alias("col4")) 
    .explode("col4")
    .join(df2, on="col4")
)
shape: (6, 4)
┌──────┬────────────┬──────┬──────┐
│ col1 ┆ col2       ┆ col4 ┆ col3 │
│ ---  ┆ ---        ┆ ---  ┆ ---  │
│ i64  ┆ str        ┆ str  ┆ i64  │
╞══════╪════════════╪══════╪══════╡
│ 1    ┆ x1, x2, x3 ┆ x1   ┆ 4    │
│ 1    ┆ x1, x2, x3 ┆ x2   ┆ 5    │
│ 1    ┆ x1, x2, x3 ┆ x3   ┆ 6    │
│ 2    ┆ x2, x3     ┆ x2   ┆ 5    │
│ 2    ┆ x2, x3     ┆ x3   ┆ 6    │
│ 3    ┆ x3         ┆ x3   ┆ 6    │
└──────┴────────────┴──────┴──────┘

You can rearrange the column order if desired.

Upvotes: 5

Related Questions