Paul Reiners
Paul Reiners

Reputation: 7894

Spark SQL DataFrame transformation involving partitioning and lagging

I want to transform a Spark SQL DataFrame like this:

animal value
------------
cat        8
cat        5
cat        6
dog        2
dog        4
dog        3
rat        7
rat        4 
rat        9

into a DataFrame like this:

animal value   previous-value
-----------------------------
cat        8                0
cat        5                8
cat        6                5
dog        2                0
dog        4                2
dog        3                4   
rat        7                0
rat        4                7
rat        9                4 

I sort of want to partition by animal, and then, for each animal, previous-value lags one row behind value (with a default value of 0), and then put the partitions back together again.

Upvotes: 0

Views: 329

Answers (2)

Shivansh
Shivansh

Reputation: 3544

This peice of code would work:

val df = spark.read.format("CSV").option("header","true").load("/home/shivansh/Desktop/foo.csv")
val df2 = df.groupBy("animal").agg(collect_list("value") as "listValue")
val desiredDF = df2.rdd.flatMap{row=>
        val animal=row.getAs[String]("animal")
        val valueList=row.getAs[Seq[String]]("listValue").toList
        val newlist=valueList zip "0"::valueList
        newlist.map(a=>(animal,a._1,a._2))
    }.toDF("animal","value","previousValue")

On the Spark shell:

scala> val df=spark.read.format("CSV").option("header","true").load("/home/shivansh/Desktop/foo.csv")
df: org.apache.spark.sql.DataFrame = [animal: string, value: string]

scala> df.show()
+------+-----+
|animal|value|
+------+-----+
|   cat|    8|
|   cat|    5|
|   cat|    6|
|   dog|    2|
|   dog|    4|
|   dog|    3|
|   rat|    7|
|   rat|   4 |
|   rat|    9|
+------+-----+


scala> val df2=df.groupBy("animal").agg(collect_list("value") as "listValue")
df2: org.apache.spark.sql.DataFrame = [animal: string, listValue: array<string>]

scala> df2.show()
+------+----------+
|animal| listValue|
+------+----------+
|   rat|[7, 4 , 9]|
|   dog| [2, 4, 3]|
|   cat| [8, 5, 6]|
+------+----------+


scala> val desiredDF=df2.rdd.flatMap{row=>
     | val animal=row.getAs[String]("animal")
     | val valueList=row.getAs[Seq[String]]("listValue").toList
     | val newlist=valueList zip "0"::valueList
     | newlist.map(a=>(animal,a._1,a._2))
     | }.toDF("animal","value","previousValue")
desiredDF: org.apache.spark.sql.DataFrame = [animal: string, value: string ... 1 more field]

scala> desiredDF.show()
+------+-----+-------------+                                                    
|animal|value|previousValue|
+------+-----+-------------+
|   rat|    7|            0|
|   rat|   4 |            7|
|   rat|    9|           4 |
|   dog|    2|            0|
|   dog|    4|            2|
|   dog|    3|            4|
|   cat|    8|            0|
|   cat|    5|            8|
|   cat|    6|            5|
+------+-----+-------------+

Upvotes: 1

sjstanley
sjstanley

Reputation: 36

This can be accomplished using a window function.

import org.apache.spark.sql.expressions.Window
import sqlContext.implicits._

val df = sc.parallelize(Seq(("cat", 8, "01:00"),("cat", 5, "02:00"),("cat", 6, "03:00"),("dog", 2, "02:00"),("dog", 4, "04:00"),("dog", 3, "06:00"),("rat", 7, "01:00"),("rat", 4, "03:00"),("rat", 9, "05:00"))).toDF("animal", "value", "time")

df.show
+------+-----+-----+
|animal|value| time|
+------+-----+-----+
|   cat|    8|01:00|
|   cat|    5|02:00|
|   cat|    6|03:00|
|   dog|    2|02:00|
|   dog|    4|04:00|
|   dog|    3|06:00|
|   rat|    7|01:00|
|   rat|    4|03:00|
|   rat|    9|05:00|
+------+-----+-----+

I've added a "time" field to illustrate orderBy.

val w1 = Window.partitionBy($"animal").orderBy($"time")

val previous_value = lag($"value", 1).over(w1)
val df1 = df.withColumn("previous", previous_value)

df1.show
+------+-----+-----+--------+                                                   
|animal|value| time|previous|
+------+-----+-----+--------+
|   dog|    2|02:00|    null|
|   dog|    4|04:00|       2|
|   dog|    3|06:00|       4|
|   cat|    8|01:00|    null|
|   cat|    5|02:00|       8|
|   cat|    6|03:00|       5|
|   rat|    7|01:00|    null|
|   rat|    4|03:00|       7|
|   rat|    9|05:00|       4|
+------+-----+-----+--------+

If you want to replace nulls with 0:

val df2 = df1.na.fill(0)
df2.show
+------+-----+-----+--------+
|animal|value| time|previous|
+------+-----+-----+--------+
|   dog|    2|02:00|       0|
|   dog|    4|04:00|       2|
|   dog|    3|06:00|       4|
|   cat|    8|01:00|       0|
|   cat|    5|02:00|       8|
|   cat|    6|03:00|       5|
|   rat|    7|01:00|       0|
|   rat|    4|03:00|       7|
|   rat|    9|05:00|       4|
+------+-----+-----+--------+

Upvotes: 2

Related Questions