Reputation: 505
I have a spark dataframe with a array type column:
scala> mydf.printSchema
root
|-- arraycol: array (nullable = true)
| |-- element: string (containsNull = true)
|-- id: integer (nullable = false)
I now need to aggregate this dataframe by "id" and count based on whether a particular value exists in the array. I was trying to do this:
val aggdata = mydf.groupBy("id").
agg(
count(when($"arraycol" contains "someval", $"arraycol")).as("aggval"))
That doesn't seem to work. Any inputs how I can do this?
Upvotes: 1
Views: 543
Reputation: 214927
There's the array_contains
method to test the condition:
val df = Seq((1, Seq("a", "b")), (1, Seq("b")), (2, Seq("b"))).toDF("id", "arrayCol")
// df: org.apache.spark.sql.DataFrame = [id: int, arrayCol: array<string>]
df.show
+---+--------+
| id|arrayCol|
+---+--------+
| 1| [a, b]|
| 1| [b]|
| 2| [b]|
+---+--------+
df.groupBy("id").agg(
count(when(array_contains($"arrayCol", "a"), $"arrayCol")).as("hasA")
).show
+---+----+
| id|hasA|
+---+----+
| 1| 1|
| 2| 0|
+---+----+
Or use sum
:
df.groupBy("id").agg(
sum(when(array_contains($"arrayCol", "a"), 1).otherwise(0)).as("hasA")
).show
+---+----+
| id|hasA|
+---+----+
| 1| 1|
| 2| 0|
+---+----+
Upvotes: 2