Reputation: 3257
I have a pyspark dataframe. I have to do a group by and then aggregate certain columns into a list so that I can apply a UDF on the data frame.
As an example, I have created a dataframe and then grouped by person.
df = spark.createDataFrame(a, ["Person", "Amount","Budget", "Date"])
df = df.groupby("Person").agg(F.collect_list(F.struct("Amount", "Budget", "Date")).alias("data"))
df.show(truncate=False)
+------+----------------------------------------------------------------------------+
|Person|data |
+------+----------------------------------------------------------------------------+
|Bob |[[85.8,Food,2017-09-13], [7.8,Household,2017-09-13], [6.52,Food,2017-06-13]]|
+------+----------------------------------------------------------------------------+
I have left out the UDF but the resulting data frame from the UDF is below.
+------+--------------------------------------------------------------+
|Person|res |
+------+--------------------------------------------------------------+
|Bob |[[562,Food,June,1], [380,Household,Sept,4], [880,Food,Sept,2]]|
+------+--------------------------------------------------------------+
I need to convert the resulting dataframe into rows where each element in list is a new row with a new column. This can be seen below.
+------+------------------------------+
|Person|Amount|Budget |Month|Cluster|
+------+------------------------------+
|Bob |562 |Food |June |1 |
|Bob |380 |Household|Sept |4 |
|Bob |880 |Food |Sept |2 |
+------+------------------------------+
Upvotes: 9
Views: 10797
Reputation: 2718
You can use explode
and getItem
as follows:
# starting from this form:
+------+--------------------------------------------------------------
|Person|res |
+------+--------------------------------------------------------------+
|Bob |[[562,Food,June,1], [380,Household,Sept,4], [880,Food,Sept,2]]|
+------+--------------------------------------------------------------+
import pyspark.sql.functions as F
# explode res to have one row for each item in res
exploded_df = df.select("*", F.explode("res").alias("exploded_data"))
exploded_df.show(truncate=False)
# then use getItem to create separate columns
exploded_df = exploded_df.withColumn(
"Amount",
F.col("exploded_data").getItem("Amount") # either get by name or by index e.g. getItem(0) etc
)
exploded_df = exploded_df.withColumn(
"Budget",
F.col("exploded_data").getItem("Budget")
)
exploded_df = exploded_df.withColumn(
"Month",
F.col("exploded_data").getItem("Month")
)
exploded_df = exploded_df.withColumn(
"Cluster",
F.col("exploded_data").getItem("Cluster")
)
exploded_df.select("Person", "Amount", "Budget", "Month", "Cluster").show(10, False)
+------+------------------------------+
|Person|Amount|Budget |Month|Cluster|
+------+------------------------------+
|Bob |562 |Food |June |1 |
|Bob |380 |Household|Sept |4 |
|Bob |880 |Food |Sept |2 |
+------+------------------------------+
You can then drop unnecessary columns. Hope this helps, good luck!
Upvotes: 12