Reputation: 65
I'm trying to count the number of elements in FavouriteCities
column in the following DataFrame.
+-----------------+
| FavouriteCities |
+-----------------+
| [NY, Canada] |
+-----------------+
The schema is as follows:
scala> data.printSchema
root
|-- FavouriteCities: array (nullable = true)
| |-- element: string (containsNull = true)
Expected output should be something like,
+------------+-------------+
| City | Count |
+------------+-------------+
| NY | 1 |
| Canada | 1 |
+------------+-------------+
I have tried using the agg()
and count()
but like the following, but it fails to extract individual elements from the array and tries to find the most common set of elements in the column.
data.agg(count("FavouriteCities").alias("count"))
Can someone please guide me with this?
Upvotes: 0
Views: 4084
Reputation: 36
To match schema you've shown:
scala> val data = Seq(Tuple1(Array("NY", "Canada"))).toDF("FavouriteCities")
data: org.apache.spark.sql.DataFrame = [FavouriteCities: array<string>]
scala> data.printSchema
root
|-- FavouriteCities: array (nullable = true)
| |-- element: string (containsNull = true)
Explode:
val counts = data
.select(explode($"FavouriteCities" as "City"))
.groupBy("City")
.count
and aggregate:
import spark.implicits._
scala> counts.as[(String, Long)].reduce((a, b) => if (a._2 > b._2) a else b)
res3: (String, Long) = (Canada,1)
Upvotes: 2