Reputation: 2158
I have a DataFrame of the form:
A_DF = |id_A: Int|concatCSV: String|
and another one:
B_DF = |id_B: Int|triplet: List[String]|
Examples of concatCSV
could look like:
"StringD, StringB, StringF, StringE, StringZ"
"StringA, StringB, StringX, StringY, StringZ"
...
while a triplet
is something like:
("StringA", "StringF", "StringZ")
("StringB", "StringU", "StringR")
...
I want to produce the cartesian set of A_DF
and B_DF
, e.g.;
| id_A: Int | concatCSV: String | id_B: Int | triplet: List[String] |
| 14 | "StringD, StringB, StringF, StringE, StringZ" | 21 | ("StringA", "StringF", "StringZ")|
| 14 | "StringD, StringB, StringF, StringE, StringZ" | 45 | ("StringB", "StringU", "StringR")|
| 18 | "StringA, StringB, StringX, StringY, StringG" | 21 | ("StringA", "StringF", "StringZ")|
| 18 | "StringA, StringB, StringX, StringY, StringG" | 45 | ("StringB", "StringU", "StringR")|
| ... | | | |
Then keep just the records that have at least two substrings (e.g StringA, StringB
) from A_DF("concatCSV")
that appear in B_DF("triplet")
, i.e. use filter
to exclude those that don't satisfy this condition.
First question is: can I do this without converting the DFs into RDDs?
Second question is: can I ideally do the whole thing in the join
step--as a where
condition?
I have tried experimenting with something like:
val cartesianRDD = A_DF
.join(B_DF,"right")
.where($"triplet".exists($"concatCSV".contains(_)))
but where
cannot be resolved. I tried it with filter
instead of where
but still no luck. Also, for some strange reason, type annotation for cartesianRDD
is SchemaRDD
and not DataFrame
. How did I end up with that? Finally, what I am trying above (the short code I wrote) is incomplete as it would keep records with just one substring from concatCSV
found in triplet
.
So, third question is: Should I just change to RDDs and solve it with a custom filtering function?
Finally, last question: Can I use a custom filtering function with DataFrames?
Thanks for the help.
Upvotes: 2
Views: 7762
Reputation: 24178
The function CROSS JOIN
is implemented in Hive
, so you could first do the cross-join using Hive SQL
:
A_DF.registerTempTable("a")
B_DF.registerTempTable("b")
// sqlContext should be really a HiveContext
val result = sqlContext.sql("SELECT * FROM a CROSS JOIN b")
Then you can filter down to your expected output using two udf
's. One that converts your string to an array of words, and a second one that gives us the length of the intersection of the resulting array column and the existing column "triplet"
:
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.functions.col
val splitArr = udf { (s: String) => s.split(",").map(_.trim) }
val commonLen = udf { (a: WrappedArray[String],
b: WrappedArray[String]) => a.intersect(b).length }
val temp = (result.withColumn("concatArr",
splitArr(col("concatCSV"))).select(col("*"),
commonLen(col("triplet"), col("concatArr")).alias("comm"))
.filter(col("comm") >= 2)
.drop("comm")
.drop("concatArr"))
temp.show
+----+--------------------+----+--------------------+
|id_A| concatCSV|id_B| triplet|
+----+--------------------+----+--------------------+
| 14|StringD, StringB,...| 21|[StringA, StringF...|
| 18|StringA, StringB,...| 21|[StringA, StringF...|
+----+--------------------+----+--------------------+
Upvotes: 3