Reputation: 1752
In scala with spark-2.4, I would like to filter the value inside the arrays in a column.
From
+---+------------+
| id| letter|
+---+------------+
| 1|[x, xxx, xx]|
| 2|[yy, y, yyy]|
+---+------------+
To
+---+-------+
| id| letter|
+---+-------+
| 1|[x, xx]|
| 2|[yy, y]|
+---+-------+
I thought of using explode
+ filter
val res = Seq(("1", Array("x", "xxx", "xx")), ("2", Array("yy", "y", "yyy"))).toDF("id", "letter")
res.withColumn("tmp", explode(col("letter"))).filter(length(col("tmp")) < 3).drop(col("letter")).show()
And I'm getting
+---+---+
| id|tmp|
+---+---+
| 1| x|
| 1| xx|
| 2| yy|
| 2| y|
+---+---+
How do I zip/groupBy back by id ?
Or maybe there is a better, more optimised solution ?
Upvotes: 1
Views: 1930
Reputation: 27373
In Spark 2.4+, higher order functions are the way to go (filter
), alternatively use collect_list
:
res.withColumn("tmp",explode(col("letter")))
.filter(length(col("tmp")) < 3)
.drop(col("letter"))
// aggregate back
.groupBy($"id")
.agg(collect_list($"tmp").as("letter"))
.show()
gives:
+---+-------+
| id| letter|
+---+-------+
| 1|[x, xx]|
| 2|[yy, y]|
+---+-------+
As this introduces a shuffle, it's better to use UDF for that:
def filter_arr(maxLength:Int)= udf((arr:Seq[String]) => arr.filter(str => str.size<=maxLength))
res
.select($"id",filter_arr(maxLength = 2)($"letter").as("letter"))
.show()
gives:
+---+-------+
| id| letter|
+---+-------+
| 1|[x, xx]|
| 2|[yy, y]|
+---+-------+
Upvotes: 2
Reputation: 361
You can filter the array without explode()
in Spark 2.4:
res.withColumn("letter", expr("filter(letter, x -> length(x) < 3)")).show()
Output:
+---+-------+
| id| letter|
+---+-------+
| 1|[x, xx]|
| 2|[yy, y]|
+---+-------+
Upvotes: 6