shammery
shammery

Reputation: 1072

How to create a new column with average value of another column in pyspark

I have a dataset which looks like this

from pyspark.sql.types import StructType,StructField, StringType, IntegerType
data2 = [("James","","Smith","36636","M",3000),
    ("Michael","Rose","","40288","M",4000),
    ("Robert","","Williams","42114","M",4000),
    ("Maria","Anne","Jones","39192","F",4000),
    ("Jen","Mary","Brown","","F",-1)
  ]

schema = StructType([ \
    StructField("firstname",StringType(),True), \
    StructField("middlename",StringType(),True), \
    StructField("lastname",StringType(),True), \
    StructField("id", StringType(), True), \
    StructField("gender", StringType(), True), \
    StructField("salary", IntegerType(), True) \
  ])
 
df = spark.createDataFrame(data=data2,schema=schema)
df.show()
---------+----------+--------+-----+------+------+
|firstname|middlename|lastname|   id|gender|salary|
+---------+----------+--------+-----+------+------+
|    James|          |   Smith|36636|     M|  3000|
|  Michael|      Rose|        |40288|     M|  4000|
|   Robert|          |Williams|42114|     M|  4000|
|    Maria|      Anne|   Jones|39192|     F|  4000|
|      Jen|      Mary|   Brown|     |     F|    -1|
+---------+----------+--------+-----+------+------+

I want to calculate the average salary and create a new column avg_salary. I am able to do something like this

(
    df
    .withColumn("id",F.lit(1))
    .join(
        df
        .agg(
            F.avg(F.col("salary")).alias("avg_salary")
        )
        .withColumn("id",F.lit(1)),
        on=["id"]
    )
    .drop("id")
).show()
+---------+----------+--------+------+------+----------+
|firstname|middlename|lastname|gender|salary|avg_salary|
+---------+----------+--------+------+------+----------+
|    James|          |   Smith|     M|  3000|    2999.8|
|  Michael|      Rose|        |     M|  4000|    2999.8|
|   Robert|          |Williams|     M|  4000|    2999.8|
|    Maria|      Anne|   Jones|     F|  4000|    2999.8|
|      Jen|      Mary|   Brown|     F|    -1|    2999.8|
+---------+----------+--------+------+------+----------+

Even though I get the desired output but i wanted to know if there is a nicer way to do this? Any help would be highly appreciated. Thanks

Upvotes: 0

Views: 895

Answers (1)

Luiz Viola
Luiz Viola

Reputation: 2436

You can use a Window Function:

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy()

df\
    .withColumn('avg_salary', F.avg(F.col('salary')).over(w))\
    .show()

# +---------+----------+--------+-----+------+------+----------+
# |firstname|middlename|lastname|   id|gender|salary|avg_salary|
# +---------+----------+--------+-----+------+------+----------+
# |    James|          |   Smith|36636|     M|  3000|    2999.8|
# |  Michael|      Rose|        |40288|     M|  4000|    2999.8|
# |   Robert|          |Williams|42114|     M|  4000|    2999.8|
# |    Maria|      Anne|   Jones|39192|     F|  4000|    2999.8|
# |      Jen|      Mary|   Brown|     |     F|    -1|    2999.8|
# +---------+----------+--------+-----+------+------+----------+

You can also use it to do other analysis. For exemple, if you want the average by gender, than you can do partitionBy('gender'):

+---------+----------+--------+-----+------+------+------------------+
|firstname|middlename|lastname|   id|gender|salary|        avg_salary|
+---------+----------+--------+-----+------+------+------------------+
|    Maria|      Anne|   Jones|39192|     F|  4000|            1999.5|
|      Jen|      Mary|   Brown|     |     F|    -1|            1999.5|
|    James|          |   Smith|36636|     M|  3000|3666.6666666666665|
|  Michael|      Rose|        |40288|     M|  4000|3666.6666666666665|
|   Robert|          |Williams|42114|     M|  4000|3666.6666666666665|
+---------+----------+--------+-----+------+------+------------------+

Upvotes: 2

Related Questions