Reputation: 4739
I am trying to create a dataframe that returns an empty array for a nested struct type if another column is false. I created a dummy dataframe to illustrate my problem.
import spark.implicits._
val newDf = spark.createDataFrame(Seq(
("user1","true", Some(8), Some("usd"), Some("tx1")),
("user1", "true", Some(9), Some("usd"), Some("tx2")),
("user2", "false", None, None, None))).toDF("userId","flag", "amount", "currency", "transactionId")
val amountStruct = struct("amount"
,"currency").alias("amount")
val transactionStruct = struct("transactionId"
, "amount").alias("transactions")
val dataStruct = struct("flag","transactions").alias("data")
val finalDf = newDf.
withColumn("amount", amountStruct).
withColumn("transactions", transactionStruct).
select("userId", "flag","transactions").
groupBy("userId", "flag").
agg(collect_list("transactions").alias("transactions")).
withColumn("data", dataStruct).
drop("transactions","flag")
This is the output:
+------+--------------------+
|userId| data|
+------+--------------------+
| user2| [false, [[, [,]]]]|
| user1|[true, [[tx1, [8,...|
+------+--------------------+
and schema:
root
|-- userId: string (nullable = true)
|-- data: struct (nullable = false)
| |-- flag: string (nullable = true)
| |-- transactions: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- transactionId: string (nullable = true)
| | | |-- amount: struct (nullable = false)
| | | | |-- amount: integer (nullable = true)
| | | | |-- currency: string (nullable = true)
The output I want:
+------+--------------------+
|userId| data|
+------+--------------------+
| user2| [false, []] |
| user1|[true, [[tx1, [8,...|
+------+--------------------+
I've tried doing this before doing collect_list
but no luck.
import org.apache.spark.sql.functions.typedLit
val emptyArray = typedLit(Array.empty[(String, Array[(Int, String)])])
testDf.withColumn("transactions", when($"flag" === "false", emptyArray).otherwise($"transactions")).show()
Upvotes: 0
Views: 73
Reputation: 6260
You were moments from victory. The approach with collect_list
is the way to go, it just needs a little nudge.
val newDf = spark
.createDataFrame(
Seq(
("user1", "true", Some(8), Some("usd"), Some("tx1")),
("user1", "true", Some(9), Some("usd"), Some("tx2")),
("user2", "false", None, None, None)
)
)
.toDF("userId", "flag", "amount", "currency", "transactionId")
val dataStruct = struct("flag", "transactions")
val finalDf2 = newDf
.groupBy("userId", "flag")
.agg(
collect_list(
when(
$"transactionId".isNotNull && $"amount".isNotNull && $"currency".isNotNull,
struct(
$"transactionId",
struct($"amount", $"currency").alias("amount")
)
)).alias("transactions"))
.withColumn("data", dataStruct)
.drop("transactions", "flag")
First of all, when it comes to behavior Spark follows SQL conventions. All the SQL aggregation functions (and collect_list
is an aggregate function) ignore NULL on input as if it never was there.
Let's take a look at how does collect_list
behave:
Seq(
("a", Some(1)),
("a", Option.empty[Int]),
("a", Some(3)),
("b", Some(10)),
("b", Some(20)),
("b", Option.empty[Int])
)
.toDF("col1", "col2")
.groupBy($"col1")
.agg(collect_list($"col2") as "col2_list")
.show()
And the result is:
+----+---------+
|col1|col2_list|
+----+---------+
| b| [10, 20]|
| a| [1, 3]|
+----+---------+
It looks like collect_list
behaves properly. So the reason you are seeing those blanks in your output is that the column that gets passed to the collect_list
is not nullable.
To prove it let's examine the schema of the DataFrame just before it gets aggregated:
newDf
.withColumn("amount", amountStruct)
.withColumn("transactions", transactionStruct)
.printSchema()
root
|-- userId: string (nullable = true)
|-- flag: string (nullable = true)
|-- amount: struct (nullable = false)
| |-- amount: integer (nullable = true)
| |-- currency: string (nullable = true)
|-- currency: string (nullable = true)
|-- transactionId: string (nullable = true)
|-- transactions: struct (nullable = false)
| |-- transactionId: string (nullable = true)
| |-- amount: struct (nullable = false)
| | |-- amount: integer (nullable = true)
| | |-- currency: string (nullable = true)
Note the transactions: struct (nullable = false)
part. It proves the suspicion.
If we translate all the nested NULLables to Scala here's what you got:
case class Row(
transactions: Transactions,
// other fields
)
case class Transactions(
transactionId: Option[String],
amount: Option[Amount],
)
case class Amount(
amount: Option[Int],
currency: Option[String]
)
And here's what you want instead:
case class Row(
transactions: Option[Transactions], // this is optional now
// other fields
)
case class Transactions(
transactionId: String, // while this is not optional
amount: Amount, // neither is this
)
case class Amount(
amount: Int, // neither is this
currency: String // neither is this
)
Now the last step is simple. To make the column that is the input to collect_list
"properly" nullable you have to check the nullability of all the amount
, currency
and transactionId
columns.
The result will be NOT NULL
if and only if all the input columns are NOT NULL
.
You can use the same when
API method to construct the result. The otherwise
clause if omitted implicitly returns NULL
which is exactly what you need.
Upvotes: 1