Christos Hadjinikolis
Christos Hadjinikolis

Reputation: 2158

How to apply a custom filtering function on a Spark DataFrame

I have a DataFrame of the form:

A_DF = |id_A: Int|concatCSV: String|

and another one:

B_DF = |id_B: Int|triplet: List[String]|

Examples of concatCSV could look like:

"StringD, StringB, StringF, StringE, StringZ"
"StringA, StringB, StringX, StringY, StringZ"
...

while a triplet is something like:

("StringA", "StringF", "StringZ")
("StringB", "StringU", "StringR")
...

I want to produce the cartesian set of A_DF and B_DF, e.g.;

| id_A: Int | concatCSV: String                             | id_B: Int | triplet: List[String]            |
|     14    | "StringD, StringB, StringF, StringE, StringZ" |     21    | ("StringA", "StringF", "StringZ")|
|     14    | "StringD, StringB, StringF, StringE, StringZ" |     45    | ("StringB", "StringU", "StringR")|
|     18    | "StringA, StringB, StringX, StringY, StringG" |     21    | ("StringA", "StringF", "StringZ")|
|     18    | "StringA, StringB, StringX, StringY, StringG" |     45    | ("StringB", "StringU", "StringR")|
|    ...    |                                               |           |                                  |

Then keep just the records that have at least two substrings (e.g StringA, StringB) from A_DF("concatCSV") that appear in B_DF("triplet"), i.e. use filter to exclude those that don't satisfy this condition.

First question is: can I do this without converting the DFs into RDDs?

Second question is: can I ideally do the whole thing in the join step--as a where condition?

I have tried experimenting with something like:

val cartesianRDD = A_DF
   .join(B_DF,"right")
   .where($"triplet".exists($"concatCSV".contains(_)))

but where cannot be resolved. I tried it with filter instead of where but still no luck. Also, for some strange reason, type annotation for cartesianRDD is SchemaRDD and not DataFrame. How did I end up with that? Finally, what I am trying above (the short code I wrote) is incomplete as it would keep records with just one substring from concatCSV found in triplet.

So, third question is: Should I just change to RDDs and solve it with a custom filtering function?

Finally, last question: Can I use a custom filtering function with DataFrames?

Thanks for the help.

Upvotes: 2

Views: 7762

Answers (1)

mtoto
mtoto

Reputation: 24178

The function CROSS JOIN is implemented in Hive, so you could first do the cross-join using Hive SQL:

A_DF.registerTempTable("a")
B_DF.registerTempTable("b")

// sqlContext should be really a HiveContext
val result = sqlContext.sql("SELECT * FROM a CROSS JOIN b") 

Then you can filter down to your expected output using two udf's. One that converts your string to an array of words, and a second one that gives us the length of the intersection of the resulting array column and the existing column "triplet":

import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col

val splitArr = udf { (s: String) => s.split(",").map(_.trim) }
val commonLen = udf { (a: WrappedArray[String], 
                       b: WrappedArray[String]) => a.intersect(b).length }

val temp = (result.withColumn("concatArr",
  splitArr(col("concatCSV"))).select(col("*"),
  commonLen(col("triplet"), col("concatArr")).alias("comm"))
  .filter(col("comm") >= 2)
  .drop("comm")
  .drop("concatArr"))

temp.show
+----+--------------------+----+--------------------+
|id_A|           concatCSV|id_B|             triplet|
+----+--------------------+----+--------------------+
|  14|StringD, StringB,...|  21|[StringA, StringF...|
|  18|StringA, StringB,...|  21|[StringA, StringF...|
+----+--------------------+----+--------------------+

Upvotes: 3

Related Questions