Powers
Powers

Reputation: 19308

Writing a Pyspark UDF that functions like the Python any function

I'd like to write an any_lambda function that checks if any of the elements in an ArrayType column meet a condition specified by a lambda function.

Here's the code I have that's not working:

def any_lambda(f, l):
    return any(list(map(f, l)))

spark.udf.register("any_lambda", any_lambda)

source_df = spark.createDataFrame(
    [
        ("jose", [1, 2, 3]),
        ("li", [4, 5, 6]),
        ("luisa", [10, 11, 12]),
    ],
    StructType([
        StructField("name", StringType(), True),
        StructField("nums", ArrayType(StringType(), True), True),
    ])
)

actual_df = source_df.withColumn(
    "any_num_greater_than_5",
    any_lambda(lambda n: n > 5, col("nums"))
)

This code raises TypeError: Column is not iterable.

How can I create an any_lambda function that works?

Upvotes: 1

Views: 3205

Answers (1)

akuiper
akuiper

Reputation: 214927

Udf expects the arguments to be columns, a lambda function is not a column; What you might do is define any_lambda so that it takes a lambda function and return a udf:

import pyspark.sql.functions as F

def any_lambda(f):
    @F.udf
    def temp_udf(l):
        return any(map(f, l))
    return temp_udf

source_df = spark.createDataFrame(
    [
        ("jose", [1, 2, 3]),
        ("li", [4, 5, 6]),
        ("luisa", [10, 11, 12]),
    ],
    StructType([
        StructField("name", StringType(), True),
        StructField("nums", ArrayType(IntegerType(), True), True),
    ])
)

actual_df = source_df.withColumn(
    "any_num_greater_than_5",
    any_lambda(lambda n: n > 5)(col("nums"))
)

actual_df.show()
+-----+------------+----------------------+
| name|        nums|any_num_greater_than_5|
+-----+------------+----------------------+
| jose|   [1, 2, 3]|                 false|
|   li|   [4, 5, 6]|                  true|
|luisa|[10, 11, 12]|                  true|
+-----+------------+----------------------+

Or as commented by @Powers, to be explicit about the returned column type, we can specify the returned type in the udf like so:

def any_lambda(f):
    def temp_udf(l):
        return any(map(f, l))
    return F.udf(temp_udf, BooleanType())

Now the schema looks like:

actual_df.printSchema()
root
 |-- name: string (nullable = true)
 |-- nums: array (nullable = true)
 |    |-- element: integer (containsNull = true)
 |-- any_num_greater_than_5: boolean (nullable = true)

Upvotes: 7

Related Questions