Reputation: 3405
I have an RDD like below
RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, , 2), (003, 4, , 9, 2), (003, 2, 3, 0, 1) )
the first column is the contract id(001 and 003), I need to group the records with similar contract ids and find the average of all the columns other than the contract id and substitute the missing columns with the average of those columns related to that contract id.
So, the final output would be
RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, (1+3)/3 , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, (5+9+0)/4 , 2), (003, 4, (4+4+3)/4 , 9, 2), (003, 2, 3, 0, 1) )
I did the groupByKey using contract id as the key and I got stuck after that. I really appreciate any suggestions.
Upvotes: 0
Views: 365
Reputation: 8711
This can also be achieved using the Window functions in sql without using any joins. Check this out:
val df = Seq(
("001", Some(1), Some(0), Some(3), Some(4)),
("001", Some(3), Some(4), Some(1), Some(7)),
("001", None, Some(0), Some(6), Some(4)),
("003", Some(1), Some(4), Some(5), Some(7)),
("003", Some(5), Some(4), None, Some(2)),
("003", Some(4), None, Some(9), Some(2)),
("003", Some(2), Some(3), Some(0), Some(1))
).toDF("a","b","c","d","e")
df.show(false)
df.createOrReplaceTempView("avg_temp")
spark.sql(""" select a, coalesce(b,sum(b) over(partition by a)/count(*) over(partition by a)) b1, coalesce( c, sum(c) over(partition by a)/count(*) over(partition by a)) c1,
coalesce( d, sum(d) over(partition by a)/count(*) over(partition by a)) d1, coalesce( e, sum(e) over(partition by a)/count(*) over(partition by a)) e1 from avg_temp
""").show(false)
Results:
+---+----+----+----+---+
|a |b |c |d |e |
+---+----+----+----+---+
|001|1 |0 |3 |4 |
|001|3 |4 |1 |7 |
|001|null|0 |6 |4 |
|003|1 |4 |5 |7 |
|003|5 |4 |null|2 |
|003|4 |null|9 |2 |
|003|2 |3 |0 |1 |
+---+----+----+----+---+
+---+------------------+----+---+---+
|a |b1 |c1 |d1 |e1 |
+---+------------------+----+---+---+
|003|1.0 |4.0 |5.0|7.0|
|003|5.0 |4.0 |3.5|2.0|
|003|4.0 |2.75|9.0|2.0|
|003|2.0 |3.0 |0.0|1.0|
|001|1.0 |0.0 |3.0|4.0|
|001|3.0 |4.0 |1.0|7.0|
|001|1.3333333333333333|0.0 |6.0|4.0|
+---+------------------+----+---+---+
Upvotes: 1
Reputation: 20370
// Create the exact input data provided as a Spark DataFrame/DataSet
val df = {
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
val simpleSchema = StructType(
StructField("a", StringType) ::
StructField("b", IntegerType) ::
StructField("c", IntegerType) ::
StructField("d", IntegerType) ::
StructField("e", IntegerType) :: Nil)
val data = List(
Row("001", 1, 0, 3, 4),
Row("001", 3, 4, 1, 7),
Row("001", null, 0, 6, 4),
Row("003", 1, 4, 5, 7),
Row("003", 5, 4, null, 2),
Row("003", 4, null, 9, 2),
Row("003", 2, 3, 0, 1)
)
spark.createDataFrame(data.asJava, simpleSchema)
}
// fill replaces nulls with zero, which we need for the desired averaging.
val avgs = df.na.fill(0).groupBy(col("a")).avg("b", "c", "d", "e").as("avgs")
val output = df.as("df").join(avgs, col("df.a") === col("avgs.a")).select(col("df.a"),
coalesce(col("df.b"), col("avg(b)")),
coalesce(col("df.c"), col("avg(c)")),
coalesce(col("df.d"), col("avg(d)")),
coalesce(col("df.e"), col("avg(e)"))
)
scala> df.show()
+---+----+----+----+---+
| a| b| c| d| e|
+---+----+----+----+---+
|001| 1| 0| 3| 4|
|001| 3| 4| 1| 7|
|001|null| 0| 6| 4|
|003| 1| 4| 5| 7|
|003| 5| 4|null| 2|
|003| 4|null| 9| 2|
|003| 2| 3| 0| 1|
+---+----+----+----+---+
scala> avgs.show()
+---+------------------+------------------+------------------+------+
| a| avg(b)| avg(c)| avg(d)|avg(e)|
+---+------------------+------------------+------------------+------+
|003| 3.0| 2.75| 3.5| 3.0|
|001|1.3333333333333333|1.3333333333333333|3.3333333333333335| 5.0|
+---+------------------+------------------+------------------+------+
scala> output.show()
+---+----------------------+----------------------+----------------------+----------------------+
| a|coalesce(df.b, avg(b))|coalesce(df.c, avg(c))|coalesce(df.d, avg(d))|coalesce(df.e, avg(e))|
+---+----------------------+----------------------+----------------------+----------------------+
|001| 1.0| 0.0| 3.0| 4.0|
|001| 3.0| 4.0| 1.0| 7.0|
|001| 1.3333333333333333| 0.0| 6.0| 4.0|
|003| 1.0| 4.0| 5.0| 7.0|
|003| 5.0| 4.0| 3.5| 2.0|
|003| 4.0| 2.75| 9.0| 2.0|
|003| 2.0| 3.0| 0.0| 1.0|
+---+----------------------+----------------------+----------------------+----------------------+
Upvotes: 2