Raghu
Raghu

Reputation: 1712

pyspark getting the field names of a a struct datatype inside a udf

I am trying to pass multiple columns to a udf as a StructType (using pyspark.sql.functions.struct()).

Inside this udf I want to get the fields of the struct column that I passed as a list, so that I can iterate over the passed columns for every row.

Basically I am looking for a pyspark version of the scala code provided in this answer - Spark - pass full row to a udf and then get column name inside udf

Upvotes: 2

Views: 2808

Answers (1)

pault
pault

Reputation: 43504

You can use the same method as on the post you linked, i.e. by using a pyspark.sql.Row. But instead of .schema.fieldNames, you can use .asDict() to convert the Row into a dictionary.

For example, here is a way to iterate over the column names and values simultaneously:

from pyspark.sql.functions import col, struct, udf

df = spark.createDataFrame([(1, 2, 3)], ["a", "b", "c"])
f = udf(lambda row: "; ".join(["=".join(map(str, [k,v])) for k, v in row.asDict().items()]))
df.select(f(struct(*df.columns)).alias("myUdfOutput")).show()
#+-------------+
#|  myUdfOutput|
#+-------------+
#|a=1; c=3; b=2|
#+-------------+

An alternative would be to build a MapType() of column name to value, and pass this to your udf.

from itertools import chain
from pyspark.sql.functions import create_map, lit

f2 = udf(lambda row: "; ".join(["=".join(map(str, [k,v])) for k, v in row.items()]))
df.select(
    f2(
        create_map(
            *chain.from_iterable([(lit(c), col(c)) for c in df.columns])
        )
    ).alias("myNewUdfOutput")
).show()
#+--------------+
#|myNewUdfOutput|
#+--------------+
#| a=1; c=3; b=2|
#+--------------+

This second method is arguably unnecessarily complicated, so the first option is the recommended approach.

Upvotes: 1

Related Questions