S.K
S.K

Reputation: 367

Apache Spark - Finding Array/List/Set subsets

I have 2 dataframes each one having Array[String] as one of the columns. For each entry in one dataframe, I need to find out subsets, if any, in the other dataframe. An example is here:

DF1:

----------------------------------------------------
           id : Long    |   labels : Array[String]
---------------------------------------------------
         10             |    [label1, label2, label3]
         11             |    [label4, label5]
         12             |    [label6, label7]

DF2:

----------------------------------------------------
         item : String |   labels : Array[String]
---------------------------------------------------
         item1         |   [label1, label2, label3, label4, label5]
         item2         |   [label4, label5]
         item3         |   [label4, label5, label6, label7]

After the subset operation I described, the expected o/p should be

DF3:

----------------------------------------------------
         item : String |   id : Long
---------------------------------------------------
         item1         |   [10, 11]
         item2         |   [11]
         item3         |   [11, 12]

It is guaranteed that the DF2, will always have corresponding subsets in DF1, so there won't be any left over elements.

Can someone please help with the right approach here ? It looks like for each element in DF2, I need to scan DF1 and do subset operation (or set subtraction) on the 2nd column until I find all the subsets and exhaust the labels in that row and while doing that accumulate the list of "id" fields. How do I do this in compact and efficient manner ? Any help is greatly appreciated. Realistically, I may have 100s of elements in DF1 and 1000s of elements in DF2.

Upvotes: 1

Views: 1464

Answers (1)

Shaido
Shaido

Reputation: 28392

I'm not aware of any way to perform this kind of operation in an efficient way. However, here is one possible solution using UDF as well as Cartesian join.

The UDF takes two sequences and checks if all strings in the first exists in the second:

val matchLabel = udf((array1: Seq[String], array2: Seq[String]) => {
  array1.forall{x => array2.contains(x)}
})

To use Cartesian join, it needs to be enabled as it is computationally expensive.

val spark = SparkSession.builder.getOrCreate()
spark.conf.set("spark.sql.crossJoin.enabled", true)

The two dataframes are joined together utilizing the UDF. Afterwards the resulting dataframe is grouped by the item column to collect a list of all ids. Using the same DF1 and DF2 as in the question:

val DF3 = DF2.join(DF1, matchLabel(DF1("labels"), DF2("labels")))
  .groupBy("item")
  .agg(collect_list("id").as("id"))

The result is as follows:

+-----+--------+
| item|      id|
+-----+--------+
|item3|[11, 12]|
|item2|    [11]|
|item1|[10, 11]|
+-----+--------+

Upvotes: 0

Related Questions