LucieCBurgess
LucieCBurgess

Reputation: 799

How to pivot on arbitrary column?

I use Apache Spark 2.2.0 and Scala.

I'm following the question as a guide to pivot a dataframe without using the pivot function.

I need to pivot the dataframe without using the pivot function as I have non-numerical data and pivot works with an aggregation function like sum, min, max on numerical data only. I've got a non-numerical column I'd like to use in pivot aggregation.

Here's my data:

+---+-------------+----------+-------------+----------+-------+
|Qid|     Question|AnswerText|ParticipantID|Assessment| GeoTag|
+---+-------------+----------+-------------+----------+-------+
|  1|Question1Text|       Yes|       abcde1|         0|(x1,y1)|
|  2|Question2Text|        No|       abcde1|         0|(x1,y1)|
|  3|Question3Text|         3|       abcde1|         0|(x1,y1)|
|  1|Question1Text|        No|       abcde2|         0|(x2,y2)|
|  2|Question2Text|       Yes|       abcde2|         0|(x2,y2)|
+---+-------------+----------+-------------+----------+-------+

I want it to group by ParticipantID, Assessment and GeoTag tags and "pivot" on Question column and take the values from AnswerText column. In the end, the output should look as follows:

+-------------+-----------+----------+-------+-----+----- +
|ParticipantID|Assessment |GeoTag    |Qid_1  |Qid_2|Qid_3 |
+-------------+-----------+----------+-------+-----+------+
|abcde1       |0          |(x1,y1)   |Yes    |No   |3     |
|abcde2       |0          |(x2,y2)   |No     |Yes  |null  |
+-------------+-----------+----------+-------+-----+------+

I have tried this:

val questions: Array[String] = df.select("Q_id")
      .distinct()
      .collect()
      .map(_.getAs[String]("Q_id"))
      .sortWith(_<_)

val df2: DataFrame = questions.foldLeft(df) {
      case (data, question) => data.selectExpr("*", s"IF(Q_id = '$question', AnswerText, 0) AS $question")
    }

[followed by a GroupBy expression]

But I'm getting the following error, which must be something to do with the syntax of the final statement AS $question

17/12/08 16:13:12 INFO SparkSqlParser: Parsing command: *
17/12/08 16:13:12 INFO SparkSqlParser: Parsing command: IF(Q_id_string_new_2 = '101_Who_is_with_you_right_now?', AnswerText, 0) AS 101_Who_is_with_you_right_now?


extraneous input '?' expecting <EOF>(line 1, pos 104)

== SQL ==
IF(Q_id_string_new_2 = '101_Who_is_with_you_right_now?', AnswerText, 0) AS 101_Who_is_with_you_right_now?
--------------------------------------------------------------------------------------------------------^^^

org.apache.spark.sql.catalyst.parser.ParseException: 
extraneous input '?' expecting <EOF>(line 1, pos 104)

== SQL ==
IF(Q_id_string_new_2 = '101_Who_is_with_you_right_now?', AnswerText, 0) AS 101_Who_is_with_you_right_now?
--------------------------------------------------------------------------------------------------------^^^

    at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:217)

Any ideas where I am going wrong? Is there a better way? I thought about reverting to Pandas and Python outside Spark if necessary, but I'd rather write all the code within the same framework if possible.

Upvotes: 2

Views: 905

Answers (2)

Alper t. Turker
Alper t. Turker

Reputation: 35249

As $question is substituting the value of the question variable into the SQL statement, you end up with a column name with '?' in it in SQL. ? is not a valid character in a column name so you have to at least use backticks to quote:

s"IF(Q_id = '$question', AnswerText, 0) AS `$question`"

or use select / withColumn:

import org.apache.spark.sql.functions.when

case (data, question) => 
  data.withColumn(question, when($"Q_id" === question, $"AnswerText"))

or santize strings first, using regexp_replace.

need to pivot the dataframe without using the pivot function as I have non-numerical data and df.pivot only works with an aggregation function like sum, min, max on numerical data.

You can use first: How to use pivot and calculate average on a non-numeric column (facing AnalysisException "is not a numeric column")?

data.groupBy($"ParticipantID", $"Assessment", $"GeoTag")
  .pivot($"Question", questions).agg(first($"AnswerText"))

Upvotes: 3

Jacek Laskowski
Jacek Laskowski

Reputation: 74789

Just a note to the accepted answer by @user8371915 to make the query a bit faster.


There is a way to avoid the expensive scan to generate questions with the headers.

Simply generate the headers (in the same job and stage!) followed by pivot on the column.

// It's a simple and cheap map-like transformation
val qid_header = input.withColumn("header", concat(lit("Qid_"), $"Qid"))
scala> qid_header.show
+---+-------------+----------+-------------+----------+-------+------+
|Qid|     Question|AnswerText|ParticipantID|Assessment| GeoTag|header|
+---+-------------+----------+-------------+----------+-------+------+
|  1|Question1Text|       Yes|       abcde1|         0|(x1,y1)| Qid_1|
|  2|Question2Text|        No|       abcde1|         0|(x1,y1)| Qid_2|
|  3|Question3Text|         3|       abcde1|         0|(x1,y1)| Qid_3|
|  1|Question1Text|        No|       abcde2|         0|(x2,y2)| Qid_1|
|  2|Question2Text|       Yes|       abcde2|         0|(x2,y2)| Qid_2|
+---+-------------+----------+-------------+----------+-------+------+

With the headers as part of the dataset, let's pivot.

val solution = qid_header
  .groupBy('ParticipantID, 'Assessment, 'GeoTag)
  .pivot('header)
  .agg(first('AnswerText))
scala> solution.show
+-------------+----------+-------+-----+-----+-----+
|ParticipantID|Assessment| GeoTag|Qid_1|Qid_2|Qid_3|
+-------------+----------+-------+-----+-----+-----+
|       abcde1|         0|(x1,y1)|  Yes|   No|    3|
|       abcde2|         0|(x2,y2)|   No|  Yes| null|
+-------------+----------+-------+-----+-----+-----+

Upvotes: 0

Related Questions