jss367
jss367

Reputation: 5381

How to filter a polars DataFrame with list type columns

Update: This was fixed by pull/10857

>>> df.filter(pl.col('a') == ['1'])
shape: (1, 2)
┌───────────┬────────────┐
│ a         ┆ b          │
│ ---       ┆ ---        │
│ list[str] ┆ list[str]  │
╞═══════════╪════════════╡
│ ["1"]     ┆ ["6", "7"] │
└───────────┴────────────┘

I'm working with a polars DataFrame that contains columns with list type values. Here's how my DataFrame looks:

df = pl.DataFrame({
    'a': [['1', '3'], ['3'], ['1'], ['2']], 
    'b':[['8'],['6'],['6', '7'],['7']]
})

I want to filter this DataFrame based on the 'a' column containing the list ['1']. I tried using the following code:

df.filter(pl.col('a') == ['1'])

but I get this error:

ComputeError: cannot cast List type (inner: 'String', to: 'String')

The filter works when the DataFrame has non-list type values, as in:

df = pl.DataFrame({'a': ['3', '3', '1', '2'], 'b':['8','6','7','7']})
df.filter(pl.col('a') == '1')

How can I apply filtering to list type columns using polars? I'm not interested in solutions using pandas, only polars.

Upvotes: 3

Views: 1179

Answers (1)

Dean MacGregor
Dean MacGregor

Reputation: 18341

It seems you can't do it directly because of how it broadcasts and casts Lists which have nested types.

One workaround is to make a column with the list in question and then filter like this:

df.with_columns(c=['1']).filter(pl.col('a')==pl.col('c')).drop('c')
shape: (1, 2)
┌───────────┬────────────┐
│ a         ┆ b          │
│ ---       ┆ ---        │
│ list[str] ┆ list[str]  │
╞═══════════╪════════════╡
│ ["1"]     ┆ ["6", "7"] │
└───────────┴────────────┘

Alternatively you can manually broadcast the underlying list like this:

df.filter(pl.col('a')==pl.Series([['1']]*df.shape[0]))
shape: (1, 2)
┌───────────┬────────────┐
│ a         ┆ b          │
│ ---       ┆ ---        │
│ list[str] ┆ list[str]  │
╞═══════════╪════════════╡
│ ["1"]     ┆ ["6", "7"] │
└───────────┴────────────┘

Upvotes: 1

Related Questions