kush
kush

Reputation: 173

Replace Null Values of a Column with mean of another Categorcial Column in Spark Dataframe

I have a dataset like this

id    category     value
1     A            NaN
2     B            NaN
3     A            10.5
5     A            2.0
6     B            1.0

I want to fill the NAN values with the mean of their respective category. As shown below

id    category     value
1     A            4.16
2     B            0.5
3     A            10.5
5     A            2.0
6     B            1.0

I tried to calculate first mean values of each category using group by

val df2 = dataFrame.groupBy(category).agg(mean(value)).rdd.map{
      case r:Row => (r.getAs[String](category),r.get(1))
    }.collect().toMap
    println(df2)

I got map of each category and their respective mean values.output: Map(A ->4.16,B->0.5) Now i tried update query in Sparksql to fill column but it seems spqrkSql dosnt support update query. I tried to fill null values with in dataframe but failed to do so. What can i do? We can do the same in pandas as shown in Pandas: How to fill null values with mean of a groupby? But how can i do using spark dataframe

Upvotes: 1

Views: 4099

Answers (3)

nileshg
nileshg

Reputation: 23

I stumbled upon same problem and came across this post. But tried a different solution i.e. using window functions. The code below is tested on pyspark 2.4.3 (Window functions are available from Spark 1.4). I believe this is bit cleaner solution. This post is quiet old, but hope this answer will be helpful for others.

from pyspark.sql import Window
from pyspark.sql.functions import *

df = spark.createDataFrame([(1,"A", None), (2,"B", None), (3,"A",10.5), (5,"A",2.0), (6,"B",1.0)], ['id', 'category', 'value'])

category_window = Window.partitionBy("category")
value_mean = mean("value0").over(category_window)

result = df\
  .withColumn("value0", coalesce("value", lit(0)))\
  .withColumn("value_mean", value_mean)\
  .withColumn("new_value", coalesce("value", "value_mean"))\
  .select("id", "category", "new_value")

result.show()

Output will be as expected (in question):

id  category    new_value       
1   A   4.166666666666667
2   B   0.5
3   A   10.5
5   A   2
6   B   1

Upvotes: 0

Assaf Mendelson
Assaf Mendelson

Reputation: 13001

The simplest solution would be to use groupby and join:

 val df2 = df.filter(!(isnan($"value"))).groupBy("category").agg(avg($"value").as("avg"))
 df.join(df2, "category").withColumn("value", when(col("value").isNaN, $"avg").otherwise($"value")).drop("avg")

Note that if there is a category with all NaN it will be removed from the result

Upvotes: 3

Tzach Zohar
Tzach Zohar

Reputation: 37832

Indeed, you cannot update DataFrames, but you can transform them using functions like select and join. In this case, you can keep the grouping result as a DataFrame and join it (on category column) to the original one, then perform the mapping that would replace NaNs with the mean values:

import org.apache.spark.sql.functions._
import spark.implicits._

// calculate mean per category:
val meanPerCategory = dataFrame.groupBy("category").agg(mean("value") as "mean")

// use join, select and "nanvl" function to replace NaNs with the mean values:
val result = dataFrame
  .join(meanPerCategory, "category")
  .select($"category", $"id", nanvl($"value", $"mean")).show()

Upvotes: 2

Related Questions