Reputation: 7894
I want to transform a Spark SQL DataFrame like this:
animal value
------------
cat 8
cat 5
cat 6
dog 2
dog 4
dog 3
rat 7
rat 4
rat 9
into a DataFrame like this:
animal value previous-value
-----------------------------
cat 8 0
cat 5 8
cat 6 5
dog 2 0
dog 4 2
dog 3 4
rat 7 0
rat 4 7
rat 9 4
I sort of want to partition by animal
, and then, for each animal
, previous-value
lags one row behind value
(with a default value of 0
), and then put the partitions back together again.
Upvotes: 0
Views: 329
Reputation: 3544
This peice of code would work:
val df = spark.read.format("CSV").option("header","true").load("/home/shivansh/Desktop/foo.csv")
val df2 = df.groupBy("animal").agg(collect_list("value") as "listValue")
val desiredDF = df2.rdd.flatMap{row=>
val animal=row.getAs[String]("animal")
val valueList=row.getAs[Seq[String]]("listValue").toList
val newlist=valueList zip "0"::valueList
newlist.map(a=>(animal,a._1,a._2))
}.toDF("animal","value","previousValue")
On the Spark shell:
scala> val df=spark.read.format("CSV").option("header","true").load("/home/shivansh/Desktop/foo.csv")
df: org.apache.spark.sql.DataFrame = [animal: string, value: string]
scala> df.show()
+------+-----+
|animal|value|
+------+-----+
| cat| 8|
| cat| 5|
| cat| 6|
| dog| 2|
| dog| 4|
| dog| 3|
| rat| 7|
| rat| 4 |
| rat| 9|
+------+-----+
scala> val df2=df.groupBy("animal").agg(collect_list("value") as "listValue")
df2: org.apache.spark.sql.DataFrame = [animal: string, listValue: array<string>]
scala> df2.show()
+------+----------+
|animal| listValue|
+------+----------+
| rat|[7, 4 , 9]|
| dog| [2, 4, 3]|
| cat| [8, 5, 6]|
+------+----------+
scala> val desiredDF=df2.rdd.flatMap{row=>
| val animal=row.getAs[String]("animal")
| val valueList=row.getAs[Seq[String]]("listValue").toList
| val newlist=valueList zip "0"::valueList
| newlist.map(a=>(animal,a._1,a._2))
| }.toDF("animal","value","previousValue")
desiredDF: org.apache.spark.sql.DataFrame = [animal: string, value: string ... 1 more field]
scala> desiredDF.show()
+------+-----+-------------+
|animal|value|previousValue|
+------+-----+-------------+
| rat| 7| 0|
| rat| 4 | 7|
| rat| 9| 4 |
| dog| 2| 0|
| dog| 4| 2|
| dog| 3| 4|
| cat| 8| 0|
| cat| 5| 8|
| cat| 6| 5|
+------+-----+-------------+
Upvotes: 1
Reputation: 36
This can be accomplished using a window function.
import org.apache.spark.sql.expressions.Window
import sqlContext.implicits._
val df = sc.parallelize(Seq(("cat", 8, "01:00"),("cat", 5, "02:00"),("cat", 6, "03:00"),("dog", 2, "02:00"),("dog", 4, "04:00"),("dog", 3, "06:00"),("rat", 7, "01:00"),("rat", 4, "03:00"),("rat", 9, "05:00"))).toDF("animal", "value", "time")
df.show
+------+-----+-----+
|animal|value| time|
+------+-----+-----+
| cat| 8|01:00|
| cat| 5|02:00|
| cat| 6|03:00|
| dog| 2|02:00|
| dog| 4|04:00|
| dog| 3|06:00|
| rat| 7|01:00|
| rat| 4|03:00|
| rat| 9|05:00|
+------+-----+-----+
I've added a "time" field to illustrate orderBy.
val w1 = Window.partitionBy($"animal").orderBy($"time")
val previous_value = lag($"value", 1).over(w1)
val df1 = df.withColumn("previous", previous_value)
df1.show
+------+-----+-----+--------+
|animal|value| time|previous|
+------+-----+-----+--------+
| dog| 2|02:00| null|
| dog| 4|04:00| 2|
| dog| 3|06:00| 4|
| cat| 8|01:00| null|
| cat| 5|02:00| 8|
| cat| 6|03:00| 5|
| rat| 7|01:00| null|
| rat| 4|03:00| 7|
| rat| 9|05:00| 4|
+------+-----+-----+--------+
If you want to replace nulls with 0:
val df2 = df1.na.fill(0)
df2.show
+------+-----+-----+--------+
|animal|value| time|previous|
+------+-----+-----+--------+
| dog| 2|02:00| 0|
| dog| 4|04:00| 2|
| dog| 3|06:00| 4|
| cat| 8|01:00| 0|
| cat| 5|02:00| 8|
| cat| 6|03:00| 5|
| rat| 7|01:00| 0|
| rat| 4|03:00| 7|
| rat| 9|05:00| 4|
+------+-----+-----+--------+
Upvotes: 2