Reputation: 19308
I'd like to write an any_lambda
function that checks if any of the elements in an ArrayType
column meet a condition specified by a lambda function.
Here's the code I have that's not working:
def any_lambda(f, l):
return any(list(map(f, l)))
spark.udf.register("any_lambda", any_lambda)
source_df = spark.createDataFrame(
[
("jose", [1, 2, 3]),
("li", [4, 5, 6]),
("luisa", [10, 11, 12]),
],
StructType([
StructField("name", StringType(), True),
StructField("nums", ArrayType(StringType(), True), True),
])
)
actual_df = source_df.withColumn(
"any_num_greater_than_5",
any_lambda(lambda n: n > 5, col("nums"))
)
This code raises TypeError: Column is not iterable
.
How can I create an any_lambda
function that works?
Upvotes: 1
Views: 3205
Reputation: 214927
Udf expects the arguments to be columns, a lambda
function is not a column; What you might do is define any_lambda
so that it takes a lambda function and return a udf
:
import pyspark.sql.functions as F
def any_lambda(f):
@F.udf
def temp_udf(l):
return any(map(f, l))
return temp_udf
source_df = spark.createDataFrame(
[
("jose", [1, 2, 3]),
("li", [4, 5, 6]),
("luisa", [10, 11, 12]),
],
StructType([
StructField("name", StringType(), True),
StructField("nums", ArrayType(IntegerType(), True), True),
])
)
actual_df = source_df.withColumn(
"any_num_greater_than_5",
any_lambda(lambda n: n > 5)(col("nums"))
)
actual_df.show()
+-----+------------+----------------------+
| name| nums|any_num_greater_than_5|
+-----+------------+----------------------+
| jose| [1, 2, 3]| false|
| li| [4, 5, 6]| true|
|luisa|[10, 11, 12]| true|
+-----+------------+----------------------+
Or as commented by @Powers, to be explicit about the returned column type, we can specify the returned type in the udf
like so:
def any_lambda(f):
def temp_udf(l):
return any(map(f, l))
return F.udf(temp_udf, BooleanType())
Now the schema looks like:
actual_df.printSchema()
root
|-- name: string (nullable = true)
|-- nums: array (nullable = true)
| |-- element: integer (containsNull = true)
|-- any_num_greater_than_5: boolean (nullable = true)
Upvotes: 7