Reputation: 1712
I am trying to pass multiple columns to a udf
as a StructType
(using pyspark.sql.functions.struct()
).
Inside this udf
I want to get the fields of the struct column that I passed as a list
, so that I can iterate over the passed columns for every row.
Basically I am looking for a pyspark version of the scala code provided in this answer - Spark - pass full row to a udf and then get column name inside udf
Upvotes: 2
Views: 2808
Reputation: 43504
You can use the same method as on the post you linked, i.e. by using a pyspark.sql.Row
. But instead of .schema.fieldNames
, you can use .asDict()
to convert the Row
into a dictionary.
For example, here is a way to iterate over the column names and values simultaneously:
from pyspark.sql.functions import col, struct, udf
df = spark.createDataFrame([(1, 2, 3)], ["a", "b", "c"])
f = udf(lambda row: "; ".join(["=".join(map(str, [k,v])) for k, v in row.asDict().items()]))
df.select(f(struct(*df.columns)).alias("myUdfOutput")).show()
#+-------------+
#| myUdfOutput|
#+-------------+
#|a=1; c=3; b=2|
#+-------------+
An alternative would be to build a MapType()
of column name to value, and pass this to your udf
.
from itertools import chain
from pyspark.sql.functions import create_map, lit
f2 = udf(lambda row: "; ".join(["=".join(map(str, [k,v])) for k, v in row.items()]))
df.select(
f2(
create_map(
*chain.from_iterable([(lit(c), col(c)) for c in df.columns])
)
).alias("myNewUdfOutput")
).show()
#+--------------+
#|myNewUdfOutput|
#+--------------+
#| a=1; c=3; b=2|
#+--------------+
This second method is arguably unnecessarily complicated, so the first option is the recommended approach.
Upvotes: 1