ignoring_gravity
ignoring_gravity

Reputation: 10531

when / then / otherwise with values from numpy array

Say I have

df = pl.DataFrame({'group': [1, 1, 1, 3, 3, 3, 4, 4]})

I have a numpy array of values, which I'd like to replace 'group' 3 with

values = np.array([9, 8, 7])

Here's what I've tried:

(
    df
    .with_columns(
        pl.when(pl.col('group')==3)
        .then(values)
        .otherwise(pl.col('group'))
    ).alias('group')
)
ShapeError: shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation

How can I do this correctly?

Upvotes: 1

Views: 451

Answers (3)

jqurious
jqurious

Reputation: 21534

Not really simpler, but the map_elements can be swapped out:

df.with_columns(
    pl.when(pl.col.group == 3)
      .then(
         pl.lit(values).get(pl.int_range(pl.len()).over("group"))
      )
      .otherwise(pl.col.group)
      .alias("group")
)
shape: (8, 1)
┌───────┐
│ group │
│ ---   │
│ i64   │
╞═══════╡
│ 1     │
│ 1     │
│ 1     │
│ 9     │
│ 8     │
│ 7     │
│ 4     │
│ 4     │
└───────┘

Upvotes: 0

ritchie46
ritchie46

Reputation: 14730

A few things to consider.

  • One is that you always should convert your numpy arrays to polars Series as we will use the arrow memory specification underneath and not numpys.

  • Second is that when -> then -> otherwise operates on columns that are of equal length. We nudge the API in such a direction that you define a logical statement based of columns in your DataFrame and therefore you should not know the indices (nor the lenght of a group) that you want to replace. This allows for much optimizations because if you do not define indices to replace we can push down a filter before that expression.

Anyway, your specific situation does know the length of the group, so we must use something different. We can first compute the indices where the conditional holds and then modify based on those indices.

df = pl.DataFrame({
    "group": [1, 1, 1, 3, 3, 3, 4, 4]
})

values = np.array([9, 8, 7])

# compute indices of the predicate
idx = df.select(
    pl.arg_where(pl.col("group") == 3)
).to_series()

# mutate on those locations
df.with_columns(
    df["group"].scatter(idx, pl.Series(values))
)

Upvotes: 3

ignoring_gravity
ignoring_gravity

Reputation: 10531

Here's all I could come up with

df.with_columns(
    pl.when(pl.col("group") == 3)
      .then(
         pl.col("group").cum_count().over("group")
           .map_elements(lambda x: values[x - 1])
      )
      .otherwise("group")
)
shape: (8, 1)
┌───────┐
│ group │
│ ---   │
│ i64   │
╞═══════╡
│ 1     │
│ 1     │
│ 1     │
│ 9     │
│ 8     │
│ 7     │
│ 4     │
│ 4     │
└───────┘

Surely there's a simpler way?

Upvotes: 1

Related Questions