jamiet
jamiet

Reputation: 12334

Recursively create a Spark filter with multiple predicates in Scala

I have a dataframe containing a bunch of values

val df = List(
  (2017, 1, 1234),
  (2017, 2, 1234),
  (2017, 3, 1234),
  (2017, 4, 1234),
  (2018, 1, 12345),
  (2018, 2, 12346),
  (2018, 3, 12347),
  (2018, 4, 12348)
).toDF("year", "month", "employeeCount")

df: org.apache.spark.sql.DataFrame = [year: int, month: int, employeeCount: int]

I want to filter that dataframe by a list of (year, month) pairs:

val filterValues = List((2018, 1), (2018, 2))

I can easily cheat and write the code that achieves it:

df.filter(
  (col("year") === 2018 && col("month") === 1) || 
  (col("year") === 2018 && col("month") === 2)
).show

but of course that's not satisfactory because filterValues could change, and I want to base it on whatever is in that list.

Is it possible to dynamically build my filter_expression and then pass it to df.filter(filter_expression)? I can't figure out how.

Upvotes: 1

Views: 898

Answers (3)

Raphael Roth
Raphael Roth

Reputation: 27373

you can build up your filter_expression like this :

val df = List(
  (2017, 1, 1234),
  (2017, 2, 1234),
  (2017, 3, 1234),
  (2017, 4, 1234),
  (2018, 1, 12345),
  (2018, 2, 12346),
  (2018, 3, 12347),
  (2018, 4, 12348)
).toDF("year", "month", "employeeCount")

val filterValues = List((2018, 1), (2018, 2))

val filter_expession = filterValues
  .map{case (y,m) => col("year") === y and col("month") === m}
  .reduce(_ || _)

df
  .filter(filter_expession)
  .show()

+----+-----+-------------+
|year|month|employeeCount|
+----+-----+-------------+
|2018|    1|        12345|
|2018|    2|        12346|
+----+-----+-------------+

Upvotes: 0

Alper t. Turker
Alper t. Turker

Reputation: 35249

Based on your comment:

imagine someone calling this from the command-line with something like --filterColumns "year,month" --filterValues "2018|1,2018|2"

val filterValues = "2018|1,2018|2"
val filterColumns = "year,month"

you can get a list of columns

val colnames = filterColumns.split(',')

Convert data to a local Dataset (add schema when needed):

val filter = spark.read.option("delimiter", "|")
  .csv(filterValues.split(',').toSeq.toDS)
  .toDF(colnames: _*)

and semi join:

df.join(filter, colnames, "left_semi").show
// +----+-----+-------------+             
// |year|month|employeeCount|
// +----+-----+-------------+
// |2018|    1|        12345|
// |2018|    2|        12346|
// +----+-----+-------------+

Expression like this one should work as well:

import org.apache.spark.sql.functions._

val pred = filterValues
  .split(",")
  .map(x => colnames.zip(x.split('|'))
                    .map { case (c, v) => col(c) === v }
                    .reduce(_ && _))
  .reduce(_ || _)

df.where(pred).show
// +----+-----+-------------+
// |year|month|employeeCount|
// +----+-----+-------------+
// |2018|    1|        12345|
// |2018|    2|        12346|
// +----+-----+-------------+

but will require more work if some type casting is required.

Upvotes: 4

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

You can always do that using a udf function as

val filterValues = List((2018, 1), (2018, 2))

import org.apache.spark.sql.functions._
def filterUdf = udf((year:Int, month:Int) => filterValues.exists(x => x._1 == year && x._2 == month))

df.filter(filterUdf(col("year"), col("month"))).show(false)

Updated

You commented as

I mean that the list of columns to filter on (and the corresponding list of respective values) would be supplied from elsewhere at runtime.

for that you will have list of column names provided too, so the solution would be something like below

val filterValues = List((2018, 1), (2018, 2))
val filterColumns = List("year", "month")

import org.apache.spark.sql.functions._
def filterUdf = udf((unknown: Seq[Int]) => filterValues.exists(x => !x.productIterator.toList.zip(unknown).map(y => y._1 == y._2).contains(false)))

df.filter(filterUdf(array(filterColumns.map(col): _*))).show(false)

Upvotes: 1

Related Questions