Lubu
Lubu

Reputation: 123

Filtering array as a column in dataframe

root
 |-- id: string (nullable = true)
 |-- elements: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- data: struct (nullable = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- surname: string (nullable = true)
 |    |    |-- value: float (nullable = true)
 |    |    |-- othername: string (nullable = true)

Having that dataframe structure, I'm trying to filter for elements in which value is greater than X e.g. 0.5. However when I try to filter it:

df.where(col('elements.value') > 0.5)

it throws

cannot resolve '(spark_catalog.default.tempD.`elements`.`value` > 0.5D)' due to data type mismatch:
 differing types in '(spark_catalog.default.tempD.`elements`.`value` > 0.5D)' (array<float> and double).;;

I can't figure out how to fix that. Wrapping value with float() e.g. float(0.5) changes nothing. I bet it is a simple fix, but I'm struggling with it too many hours.

Upvotes: 1

Views: 254

Answers (1)

mck
mck

Reputation: 42422

You can try higher order expressions to filter the array:

df2 = df.selectExpr('id', 'filter(elements, x -> x.value > 0.5) filtered')

A normal where filter doesn't work because it cannot be applied onto an array. Imagine if your array contains two structs, one having value > 0.5 and the other value < 0.5. It's not possible to determine whether that row should be included or not.

If you want to filter the rows where ALL values in the array are > 0.5, you can use

df.where('array_min(transform(elements, x -> x.value > 0.5))')

the clause is only True if every item in the array returns True.

Upvotes: 1

Related Questions