Koushik Chandra
Koushik Chandra

Reputation: 1491

SPARK code for sql case statement and row_number equivalent

I have a data set like below

hduser@ubuntu:~$ hadoop fs -cat /user/hduser/test_sample/sample1.txt
Eid1,EName1,EDept1,100
Eid2,EName2,EDept1,102
Eid3,EName3,EDept1,101
Eid4,EName4,EDept2,110
Eid5,EName5,EDept2,121
Eid6,EName6,EDept3,99

I want to generate the output as below using spark code

Eid1,EName1,IT,102,1
Eid2,EName2,IT,101,2
Eid3,EName3,IT,100,3
Eid4,EName4,ComSc,121,1
Eid5,EName5,ComSc,110,2
Eid6,EName6,Mech,99,1

which is equivalent of the below SQL

Select emp_id, emp_name, case when emp_dept='EDept1' then 'IT' when emp_dept='EDept2' then 'ComSc' when emp_dept='EDept3' then 'Mech' end dept_name, emp_sal, row_number() over (partition by emp_dept order by emp_sal desc) as rn from emp

Can someone suggest how should I get that in spark.

Upvotes: 0

Views: 974

Answers (1)

David Griffin
David Griffin

Reputation: 13927

You can use RDD.zipWithIndex, then convert it to a DataFrame, then use min() and join to get the results you want.

Like this:

import org.apache.spark.sql._
import org.apache.spark.sql.types._

// SORT BY added as per comment request
val test = sc.textFile("/user/hadoop/test.txt")
  .sortBy(_.split(",")(2)).sortBy(_.split(",")(3).toInt)

// Table to hold the dept name lookups
val deptDF = 
  sc.parallelize(Array(("EDept1", "IT"),("EDept2", "ComSc"),("EDept3", "Mech")))
  .toDF("deptCode", "dept")

val schema = StructType(Array(
  StructField("col1", StringType, false),
  StructField("col2", StringType, false),
  StructField("col3", StringType, false),
  StructField("col4", StringType, false),
  StructField("col5", LongType, false))
)

// join to deptDF added as per comment
val testDF = sqlContext.createDataFrame(
  test.zipWithIndex.map(tuple => Row.fromSeq(tuple._1.split(",") ++ Array(tuple._2))),
  schema
)
.join(deptDF, $"col3" === $"deptCode")
.select($"col1", $"col2", $"dept" as "col3", $"col4", $"col5")
.orderBy($"col5")

testDF.show

col1 col2   col3  col4 col5
Eid1 EName1 IT    100  0
Eid3 EName3 IT    101  1
Eid2 EName2 IT    102  2
Eid4 EName4 ComSc 110  3
Eid5 EName5 ComSc 121  4
Eid6 EName6 Mech  99   5

val result = testDF.join(
  testDF.groupBy($"col3").agg($"col3" as "g_col3", min($"col5") as "start"),
  $"col3" === $"g_col3"
)
.select($"col1", $"col2", $"col3", $"col4", $"col5" - $"start" + 1 as "index")

result.show

col1 col2   col3   col4 index
Eid4 EName4 ComSc 110  1
Eid5 EName5 ComSc 121  2
Eid6 EName6 Mech  99   1
Eid1 EName1 IT    100  1
Eid3 EName3 IT    101  2
Eid2 EName2 IT    102  3

Upvotes: 1

Related Questions