yAsH
yAsH

Reputation: 3405

Spark RDD - Replacing the missing columns with the average of other columns

I have an RDD like below

RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, , 2), (003, 4, , 9, 2), (003, 2, 3, 0, 1) )

the first column is the contract id(001 and 003), I need to group the records with similar contract ids and find the average of all the columns other than the contract id and substitute the missing columns with the average of those columns related to that contract id.

So, the final output would be

RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, (1+3)/3 , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, (5+9+0)/4 , 2), (003, 4, (4+4+3)/4 , 9, 2), (003, 2, 3, 0, 1) )

I did the groupByKey using contract id as the key and I got stuck after that. I really appreciate any suggestions.

Upvotes: 0

Views: 365

Answers (2)

stack0114106
stack0114106

Reputation: 8711

This can also be achieved using the Window functions in sql without using any joins. Check this out:

val df = Seq(
  ("001", Some(1), Some(0), Some(3), Some(4)),
  ("001", Some(3), Some(4), Some(1), Some(7)),
  ("001", None, Some(0), Some(6), Some(4)),
  ("003", Some(1), Some(4), Some(5), Some(7)),
  ("003", Some(5), Some(4), None, Some(2)),
  ("003", Some(4), None, Some(9), Some(2)),
  ("003", Some(2), Some(3), Some(0), Some(1))
).toDF("a","b","c","d","e")
df.show(false)
df.createOrReplaceTempView("avg_temp")

spark.sql("""  select a, coalesce(b,sum(b) over(partition by a)/count(*) over(partition by a)) b1, coalesce( c, sum(c) over(partition by a)/count(*) over(partition by a)) c1,
            coalesce( d, sum(d) over(partition by a)/count(*) over(partition by a)) d1, coalesce( e, sum(e) over(partition by a)/count(*) over(partition by a)) e1 from avg_temp
""").show(false)

Results:

+---+----+----+----+---+
|a  |b   |c   |d   |e  |
+---+----+----+----+---+
|001|1   |0   |3   |4  |
|001|3   |4   |1   |7  |
|001|null|0   |6   |4  |
|003|1   |4   |5   |7  |
|003|5   |4   |null|2  |
|003|4   |null|9   |2  |
|003|2   |3   |0   |1  |
+---+----+----+----+---+
+---+------------------+----+---+---+
|a  |b1                |c1  |d1 |e1 |
+---+------------------+----+---+---+
|003|1.0               |4.0 |5.0|7.0|
|003|5.0               |4.0 |3.5|2.0|
|003|4.0               |2.75|9.0|2.0|
|003|2.0               |3.0 |0.0|1.0|
|001|1.0               |0.0 |3.0|4.0|
|001|3.0               |4.0 |1.0|7.0|
|001|1.3333333333333333|0.0 |6.0|4.0|
+---+------------------+----+---+---+

Upvotes: 1

clay
clay

Reputation: 20370

// Create the exact input data provided as a Spark DataFrame/DataSet
val df = {
  import org.apache.spark.sql._
  import org.apache.spark.sql.types._
  import scala.collection.JavaConverters._

  val simpleSchema = StructType(
    StructField("a", StringType) ::
    StructField("b", IntegerType) ::
    StructField("c", IntegerType) ::
    StructField("d", IntegerType) ::
    StructField("e", IntegerType) :: Nil)

  val data = List(
    Row("001", 1, 0, 3, 4),
    Row("001", 3, 4, 1, 7),
    Row("001", null, 0, 6, 4),
    Row("003", 1, 4, 5, 7),
    Row("003", 5, 4, null, 2),
    Row("003", 4, null, 9, 2),
    Row("003", 2, 3, 0, 1)
  )

  spark.createDataFrame(data.asJava, simpleSchema)
}

// fill replaces nulls with zero, which we need for the desired averaging.    
val avgs = df.na.fill(0).groupBy(col("a")).avg("b", "c", "d", "e").as("avgs")

val output = df.as("df").join(avgs, col("df.a") === col("avgs.a")).select(col("df.a"),
  coalesce(col("df.b"), col("avg(b)")),
  coalesce(col("df.c"), col("avg(c)")),
  coalesce(col("df.d"), col("avg(d)")),
  coalesce(col("df.e"), col("avg(e)"))
  )

scala> df.show()
+---+----+----+----+---+
|  a|   b|   c|   d|  e|
+---+----+----+----+---+
|001|   1|   0|   3|  4|
|001|   3|   4|   1|  7|
|001|null|   0|   6|  4|
|003|   1|   4|   5|  7|
|003|   5|   4|null|  2|
|003|   4|null|   9|  2|
|003|   2|   3|   0|  1|
+---+----+----+----+---+


scala> avgs.show()
+---+------------------+------------------+------------------+------+
|  a|            avg(b)|            avg(c)|            avg(d)|avg(e)|
+---+------------------+------------------+------------------+------+
|003|               3.0|              2.75|               3.5|   3.0|
|001|1.3333333333333333|1.3333333333333333|3.3333333333333335|   5.0|
+---+------------------+------------------+------------------+------+


scala> output.show()
+---+----------------------+----------------------+----------------------+----------------------+
|  a|coalesce(df.b, avg(b))|coalesce(df.c, avg(c))|coalesce(df.d, avg(d))|coalesce(df.e, avg(e))|
+---+----------------------+----------------------+----------------------+----------------------+
|001|                   1.0|                   0.0|                   3.0|                   4.0|
|001|                   3.0|                   4.0|                   1.0|                   7.0|
|001|    1.3333333333333333|                   0.0|                   6.0|                   4.0|
|003|                   1.0|                   4.0|                   5.0|                   7.0|
|003|                   5.0|                   4.0|                   3.5|                   2.0|
|003|                   4.0|                  2.75|                   9.0|                   2.0|
|003|                   2.0|                   3.0|                   0.0|                   1.0|
+---+----------------------+----------------------+----------------------+----------------------+

Upvotes: 2

Related Questions