Reputation: 1072
I have a dataset which looks like this
from pyspark.sql.types import StructType,StructField, StringType, IntegerType
data2 = [("James","","Smith","36636","M",3000),
("Michael","Rose","","40288","M",4000),
("Robert","","Williams","42114","M",4000),
("Maria","Anne","Jones","39192","F",4000),
("Jen","Mary","Brown","","F",-1)
]
schema = StructType([ \
StructField("firstname",StringType(),True), \
StructField("middlename",StringType(),True), \
StructField("lastname",StringType(),True), \
StructField("id", StringType(), True), \
StructField("gender", StringType(), True), \
StructField("salary", IntegerType(), True) \
])
df = spark.createDataFrame(data=data2,schema=schema)
df.show()
---------+----------+--------+-----+------+------+ |firstname|middlename|lastname| id|gender|salary| +---------+----------+--------+-----+------+------+ | James| | Smith|36636| M| 3000| | Michael| Rose| |40288| M| 4000| | Robert| |Williams|42114| M| 4000| | Maria| Anne| Jones|39192| F| 4000| | Jen| Mary| Brown| | F| -1| +---------+----------+--------+-----+------+------+
I want to calculate the average salary and create a new column avg_salary. I am able to do something like this
(
df
.withColumn("id",F.lit(1))
.join(
df
.agg(
F.avg(F.col("salary")).alias("avg_salary")
)
.withColumn("id",F.lit(1)),
on=["id"]
)
.drop("id")
).show()
+---------+----------+--------+------+------+----------+ |firstname|middlename|lastname|gender|salary|avg_salary| +---------+----------+--------+------+------+----------+ | James| | Smith| M| 3000| 2999.8| | Michael| Rose| | M| 4000| 2999.8| | Robert| |Williams| M| 4000| 2999.8| | Maria| Anne| Jones| F| 4000| 2999.8| | Jen| Mary| Brown| F| -1| 2999.8| +---------+----------+--------+------+------+----------+
Even though I get the desired output but i wanted to know if there is a nicer way to do this? Any help would be highly appreciated. Thanks
Upvotes: 0
Views: 895
Reputation: 2436
You can use a Window Function:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy()
df\
.withColumn('avg_salary', F.avg(F.col('salary')).over(w))\
.show()
# +---------+----------+--------+-----+------+------+----------+
# |firstname|middlename|lastname| id|gender|salary|avg_salary|
# +---------+----------+--------+-----+------+------+----------+
# | James| | Smith|36636| M| 3000| 2999.8|
# | Michael| Rose| |40288| M| 4000| 2999.8|
# | Robert| |Williams|42114| M| 4000| 2999.8|
# | Maria| Anne| Jones|39192| F| 4000| 2999.8|
# | Jen| Mary| Brown| | F| -1| 2999.8|
# +---------+----------+--------+-----+------+------+----------+
You can also use it to do other analysis. For exemple, if you want the average by gender, than you can do partitionBy('gender')
:
+---------+----------+--------+-----+------+------+------------------+
|firstname|middlename|lastname| id|gender|salary| avg_salary|
+---------+----------+--------+-----+------+------+------------------+
| Maria| Anne| Jones|39192| F| 4000| 1999.5|
| Jen| Mary| Brown| | F| -1| 1999.5|
| James| | Smith|36636| M| 3000|3666.6666666666665|
| Michael| Rose| |40288| M| 4000|3666.6666666666665|
| Robert| |Williams|42114| M| 4000|3666.6666666666665|
+---------+----------+--------+-----+------+------+------------------+
Upvotes: 2