Reputation: 3919
I have a list of values in a column in a DataFrame which I want to use to filter another larger DataFrame that has 2 columns to match based upon.
Here is an example.
df1 = sqlContext.createDataFrame(
[(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")],
("ID", "label1"))
df2 = sqlContext.createDataFrame(
[
(1, 2, "x"),
(2, 1, "y"),
(3, 1, "z"),
(4, 6, "s"),
(7, 2, "t"),
(8, 9, "z")
],
("ID1", "ID2", "label2")
)
What I would like to get in the end is a DataFrame that has entries from df2
where both ID1
and ID2
are in df1
. In this case of the example, this would look like this;
+---+---+------+
|ID1|ID2| label|
+---+---+------+
| 1| 2| x|
| 2| 1| y|
| 3| 1| z|
+---+---+------+
I have tried doing this via a join like follows;
df = df1.join(df2, (df1.ID == df2.ID1) | (df1.ID == df2.ID2))
but this explodes my table and gives me
+---+------+---+---+------+
| ID|label1|ID1|ID2|label2|
+---+------+---+---+------+
| 1| a| 1| 2| x|
| 1| a| 2| 1| y|
| 1| a| 3| 1| z|
| 2| b| 1| 2| x|
| 2| b| 2| 1| y|
| 2| b| 7| 2| t|
| 3| c| 3| 1| z|
| 4| d| 4| 6| s|
+---+------+---+---+------+
Then,
df = df1.join(df2, (df1.ID == df2.ID1) & (df1.ID == df2.ID2))
Obviously isn't what I want either........ any help folks?
Upvotes: 1
Views: 3781
Reputation: 2200
You can use intersect after filter data separately. Here is the solution with using core spark api
>>> df1.show()
+---+------+
| ID|label1|
+---+------+
| 1| a|
| 2| b|
| 3| c|
| 4| d|
| 5| e|
+---+------+
>>> df2.show()
+---+---+------+
|ID1|ID2|label2|
+---+---+------+
| 1| 2| x|
| 2| 1| y|
| 3| 1| z|
| 4| 6| s|
| 7| 2| t|
| 8| 9| z|
+---+---+------+
>>> df3 = df1.join(df2, (df1.ID == df2.ID1)).select(df2['*'])
>>> df4 = df1.join(df2, (df1.ID == df2.ID2)).select(df2['*'])
>>> df3.intersect(df4).show()
+---+---+------+
|ID1|ID2|label2|
+---+---+------+
| 2| 1| y|
| 3| 1| z|
| 1| 2| x|
+---+---+------+
Upvotes: 0
Reputation: 43494
This is another approach using spark-sql:
First register your DataFrames as tables:
df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')
Next run the following query:
df = sqlContext.sql(
"SELECT * FROM df2 WHERE ID1 IN (SELECT ID FROM df1) AND ID2 IN (SELECT ID FROM df1)"
)
df.show()
#+---+---+------+
#|ID1|ID2|label2|
#+---+---+------+
#| 3| 1| z|
#| 2| 1| y|
#| 1| 2| x|
#+---+---+------+
Upvotes: 1
Reputation: 1960
I think you can use your initial join statement and further group the DataFrame and select the rows that occur twice, since ID1
AND ID2
should be present in df1
. Thus they should occur twice in the result because the join should duplicate the row of df2
with the two values of ID in df1
.
The resulting statement looks like:
from pyspark.sql.functions import col
df2.join(
df1,
[(df1.ID==df2.ID1)|(df1.ID==df2.ID2)],
how="left"
).groupBy("ID1","ID2","label").count().filter(col("count")==2).show()
The result is:
+---+---+-----+-----+
|ID1|ID2|label|count|
+---+---+-----+-----+
| 2 | 1 | y | 2 |
| 3 | 1 | z | 2 |
| 1 | 2 | x | 2 |
+---+---+-----+-----+
If you dont like the count column you can append a select("ID1","ID2","label")
to the statement
Upvotes: 3