Reputation: 1591
I have a DataFrame like below. I need to create a new column based on existing columns.
col1 col2
a 1
a 2
b 1
c 1
d 1
d 2
Output Data Frame look like this
col1 col2 col3 col4
a 1 1 2
a 2 1 2
b 1 0 1
c 1 0 1
d 1 1 2
d 2 1 2
The logic I have used to find col3 is if count of col1 > 1 and col4 is max value of col2.
I am familiar with how to do it in sql . But it's hard to find solution with dataframe DSL. Any help would be appreciated. Thanks
Upvotes: 5
Views: 15257
Reputation: 713
To achieve this without a join, you need to use count
and max
as window functions. This requires creating a window using Window
and telling count
and max
t operate over this window.
from pyspark.sql import Window, functions as fn
df = sc.parallelize([
{'col1': 'a', 'col2': 1},
{'col1': 'a', 'col2': 2},
{'col1': 'b', 'col2': 1},
{'col1': 'c', 'col2': 1},
{'col1': 'd', 'col2': 1},
{'col1': 'd', 'col2': 2}
]).toDF()
col1_window = Window.partitionBy('col1')
df = df.withColumn('col3', fn.when(fn.count('col1').over(col1_window) > 1, 1).otherwise(0))
df = df.withColumn('col4', fn.max('col2').over(col1_window))
df.orderBy(['col1', 'col2']).show()
Upvotes: 2
Reputation: 3069
To add col3 you can use withcolumn + when/otherwise :
val df2 = df.withColumn("col3",when($"col2" > 1, 1).otherwise(0))
To add col4 the groupBy/max + join already mentionned should do the job :
val df3 = df2.join(df.groupBy("col1").max("col2"), "col1")
Upvotes: 2
Reputation: 451
spark df has property called withColumn You can add as many derived columns as you want. But the column is not added to existing DF instead it create a new DF with added column.
e.g. Adding a static date to the data
val myFormattedData = myData.withColumn("batchdate",addBatchDate(myData("batchdate")))
val addBatchDate = udf { (BatchDate: String) => "20160101" }
Upvotes: 2
Reputation: 1327
groupBy col1 and aggregate to get count and max. Then you can join it back with original dataframe to get your desired result
val df2 = df1.groupBy("col1").agg(count() as col3, max("col2") as col4)
val df3 = df1.join(df2, "col1")
Upvotes: 6