Evan
Evan

Reputation: 383

Create a new spark dataframe that contains pairwise combinations of another dataframe?

Consider the following code

question = spark.createDataFrame([{'A':1,'B':5},{'A':2,'B':5},
                             {'A':3,'B':5},{'A':3,'B':6}])
#+---+---+
#|  A|  B|
#+---+---+
#|  1|  5|
#|  2|  5|
#|  3|  5|
#|  3|  6|
#+---+---+

How can I create a spark dataframe that looks as follows :

solution = spark.createDataFrame([{'C':1,'D':2},{'C':1,'D':3},
                             {'C':2,'D':3},{'C':5,'D':6}])
#+---+---+
#|  C|  D|
#+---+---+
#|  1|  2|
#|  1|  3|
#|  2|  3|
#|  5|  6|
#+---+---+

This is the notion of triadic closure, where I am connecting the third edge of the triangle based upon which edges are already connected.

I must have (1,2) since (1,5) and (2,5) are present, I must have (1,3) since (1,5) and (3,5) are present, and I must have (2,3) since (2,5) and (3,5) are present. I must have (5,6) since (3,5) and (3,6) are present (an edge in both directions). There should NOT be an additional entry for (5,6) since no two pairs from A map to 6. Since there isn't a second instance in A that maps to 6, (5,6) does not get added.

Upvotes: 1

Views: 1858

Answers (2)

Sidd Singal
Sidd Singal

Reputation: 656

val df = sc.parallelize(Array((1,5),(2,5),(3,5),(3,6),(1,7),(2,7))).toDF("A","B")
df.union(df.select("B","A"))
  .groupByKey(r => r.getInt(0))
  .flatMapGroups({
    (K,Vs) => Vs.map(_.getInt(1)).toArray.combinations(2).map(a => (a(0), a(1)))
  })
  .dropDuplicates
  .show

This is in Scala, not Python, but should be easy to convert. I included some extra data points to show why dropDuplicates is necessary. I basically just followed exactly the steps I wrote above in a comment:
1) append the original dataframe to itself, but with B and A switched
2) group by A
3) flatmap group to all pairwise combinations (i think there are scala functions for this)
4) map new column to separate C and D columns (i didn't actually do this)
5) filter duplicates, if required

Upvotes: 0

mayank agrawal
mayank agrawal

Reputation: 2545

Try this,

import pyspark.sql.functions as F
from pyspark.sql.types import *
from itertools import combinations

df = spark.createDataFrame([{'A':1,'B':5},{'A':2,'B':5},
                         {'A':3,'B':5},{'A':3,'B':6}])

def pairs(list_):
    if len(set(list_)) > 1:
        return [[int(x[0]),int(x[1])] for x in combinations(set(list_), r=2)]
    else:
        return None

triadic_udf = F.udf(pairs, ArrayType(ArrayType(IntegerType())))
cols = ['C','D']
splits = [F.udf(lambda val:val[0],IntegerType())\
         ,F.udf(lambda val:val[1],IntegerType())]

df1 = df.groupby('B').agg(F.collect_list('A').alias('A'))\
                 .withColumn('pairs',F.explode(triadic_udf(F.col('A'))))\
                 .dropna().select('pairs')

df2 = df.groupby('A').agg(F.collect_list('B').alias('B'))\
                 .withColumn('pairs',F.explode(triadic_udf(F.col('B'))))\
                 .dropna().select('pairs')

solution = df1.union(df2).select([s('pairs').alias(c) for s,c in zip(splits,cols)])

solution.show()

Upvotes: 1

Related Questions