Dave DeCaprio
Dave DeCaprio

Reputation: 2111

What's the most concise way to apply a filter to the elements of an array Column in Spark

I have a Spark DataFrame where one of my columns is an array of objects. I'd like to do an operation that filters that array. In my example below I have a parent who has children, and I'd like to get only the adult children.

import spark.implicits._

case class Child(name: String, age: Int)
case class Parent(name: String, children: Array[Child])

val rawData = Seq(Parent("Mom", Array(Child("Jane", 9))), Parent("Dad", Array(Child("Hubert", 28), Child("David", 27), Child("Jim", 25))))
val data = spark.createDataFrame(rawData)

The closest I have been able to come is:

val adultChildren = udf((children: mutable.WrappedArray[Child]) => {
  val rowArray = children.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
  val ret = rowArray.filter(c => c.getAs[Int]("age") > 18)
  ret.asInstanceOf[mutable.WrappedArray[Child]]
})
data.select(adultChildren($"children")).show()

This is somewhat annoying. I guess the advantage is that Spark spends less time (de)serializing objects, but it is verbose.

Is there a more concise way to do this?

Upvotes: 0

Views: 624

Answers (2)

Justin Pihony
Justin Pihony

Reputation: 67115

If you can use Datasets then it becomes really simple:

data.map(_.children.filter(_.age > 18).toList)

But if you are beholden to DataFrames:

data.select($"name", explode($"children").as("child"))
    .where($"child.age" > 18)
    .groupBy($"name").agg(collect_list($"child"))

Upvotes: 2

Dave DeCaprio
Dave DeCaprio

Reputation: 2111

One improvement is to encapsulate the boilerplate in a function:

import scala.reflect.runtime.universe._
def arrayFilterUDF[T: TypeTag](f: (GenericRowWithSchema) => Boolean) = udf((a: mutable.WrappedArray[T]) => {
    val rowArray = a.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
    rowArray.filter(f).asInstanceOf[mutable.WrappedArray[T]]
})

This allows you to write:

val adultChildren = arrayFilterUDF[Child](_.getAs[Int]("age") > 18)
data.select(adultChildren($"children")).show()    

Upvotes: 0

Related Questions