Reputation: 123
root
|-- id: string (nullable = true)
|-- elements: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- name: string (nullable = true)
| | |-- data: struct (nullable = true)
| | | |-- name: string (nullable = true)
| | | |-- surname: string (nullable = true)
| | |-- value: float (nullable = true)
| | |-- othername: string (nullable = true)
Having that dataframe structure, I'm trying to filter for elements in which value is greater than X e.g. 0.5. However when I try to filter it:
df.where(col('elements.value') > 0.5)
it throws
cannot resolve '(spark_catalog.default.tempD.`elements`.`value` > 0.5D)' due to data type mismatch:
differing types in '(spark_catalog.default.tempD.`elements`.`value` > 0.5D)' (array<float> and double).;;
I can't figure out how to fix that. Wrapping value with float() e.g. float(0.5) changes nothing. I bet it is a simple fix, but I'm struggling with it too many hours.
Upvotes: 1
Views: 254
Reputation: 42422
You can try higher order expressions to filter the array:
df2 = df.selectExpr('id', 'filter(elements, x -> x.value > 0.5) filtered')
A normal where
filter doesn't work because it cannot be applied onto an array. Imagine if your array contains two structs, one having value > 0.5 and the other value < 0.5. It's not possible to determine whether that row should be included or not.
If you want to filter the rows where ALL values in the array are > 0.5, you can use
df.where('array_min(transform(elements, x -> x.value > 0.5))')
the clause is only True
if every item in the array returns True
.
Upvotes: 1