Peascod
Peascod

Reputation: 51

Python-Polars: How to filter categorical column with string list

Update: df_cat.filter(pl.col('a_cat').is_in(['a', 'c'])) now works as expected in Polars.


I have a Polars dataframe like below:

df_cat = pl.DataFrame(
[
    pl.Series("a_cat", ["c", "a", "b", "c", "b"], dtype=pl.Categorical),
    pl.Series("b_cat", ["F", "G", "E", "G", "G"], dtype=pl.Categorical)
])
print(df_cat)
shape: (5, 2)
┌───────┬───────┐
│ a_cat ┆ b_cat │
│ ---   ┆ ---   │
│ cat   ┆ cat   │
╞═══════╪═══════╡
│ c     ┆ F     │
│ a     ┆ G     │
│ b     ┆ E     │
│ c     ┆ G     │
│ b     ┆ G     │
└───────┴───────┘

The following filter runs perfectly fine:

print(df_cat.filter(pl.col('a_cat') == 'c'))
shape: (2, 2)
┌───────┬───────┐
│ a_cat ┆ b_cat │
│ ---   ┆ ---   │
│ cat   ┆ cat   │
╞═══════╪═══════╡
│ c     ┆ F     │
│ c     ┆ G     │
└───────┴───────┘

What I want is to use a list of string to run the filter more efficiently. So I tried and ended up with the following error message:

print(df_cat.filter(pl.col('a_cat').is_in(['a', 'c'])))
---------------------------------------------------------------------------
ComputeError                              Traceback (most recent call last)
d:\GitRepo\Test2\stockEMD3.ipynb Cell 9 in <cell line: 1>()
----> 1 print(df_cat.filter(pl.col('a_cat').is_in(['c'])))

File c:\ProgramData\Anaconda3\envs\charm3.9\lib\site-packages\polars\internals\dataframe\frame.py:2185, in DataFrame.filter(self, predicate)
   2181 if _NUMPY_AVAILABLE and isinstance(predicate, np.ndarray):
   2182     predicate = pli.Series(predicate)
   2184 return (
-> 2185     self.lazy()
   2186     .filter(predicate)  # type: ignore[arg-type]
   2187     .collect(no_optimization=True, string_cache=False)
   2188 )

File c:\ProgramData\Anaconda3\envs\charm3.9\lib\site-packages\polars\internals\lazyframe\frame.py:660, in LazyFrame.collect(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, string_cache, no_optimization, slice_pushdown)
    650     projection_pushdown = False
    652 ldf = self._ldf.optimization_toggle(
    653     type_coercion,
    654     predicate_pushdown,
   (...)
    658     slice_pushdown,
    659 )
--> 660 return pli.wrap_df(ldf.collect())

ComputeError: joins/or comparisons on categorical dtypes can only happen if they are created under the same global string cache

From this Stackoverflow link I understand "You need to set a global string cache to compare categoricals created in different columns/lists." but my question is

  1. Why the == one single string filter case works?
  2. What is the proper way to filter a categorical column with a list of string?

Thanks!

Upvotes: 1

Views: 4401

Answers (3)

ritchie46
ritchie46

Reputation: 14630

As the error states: ComputeError: joins/or comparisons on categorical dtypes can only happen if they are created under the same global string cache.

Comparisons of categorical values are only allowed under a global string cache. You really want to set this in such a case as it speeds up comparisons and prevents expensive casts to strings.

Setting this on the start of your query will ensure it runs:

import polars as pl
pl.enable_string_cache()

Upvotes: 2

user18559875
user18559875

Reputation:

Actually, you don't need to set a global string cache to compare strings to Categorical variables. You can use cast to accomplish this.

Let's use this data. I've included the integer values that underlie the Categorical variables to demonstrate something later.

import polars as pl

df_cat = (
    pl.DataFrame(
        [
            pl.Series("a_cat", ["c", "a", "b", "c", "X"], dtype=pl.Categorical),
            pl.Series("b_cat", ["F", "G", "E", "S", "X"], dtype=pl.Categorical),
        ]
    )
    .with_columns(
        pl.all().to_physical().name.suffix('_phys')
    )
)
df_cat
shape: (5, 4)
┌───────┬───────┬────────────┬────────────┐
│ a_cat ┆ b_cat ┆ a_cat_phys ┆ b_cat_phys │
│ ---   ┆ ---   ┆ ---        ┆ ---        │
│ cat   ┆ cat   ┆ u32        ┆ u32        │
╞═══════╪═══════╪════════════╪════════════╡
│ c     ┆ F     ┆ 0          ┆ 0          │
│ a     ┆ G     ┆ 1          ┆ 1          │
│ b     ┆ E     ┆ 2          ┆ 2          │
│ c     ┆ S     ┆ 0          ┆ 3          │
│ X     ┆ X     ┆ 3          ┆ 4          │
└───────┴───────┴────────────┴────────────┘

Comparing a categorical variable to a string

If we cast a Categorical variable back to its string values, we can make any comparison we need. For example:

df_cat.filter(pl.col('a_cat').cast(pl.String).is_in(['a', 'c']))
shape: (3, 4)
┌───────┬───────┬────────────┬────────────┐
│ a_cat ┆ b_cat ┆ a_cat_phys ┆ b_cat_phys │
│ ---   ┆ ---   ┆ ---        ┆ ---        │
│ cat   ┆ cat   ┆ u32        ┆ u32        │
╞═══════╪═══════╪════════════╪════════════╡
│ c     ┆ F     ┆ 0          ┆ 0          │
│ a     ┆ G     ┆ 1          ┆ 1          │
│ c     ┆ S     ┆ 0          ┆ 3          │
└───────┴───────┴────────────┴────────────┘

Or in a filter step comparing the string values of two Categorical variables that do not share the same string cache.

df_cat.filter(pl.col('a_cat').cast(pl.String) == pl.col('b_cat').cast(pl.String))
shape: (1, 4)
┌───────┬───────┬────────────┬────────────┐
│ a_cat ┆ b_cat ┆ a_cat_phys ┆ b_cat_phys │
│ ---   ┆ ---   ┆ ---        ┆ ---        │
│ cat   ┆ cat   ┆ u32        ┆ u32        │
╞═══════╪═══════╪════════════╪════════════╡
│ X     ┆ X     ┆ 3          ┆ 4          │
└───────┴───────┴────────────┴────────────┘

Notice that it is the string values being compared (not the integers underlying the two Categorical variables).

The equality operator on Categorical variables

The following statements are equivalent:

df_cat.filter((pl.col('a_cat') == 'a'))
df_cat.filter((pl.col('a_cat').cast(pl.String) == 'a'))

The former is syntactic sugar for the latter, as the former is a common use case.

Upvotes: 2

Martin Z&#228;ch
Martin Z&#228;ch

Reputation: 23

This is a new answer based on the one from @ritchie46.

It is now:

import polars as pl
pl.enable_string_cache()

Also a StringCache() Context manager can be used, see polars documentation:

with pl.StringCache():
   print(df_cat.filter(pl.col('a_cat').is_in(['a', 'c'])))

Upvotes: 1

Related Questions