Ivan
Ivan

Reputation: 20101

Fill Pyspark dataframe column null values with average value from same column

With a dataframe like this,

rdd_2 = sc.parallelize([(0,10,223,"201601"), (0,10,83,"2016032"),(1,20,None,"201602"),(1,20,3003,"201601"), (1,20,None,"201603"), (2,40, 2321,"201601"), (2,30, 10,"201602"),(2,61, None,"201601")])

df_data = sqlContext.createDataFrame(rdd_2, ["id", "type", "cost", "date"])
df_data.show()

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|null| 201602|
|  1|  20|3003| 201601|
|  1|  20|null| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|null| 201601|
+---+----+----+-------+

I need to fill the null values with the average of the existing values, with the expected result being

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

where 1128 is the average of the existing values. I need to do that for several columns.

My current approach is to use na.fill:

fill_values = {column: df_data.agg({column:"mean"}).flatMap(list).collect()[0] for column in df_data.columns if column not in ['date','id']}
df_data = df_data.na.fill(fill_values)

+---+----+----+-------+
| id|type|cost|   date|
+---+----+----+-------+
|  0|  10| 223| 201601|
|  0|  10|  83|2016032|
|  1|  20|1128| 201602|
|  1|  20|3003| 201601|
|  1|  20|1128| 201603|
|  2|  40|2321| 201601|
|  2|  30|  10| 201602|
|  2|  61|1128| 201601|
+---+----+----+-------+

But this is very cumbersome. Any ideas?

Upvotes: 10

Views: 23802

Answers (1)

zero323
zero323

Reputation: 330063

Well, one way or another you have to:

  • compute statistics
  • fill the blanks

It pretty much limits what you can really improve here, still:

  • replace flatMap(list).collect()[0] with first()[0] or structure unpacking
  • compute all stats with a single action
  • use built-in Row methods to extract dictionary

The final result could like this:

def fill_with_mean(df, exclude=set()): 
    stats = df.agg(*(
        avg(c).alias(c) for c in df.columns if c not in exclude
    ))
    return df.na.fill(stats.first().asDict())

fill_with_mean(df_data, ["id", "date"])

In Spark 2.2 or later you can also use Imputer. See Replace missing values with mean - Spark Dataframe.

Upvotes: 17

Related Questions