Reputation: 173
I have a dataset like this
id category value
1 A NaN
2 B NaN
3 A 10.5
5 A 2.0
6 B 1.0
I want to fill the NAN values with the mean of their respective category. As shown below
id category value
1 A 4.16
2 B 0.5
3 A 10.5
5 A 2.0
6 B 1.0
I tried to calculate first mean values of each category using group by
val df2 = dataFrame.groupBy(category).agg(mean(value)).rdd.map{
case r:Row => (r.getAs[String](category),r.get(1))
}.collect().toMap
println(df2)
I got map of each category and their respective mean values.output: Map(A ->4.16,B->0.5)
Now i tried update query in Sparksql to fill column but it seems spqrkSql dosnt support update query. I tried to fill null values with in dataframe but failed to do so.
What can i do? We can do the same in pandas as shown in Pandas: How to fill null values with mean of a groupby?
But how can i do using spark dataframe
Upvotes: 1
Views: 4099
Reputation: 23
I stumbled upon same problem and came across this post. But tried a different solution i.e. using window functions. The code below is tested on pyspark 2.4.3 (Window functions are available from Spark 1.4). I believe this is bit cleaner solution. This post is quiet old, but hope this answer will be helpful for others.
from pyspark.sql import Window
from pyspark.sql.functions import *
df = spark.createDataFrame([(1,"A", None), (2,"B", None), (3,"A",10.5), (5,"A",2.0), (6,"B",1.0)], ['id', 'category', 'value'])
category_window = Window.partitionBy("category")
value_mean = mean("value0").over(category_window)
result = df\
.withColumn("value0", coalesce("value", lit(0)))\
.withColumn("value_mean", value_mean)\
.withColumn("new_value", coalesce("value", "value_mean"))\
.select("id", "category", "new_value")
result.show()
Output will be as expected (in question):
id category new_value
1 A 4.166666666666667
2 B 0.5
3 A 10.5
5 A 2
6 B 1
Upvotes: 0
Reputation: 13001
The simplest solution would be to use groupby and join:
val df2 = df.filter(!(isnan($"value"))).groupBy("category").agg(avg($"value").as("avg"))
df.join(df2, "category").withColumn("value", when(col("value").isNaN, $"avg").otherwise($"value")).drop("avg")
Note that if there is a category with all NaN it will be removed from the result
Upvotes: 3
Reputation: 37832
Indeed, you cannot update DataFrames, but you can transform them using functions like select
and join
. In this case, you can keep the grouping result as a DataFrame
and join it (on category
column) to the original one, then perform the mapping that would replace NaN
s with the mean values:
import org.apache.spark.sql.functions._
import spark.implicits._
// calculate mean per category:
val meanPerCategory = dataFrame.groupBy("category").agg(mean("value") as "mean")
// use join, select and "nanvl" function to replace NaNs with the mean values:
val result = dataFrame
.join(meanPerCategory, "category")
.select($"category", $"id", nanvl($"value", $"mean")).show()
Upvotes: 2