yatu
yatu

Reputation: 88236

Keep groups where at least one element satisfies condition in pyspark

I've been trying to reproduce in pyspark something that is fairly easy to do in Pandas, but I've been struggling for a while now. Say I have the following dataframe:

df = pd.DataFrame({'a':[1,2,2,1,1,2], 'b':[12,5,1,19,2,7]})
print(df)
   a   b
0  1  12
1  2   5
2  2   1
3  1  19
4  1   2
5  2   7

And the list

l = [5,1]

What I'm trying to do, is to group by a, and if any of the elements in b are in the list, then return True for all values in the group. Then we could use the result to index the dataframe. The Pandas equivalent of this, would be:

df[df.b.isin(l).groupby(df.a).transform('any')]

   a  b
1  2  5
2  2  1
5  2  7

Reproducible dataframe in pyspark:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

df = pd.DataFrame({'a':[1,2,2,1,1,2], 'b':[12,5,1,19,2,7]})
sparkdf = spark.createDataFrame(df)

I was currently going in the direction of grouping by a and applying a pandasUDF, though there's surely a better way to do this using spark only.

Upvotes: 2

Views: 747

Answers (1)

yatu
yatu

Reputation: 88236

I've figured out a simple enough solution. The first step is to filter out rows where the values in b are in the list using isin and filter, and then keeping the unique grouping keys (a) in a list.

Then by merging back with the dataframe on a we keep groups contained in the list:

unique_a = (sparkdf.filter(f.col('b').isin(l))
                   .select('a').distinct())
sparkdf.join(unique_a, 'a').show()

+---+---+
|  a|  b|
+---+---+
|  2|  5|
|  2|  1|
|  2|  7|
+---+---+

Upvotes: 3

Related Questions