artyomboyko
artyomboyko

Reputation: 2861

Spark merge sets of common elements

I have a DataFrame that looks like this:

+-----------+-----------+
|  Package  | Addresses |
+-----------+-----------+
| Package 1 | address1  |
| Package 1 | address2  |
| Package 1 | address3  |
| Package 2 | address3  |
| Package 2 | address4  |
| Package 2 | address5  |
| Package 2 | address6  |
| Package 3 | address7  |
| Package 3 | address8  |
| Package 4 | address9  |
| Package 5 | address9  |
| Package 5 | address1  |
| Package 6 | address10 |
| Package 7 | address8  |
+-----------+-----------+

I need to find all the addresses that were seen together across different packages. Example output:

+----+------------------------------------------------------------------------+
| Id |                               Addresses                                |
+----+------------------------------------------------------------------------+
|  1 | [address1, address2, address3, address4, address5, address6, address9] |
|  2 | [address7, address8]                                                   |
|  3 | [address10]                                                            |
+----+------------------------------------------------------------------------+

So, I have DataFrame. I'm grouping it by package (instead of grouping):

val rdd = packages.select($"package", $"address").
  map{
    x => {
      (x(0).toString(), x(1).toString())
    }
  }.rdd.combineByKey(
  (source) => {
    Set[String](source)
  },

  (acc: Set[String], v) => {
    acc + v
  },

  (acc1: Set[String], acc2: Set[String]) => {
    acc1 ++ acc2
  }
)

Then, I'm merging rows that have common addresses:

val result = rdd.treeAggregate(
  Set.empty[Set[String]]
)(
  (map: Set[Set[String]], row) => {
    val vals = row._2
    val sets = map + vals

    // copy-paste from here https://stackoverflow.com/a/25623014/772249
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => {
      val (hasCommon, rest) = cum.partition(_ & cur nonEmpty)
      rest + (cur ++ hasCommon.flatten)
    })
  },
  (map1, map2) => {
    val sets = map1 ++ map2

    // copy-paste from here https://stackoverflow.com/a/25623014/772249
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => {
      val (hasCommon, rest) = cum.partition(_ & cur nonEmpty)
      rest + (cur ++ hasCommon.flatten)
    })
  },
  10
)

But, no matter what I do, treeAggregate are taking very long, and I can't finish single task. Raw data size is about 250gb. I've tried different clusters, but treeAggregate is taking too long.

Everything before treeAggregate works good, but it's stuch after that.

I've tried different spark.sql.shuffle.partitions (default, 2000, 10000), but it doesn't seems to matter.

I've tried different depth for treeAggregate, but didn't noticed the difference.

Related questions:

  1. Merge Sets of Sets that contain common elements in Scala
  2. Spark complex grouping

Upvotes: 0

Views: 1071

Answers (1)

Nazarii Bardiuk
Nazarii Bardiuk

Reputation: 4342

Take a look at your data as if it is a graph where addresses are vertices and they have a connection if there is package for both of them. Then solution to your problem will be connected components of the graph.

Sparks gpraphX library has optimized function to find connected components. It will return vertices that are in different connected components, think of them as ids of each connected component.

Then having id you can collect all other addresses connected to it if needed.

Have a look at this article how they use graphs to achieve the same grouping as you.

Upvotes: 3

Related Questions