cccs31
cccs31

Reputation: 188

Filter list using another list as a boolean mask in polars

I have a polars dataframe containing two columns where both columns are lists.

df = pl.DataFrame({
    'a': [[True, False], [False, True]],
    'b': [['name1', 'name2'], ['name3', 'name4']]
})
shape: (2, 2)
┌───────────────┬────────────────────┐
│ a             ┆ b                  │
│ ---           ┆ ---                │
│ list[bool]    ┆ list[str]          │
╞═══════════════╪════════════════════╡
│ [true, false] ┆ ["name1", "name2"] │
│ [false, true] ┆ ["name3", "name4"] │
└───────────────┴────────────────────┘

I want to filter column b using column a as a boolean mask. The length of each list in column a is always the same as the length of each list in column b.

I can think of using an explode, then filtering, aggregating, and performing a join, but in some cases a join column is not available, and I would rather avoid this method for simplicity.

Are there any other ways to filter a list using another list as a boolean mask? I have tried using .list.eval, but it does not seem to accept operations involving other columns.

Any help would be appreciated!

Upvotes: 3

Views: 4886

Answers (2)

jqurious
jqurious

Reputation: 21229

.list.gather() has since been added which can take a list of indexes.

There is no .list.arg_true() as of yet, but you can use .list.eval()

df.select(idxs = pl.col.a.list.eval(pl.element().arg_true()))
shape: (2, 1)
┌───────────┐
│ idxs      │
│ ---       │
│ list[u32] │
╞═══════════╡
│ [0]       │
│ [1]       │
└───────────┘

Which can be given to .list.gather()

df.with_columns(
   pl.col.b.list.gather(pl.col.a.list.eval(pl.element().arg_true()))
     .alias("c")
)
shape: (2, 3)
┌───────────────┬────────────────────┬───────────┐
│ a             ┆ b                  ┆ c         │
│ ---           ┆ ---                ┆ ---       │
│ list[bool]    ┆ list[str]          ┆ list[str] │
╞═══════════════╪════════════════════╪═══════════╡
│ [true, false] ┆ ["name1", "name2"] ┆ ["name1"] │
│ [false, true] ┆ ["name3", "name4"] ┆ ["name4"] │
└───────────────┴────────────────────┴───────────┘

Upvotes: 0

ritchie46
ritchie46

Reputation: 14690

This is not the most ideal solution, as we groom the data to have a group for every list exploded to it's elements. Then we group_by again by that groups and apply the filter.

(df.with_row_index()
   .explode("a", "b")
   .group_by("index")
   .agg(
       pl.col("b").filter(pl.col("a"))
   )
)
shape: (2, 2)
┌───────┬───────────┐
│ index ┆ b         │
│ ---   ┆ ---       │
│ u32   ┆ list[str] │
╞═══════╪═══════════╡
│ 0     ┆ ["name1"] │
│ 1     ┆ ["name4"] │
└───────┴───────────┘

Maybe we can come up with something better in polars. It would be nice if the list.eval could access other columns. TBC!

Edit 02-06-2022

In polars-0.13.41 this will not be so expensive as that you might think. Polars knows that the row_count is sorted and maintains sorted in the whole query. The explodes are also free for the list columns.

When polars knows a groupby key is sorted, the groupby operation will be ~15x faster.

In the query above you would only pay for:

  • exploding the row index
  • grouping the sorted key (which is super fast)
  • traversing the list (which is something we would need to pay anyway).

To ensure that it runs fast, you can run the query with POLARS_VERBOSE=1. This will write the following text to stderr:

could fast explode column a
could fast explode column b
keys/aggregates are not partitionable: running default HASH AGGREGATION
groupby keys are sorted; running sorted key fast path

Upvotes: 3

Related Questions