Reputation: 610
I have a Spark dataframe like this:
+------+---------+---------+---------+---------+
| name | metric1 | metric2 | metric3 | metric4 |
+------+---------+---------+---------+---------+
| a | 1 | 2 | 3 | 4 |
| b | 1 | 2 | 3 | 4 |
| c | 3 | 1 | 5 | 4 |
| a | 3 | 3 | 3 | 3 |
+------+---------+---------+---------+---------+
For any duplicate names that appear, I want to replace the multiple rows with a single row containing nulls, so desired output is:
+------+---------+---------+---------+---------+
| name | metric1 | metric2 | metric3 | metric4 |
+------+---------+---------+---------+---------+
| a | null | null | null | null |
| b | 1 | 2 | 3 | 4 |
| c | 3 | 1 | 5 | 4 |
+------+---------+---------+---------+---------+
The following works:
import org.apache.spark.sql.functions._
val df = Seq(
("a", 1, 2, 3, 4), ("b", 1, 2, 3, 4), ("c", 3, 1, 5, 4), ("a", 3, 3, 3, 3)
).toDF("name", "metric1", "metric2", "metric3", "metric4")
val newDf = df
.groupBy(col("name"))
.agg(
min(col("metric1")).as("metric1"),
min(col("metric2")).as("metric2"),
min(col("metric3")).as("metric3"),
min(col("metric4")).as("metric4"),
count(col("name")).as("NumRecords")
)
.withColumn("metric1", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric1")))
.withColumn("metric2", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric2")))
.withColumn("metric3", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric3")))
.withColumn("metric4", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric4")))
.drop("NumRecords")
but surely there has got to be a better way...
Upvotes: 0
Views: 39
Reputation: 409
scala> val df = Seq(("a", 1, 2, 3, 4), ("b", 1, 2, 3, 4), ("c", 3, 1, 5, 4), ("a", 3, 3, 3, 3)).toDF("name", "metric1", "metric2", "metric3", "metric4")
scala> val newDf = df.groupBy(col("name")).agg(min(col("metric1")).as("metric1"),min(col("metric2")).as("metric2"),min(col("metric3")).as("metric3"),min(col("metric4")).as("metric4"),count(col("name")).as("NumRecords"))
scala> val colArr2 = df.columns.diff(Array("name"))
scala> val reqDF = colArr2.foldLeft(newDf){
(df,colName)=>
df.withColumn(colName,when(col("NumRecords") =!= "1",lit(null)).otherwise(col(colName)))
}.drop("NumRecords")
scala> reqDF.show
+----+-------+-------+-------+-------+
|name|metric1|metric2|metric3|metric4|
+----+-------+-------+-------+-------+
| c| 3| 1| 5| 4|
| b| 1| 2| 3| 4|
| a| null| null| null| null|
+----+-------+-------+-------+-------+
Please try like above.
Upvotes: 1