collarblind
collarblind

Reputation: 4739

Change to empty array if another column is false

I am trying to create a dataframe that returns an empty array for a nested struct type if another column is false. I created a dummy dataframe to illustrate my problem.

import spark.implicits._

val newDf = spark.createDataFrame(Seq(
  ("user1","true", Some(8), Some("usd"), Some("tx1")),
  ("user1", "true", Some(9), Some("usd"), Some("tx2")),
  ("user2", "false", None, None, None))).toDF("userId","flag", "amount", "currency", "transactionId")


val amountStruct = struct("amount"
                          ,"currency").alias("amount")

val transactionStruct = struct("transactionId"
                               , "amount").alias("transactions")

val dataStruct = struct("flag","transactions").alias("data")


val finalDf = newDf.
withColumn("amount", amountStruct).
withColumn("transactions", transactionStruct).
select("userId", "flag","transactions").
groupBy("userId", "flag").
agg(collect_list("transactions").alias("transactions")).
withColumn("data", dataStruct).
drop("transactions","flag")

This is the output:

+------+--------------------+
|userId|                data|
+------+--------------------+
| user2|  [false, [[, [,]]]]|
| user1|[true, [[tx1, [8,...|
+------+--------------------+

and schema:

root
 |-- userId: string (nullable = true)
 |-- data: struct (nullable = false)
 |    |-- flag: string (nullable = true)
 |    |-- transactions: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- transactionId: string (nullable = true)
 |    |    |    |-- amount: struct (nullable = false)
 |    |    |    |    |-- amount: integer (nullable = true)
 |    |    |    |    |-- currency: string (nullable = true)

The output I want:

+------+--------------------+
|userId|                data|
+------+--------------------+
| user2|  [false, []]       |
| user1|[true, [[tx1, [8,...|
+------+--------------------+

I've tried doing this before doing collect_list but no luck.

import org.apache.spark.sql.functions.typedLit

val emptyArray = typedLit(Array.empty[(String, Array[(Int, String)])])

testDf.withColumn("transactions", when($"flag" === "false", emptyArray).otherwise($"transactions")).show()

Upvotes: 0

Views: 73

Answers (1)

Ihor Kaharlichenko
Ihor Kaharlichenko

Reputation: 6260

You were moments from victory. The approach with collect_list is the way to go, it just needs a little nudge.

TL;DR Solution

val newDf = spark
  .createDataFrame(
    Seq(
      ("user1",  "true", Some(8), Some("usd"), Some("tx1")),
      ("user1",  "true", Some(9), Some("usd"), Some("tx2")),
      ("user2", "false", None,    None,        None)
    )
  )
  .toDF("userId", "flag", "amount", "currency", "transactionId")

val dataStruct = struct("flag", "transactions")

val finalDf2 = newDf
  .groupBy("userId", "flag")
  .agg(
    collect_list(
      when(
        $"transactionId".isNotNull && $"amount".isNotNull && $"currency".isNotNull,
        struct(
          $"transactionId",
          struct($"amount", $"currency").alias("amount")
        )
      )).alias("transactions"))
  .withColumn("data", dataStruct)
  .drop("transactions", "flag")

Explanation

SQL Aggregate Function Behavior

First of all, when it comes to behavior Spark follows SQL conventions. All the SQL aggregation functions (and collect_list is an aggregate function) ignore NULL on input as if it never was there.

Let's take a look at how does collect_list behave:

Seq(
  ("a", Some(1)),
  ("a", Option.empty[Int]),
  ("a", Some(3)),
  ("b", Some(10)),
  ("b", Some(20)),
  ("b", Option.empty[Int])
)
  .toDF("col1", "col2")
  .groupBy($"col1")
  .agg(collect_list($"col2") as "col2_list")
  .show()

And the result is:

+----+---------+
|col1|col2_list|
+----+---------+
|   b| [10, 20]|
|   a|   [1, 3]|
+----+---------+

Tracking Down Nullability

It looks like collect_list behaves properly. So the reason you are seeing those blanks in your output is that the column that gets passed to the collect_list is not nullable.

To prove it let's examine the schema of the DataFrame just before it gets aggregated:

newDf
  .withColumn("amount", amountStruct)
  .withColumn("transactions", transactionStruct)
  .printSchema()
root
 |-- userId: string (nullable = true)
 |-- flag: string (nullable = true)
 |-- amount: struct (nullable = false)
 |    |-- amount: integer (nullable = true)
 |    |-- currency: string (nullable = true)
 |-- currency: string (nullable = true)
 |-- transactionId: string (nullable = true)
 |-- transactions: struct (nullable = false)
 |    |-- transactionId: string (nullable = true)
 |    |-- amount: struct (nullable = false)
 |    |    |-- amount: integer (nullable = true)
 |    |    |-- currency: string (nullable = true)

Note the transactions: struct (nullable = false) part. It proves the suspicion.

If we translate all the nested NULLables to Scala here's what you got:

case class Row(
    transactions: Transactions,
    // other fields
)

case class Transactions(
    transactionId: Option[String],
    amount: Option[Amount],
)

case class Amount(
    amount: Option[Int],
    currency: Option[String]
)

And here's what you want instead:

case class Row(
    transactions: Option[Transactions], // this is optional now
    // other fields
)

case class Transactions(
    transactionId: String,              // while this is not optional
    amount: Amount,                     // neither is this
)

case class Amount(
    amount: Int,                        // neither is this
    currency: String                    // neither is this
)

Fixing the Nullability

Now the last step is simple. To make the column that is the input to collect_list "properly" nullable you have to check the nullability of all the amount, currency and transactionId columns.

The result will be NOT NULL if and only if all the input columns are NOT NULL.

You can use the same when API method to construct the result. The otherwise clause if omitted implicitly returns NULL which is exactly what you need.

Upvotes: 1

Related Questions