SH Y.
SH Y.

Reputation: 1749

List in the Case-When Statement in Spark SQL

I'm trying to convert a dataframe from long to wide as suggested at How to pivot DataFrame? However, the SQL seems to misinterpret the Countries list as a variable from the table. The below are the messages I saw from the console and the sample data and codes from the above link. Anyone knows how to resolve the issues?

Messages from the scala console:
scala> val myDF1 = sqlc2.sql(query)
org.apache.spark.sql.AnalysisException: cannot resolve 'US' given input columns >id, tag, value;

id  tag  value
1   US    50
1   UK    100
1   Can   125
2   US    75
2   UK    150
2   Can   175
and I want:

id  US  UK   Can
1   50  100  125
2   75  150  175
I can create a list with the value I want to pivot and then create a string containing the sql query I need.

val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1

var query = "select *, "
for (i <- 0 to numCountries-1) {
  query += "case when tag = " + countries(i) + " then value else 0 end as " + countries(i) + ", "
}
query += "case when tag = " + countries.last + " then value else 0 end as " + countries.last + " from myTable"

myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)

Upvotes: 1

Views: 4636

Answers (1)

zero323
zero323

Reputation: 330453

Country codes are literals and should be enclosed in quotes otherwise SQL parser will treat these as the names of the columns:

val caseClause = countries.map(
    x => s"""CASE WHEN tag = '$x' THEN value ELSE 0 END as $x"""
).mkString(", ")

val aggClause = countries.map(x => s"""SUM($x) AS $x""").mkString(", ")

val query = s"""
   SELECT id, $aggClause
   FROM (SELECT id, $caseClause FROM myTable) tmp
   GROUP BY id"""

sqlContext.sql(query)

Question is why even bother with building SQL strings from scratch?

def genCase(x: String) = {
  when($"tag" <=> lit(x), $"value").otherwise(0).alias(x)
}

def genAgg(f: Column => Column)(x: String) = f(col(x)).alias(x)

df
 .select($"id" :: countries.map(genCase): _*)
 .groupBy($"id")
 .agg($"id".alias("dummy"), countries.map(genAgg(sum)): _*)
 .drop("dummy")

Upvotes: 1

Related Questions