Taylrl
Taylrl

Reputation: 3919

Joining 2 columns based on values in another using pyspark

I have a list of values in a column in a DataFrame which I want to use to filter another larger DataFrame that has 2 columns to match based upon.

Here is an example.

df1 = sqlContext.createDataFrame(
     [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")],
     ("ID", "label1"))

df2 = sqlContext.createDataFrame(
    [
        (1, 2, "x"),
        (2, 1, "y"),
        (3, 1, "z"),
        (4, 6, "s"),
        (7, 2, "t"),
        (8, 9, "z")
    ],
    ("ID1", "ID2", "label2")
)

What I would like to get in the end is a DataFrame that has entries from df2 where both ID1 and ID2 are in df1. In this case of the example, this would look like this;

+---+---+------+
|ID1|ID2| label|
+---+---+------+
|  1|  2|     x|
|  2|  1|     y|
|  3|  1|     z|
+---+---+------+

I have tried doing this via a join like follows;

df = df1.join(df2, (df1.ID == df2.ID1) | (df1.ID == df2.ID2))

but this explodes my table and gives me

+---+------+---+---+------+
| ID|label1|ID1|ID2|label2|
+---+------+---+---+------+
|  1|     a|  1|  2|     x|
|  1|     a|  2|  1|     y|
|  1|     a|  3|  1|     z|
|  2|     b|  1|  2|     x|
|  2|     b|  2|  1|     y|
|  2|     b|  7|  2|     t|
|  3|     c|  3|  1|     z|
|  4|     d|  4|  6|     s|
+---+------+---+---+------+

Then,

df = df1.join(df2, (df1.ID == df2.ID1) & (df1.ID == df2.ID2))

Obviously isn't what I want either........ any help folks?

Upvotes: 1

Views: 3781

Answers (3)

Ali Yesilli
Ali Yesilli

Reputation: 2200

You can use intersect after filter data separately. Here is the solution with using core spark api

>>> df1.show()
+---+------+
| ID|label1|
+---+------+
|  1|     a|
|  2|     b|
|  3|     c|
|  4|     d|
|  5|     e|
+---+------+

>>> df2.show()
+---+---+------+
|ID1|ID2|label2|
+---+---+------+
|  1|  2|     x|
|  2|  1|     y|
|  3|  1|     z|
|  4|  6|     s|
|  7|  2|     t|
|  8|  9|     z|
+---+---+------+

>>> df3 = df1.join(df2, (df1.ID == df2.ID1)).select(df2['*'])
>>> df4 = df1.join(df2, (df1.ID == df2.ID2)).select(df2['*'])
>>> df3.intersect(df4).show()
+---+---+------+                                                                
|ID1|ID2|label2|
+---+---+------+
|  2|  1|     y|
|  3|  1|     z|
|  1|  2|     x|
+---+---+------+

Upvotes: 0

pault
pault

Reputation: 43494

This is another approach using spark-sql:

First register your DataFrames as tables:

df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')

Next run the following query:

df = sqlContext.sql(
    "SELECT * FROM df2 WHERE ID1 IN (SELECT ID FROM df1) AND ID2 IN (SELECT ID FROM df1)"
)
df.show()
#+---+---+------+
#|ID1|ID2|label2|
#+---+---+------+
#|  3|  1|     z|
#|  2|  1|     y|
#|  1|  2|     x|
#+---+---+------+

Upvotes: 1

gaw
gaw

Reputation: 1960

I think you can use your initial join statement and further group the DataFrame and select the rows that occur twice, since ID1 AND ID2 should be present in df1. Thus they should occur twice in the result because the join should duplicate the row of df2 with the two values of ID in df1.

The resulting statement looks like:

from pyspark.sql.functions import col

df2.join(
    df1,
    [(df1.ID==df2.ID1)|(df1.ID==df2.ID2)],
    how="left"
).groupBy("ID1","ID2","label").count().filter(col("count")==2).show()

The result is:

+---+---+-----+-----+
|ID1|ID2|label|count| 
+---+---+-----+-----+ 
| 2 | 1 | y   | 2   | 
| 3 | 1 | z   | 2   | 
| 1 | 2 | x   | 2   |
+---+---+-----+-----+

If you dont like the count column you can append a select("ID1","ID2","label") to the statement

Upvotes: 3

Related Questions