craigkrais
craigkrais

Reputation: 23

Pyspark: How to filter on list of two column value pairs?

So I have a PySpark Dataframe that I want to filter with a (long) list of valid pairs of two columns.

Say our dataframe's name is df and the columns col1 and col2:

col1   col2
1      A
2      B
3      1
null   2
A      null
2      null
1      null
B      C

and I have the valid pair list as: flist=[(1,A), (null,2), (1,null)]

When I try it with .isin() function (as below), it tells me that .isin() is not for tuples.

df.filter((df["col1"],df["col2"]).isin(flist))

There have been workarounds for this by concatenating the two strings or writing down a boolean expression for each pair, but I have a long list of valid pairs (hard to turn into boolean) and concatenating is also not reliable because of the nulls. Using the Python (df['col1'],df['col2']) in flist also does not work.

Is there a Pythonic/PySparkic way to do this?

Upvotes: 2

Views: 4292

Answers (3)

blackbishop
blackbishop

Reputation: 32660

You can create filder_df using the list and do a join :

flist = [("1", "A"), (None, "2"), ("1", None)]
filter_df = spark.createDataFrame(flist, ["col1", "col2"])

df1 = df.join(filter_df, ["col1", "col2"])

df1.show()
#+----+----+
#|col1|col2|
#+----+----+
#|   1|   A|
#+----+----+

Note that you can't compare null values. So only rows for tuple ("1", "A") are returned here. To check for nulls, you need to use isNull() on the column :

df1 = df.alias("df").join(
    filter_df.alias("fdf"),
    ((F.col("df.col1") == F.col("fdf.col1")) |
     (col("df.col1").isNull() & F.col("fdf.col1").isNull())
     ) &
    ((F.col("df.col2") == F.col("fdf.col2")) |
     (col("df.col2").isNull() & F.col("fdf.col2").isNull())
     )
).select("df.*")

df1.show()

#+----+----+
#|col1|col2|
#+----+----+
#|   1|   A|
#|null|   2|
#|   1|null|
#+----+----+

Or better use eqNullSafe as suggested in @Chris's answer.

Upvotes: 2

Chris
Chris

Reputation: 1455

Building on @blackbishop's approach of creating a Dataframe from your filter criteria and joining, you can use Column.eqNullSafe method to safely compare null values:

df = spark.createDataFrame(
    [('1', 'A', 1),
     ('2', 'B', 2),
     ('3', '1', 3),
     (None, '2', 4),
     ('A', None, 5),
     ('2', None, 6),
     ('1', None, 7),
     ('B', 'C', 8)], schema=['col1', 'col2', 'col3'])

flist = [("1", "A"), (None, "2"), ("1", None)]
filter_df = spark.createDataFrame(flist, ["col1", "col2"])

(df.join(filter_df,
         df["col1"].eqNullSafe(filter_df["col1"]) &
         df["col2"].eqNullSafe(filter_df['col2']))
 .select(df['col1'], df['col2'], df['col3'])
 .show())

Gives:

+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|null|   7|
|null|   2|   4|
|   1|   A|   1|
+----+----+----+

Note that the join only acts like a filter providing your 'filter' Dataframe contains unique rows. You could add a distinct on that Dataframe before the join to be sure (if your filter criteria was large for example).

Upvotes: 1

mck
mck

Reputation: 42352

Here's a way without joining, where you can chain a bunch of conditions in the filter in order to compare each row with the values in flist. It can take care of nulls.

from functools import reduce
import pyspark.sql.functions as F

flist = [(1, 'A'), (None, 2), (1, None)] 

df2 = df.filter(
    reduce(
        lambda x, y: x | y, 
        [ 
            ((F.col('col1') == col1) if col1 is not None else F.col('col1').isNull()) & 
            ((F.col('col2') == col2) if col2 is not None else F.col('col2').isNull())
            for (col1, col2) in flist
        ]
    )
)

df2.show()
+----+----+
|col1|col2|
+----+----+
|   1|   A|
|null|   2|
|   1|null|
+----+----+

Upvotes: 1

Related Questions