John
John

Reputation: 1591

Adding new Columns based on aggregation on existing column in Spark DataFrame using scala

I have a DataFrame like below. I need to create a new column based on existing columns.

col1 col2
a      1
a      2
b      1
c      1
d      1
d      2

Output Data Frame look like this

col1  col2 col3 col4
a      1   1      2
a      2   1      2
b      1   0      1
c      1   0      1
d      1   1      2
d      2   1      2

The logic I have used to find col3 is if count of col1 > 1 and col4 is max value of col2.

I am familiar with how to do it in sql . But it's hard to find solution with dataframe DSL. Any help would be appreciated. Thanks

Upvotes: 5

Views: 15257

Answers (4)

Milad Shahidi
Milad Shahidi

Reputation: 713

To achieve this without a join, you need to use count and max as window functions. This requires creating a window using Window and telling count and max t operate over this window.

from pyspark.sql import Window, functions as fn

df = sc.parallelize([
    {'col1': 'a', 'col2': 1},
    {'col1': 'a', 'col2': 2},
    {'col1': 'b', 'col2': 1},
    {'col1': 'c', 'col2': 1},
    {'col1': 'd', 'col2': 1},
    {'col1': 'd', 'col2': 2}
]).toDF()

col1_window = Window.partitionBy('col1')
df = df.withColumn('col3', fn.when(fn.count('col1').over(col1_window) > 1, 1).otherwise(0))
df = df.withColumn('col4', fn.max('col2').over(col1_window))
df.orderBy(['col1', 'col2']).show()

Upvotes: 2

Fabich
Fabich

Reputation: 3069

To add col3 you can use withcolumn + when/otherwise :

val df2 = df.withColumn("col3",when($"col2" > 1, 1).otherwise(0))

To add col4 the groupBy/max + join already mentionned should do the job :

val df3 = df2.join(df.groupBy("col1").max("col2"), "col1")

Upvotes: 2

Hari
Hari

Reputation: 451

spark df has property called withColumn You can add as many derived columns as you want. But the column is not added to existing DF instead it create a new DF with added column.

e.g. Adding a static date to the data

val myFormattedData = myData.withColumn("batchdate",addBatchDate(myData("batchdate")))
val addBatchDate = udf { (BatchDate: String) => "20160101" }

Upvotes: 2

Ashish Awasthi
Ashish Awasthi

Reputation: 1327

groupBy col1 and aggregate to get count and max. Then you can join it back with original dataframe to get your desired result

val df2 = df1.groupBy("col1").agg(count() as col3, max("col2") as col4) 

val df3 = df1.join(df2, "col1")

Upvotes: 6

Related Questions