frank
frank

Reputation: 3608

filter a list in pyspark dataframe

I have a list of sentences in a pyspark (v2.4.5) dataframe with a matching set of scores. The sentences and scores are in list forms.

df=spark.createDataFrame(
    [
        (1, ['foo1','foo2','foo3'],[0.1,0.5,0.6]), # create your data here, be consistent in the types.
        (2, ['bar1','bar2','bar3'],[0.5,0.7,0.7]),
        (3, ['baz1','baz2','baz3'],[0.1,0.2,0.3]),
    ],
    ['id', 'txt','score'] # add your columns label here
)
df.show()
+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
|  3|[baz1, baz2, baz3]|[0.1, 0.2, 0.3]|
+---+------------------+---------------+

I want to filter and return only those sentences that have a score >=0.5.

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|      [foo2, foo3]|     [0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

Any suggestions?

I tried pyspark dataframe filter or include based on list but was not able to get it working in my instance

Upvotes: 3

Views: 3313

Answers (4)

anky
anky

Reputation: 75140

With spark 2.4+ , you can access higher order functions , so you can filter on a zipped array with condition then filter out blank arrays:

import pyspark.sql.functions as F

e = F.expr('filter(arrays_zip(txt,score),x-> x.score>=0.5)')
df.withColumn("txt",e.txt).withColumn("score",e.score).filter(F.size(e)>0).show()

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|      [foo2, foo3]|     [0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

Upvotes: 4

Shubham Jain
Shubham Jain

Reputation: 5536

In spark user defined functions are considered as blackbox as the catalyst optimizer can not optimize the code inside the udf. So avoid using udf if possible.

Here is an example without using UDF's

df.withColumn('combined',f.explode(f.arrays_zip('txt','score'))).filter(f.col('combined.score')>=0.5).groupby('id').agg(f.collect_list('combined.txt').alias('txt'),f.collect_list('combined.score').alias('score')).show()

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|      [foo2, foo3]|     [0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

Hope it works.

Upvotes: -1

pissall
pissall

Reputation: 7419

Try this, I couldn't think of a way to do it without UDF's:

from pyspark.sql.types import ArrayType, BooleanType, StringType()

# UDF for boolean index
filter_udf = udf(lambda arr: [True if x >= 0.5 else False for x in arr], ArrayType(BooleanType()))

# UDF for filtering on the boolean index
filter_udf_bool = udf(lambda col_arr, bool_arr: [x for (x,y) in zip(col_arr,bool_arr) if y], ArrayType(StringType()))

df2 = df.withColumn("test", filter_udf("score"))
df3 = df2.withColumn("txt", filter_udf_bool("txt", "test")).withColumn("score", filter_udf_bool("score", "test"))

Output:

# Further filtering for empty arrays:
df3.drop("test").filter(F.size(F.col("txt")) > 0).show()

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|      [foo2, foo3]|     [0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

You can actually generalize the UDF as well by combining it all in one. I've split it for simplicity sake.

Upvotes: 1

Krish.Venkat
Krish.Venkat

Reputation: 54

The column score is a type of array, which needs to be further filtered with predicate.

Code snippet to filter the array column:

def score_filter(row):
    score_filtered = [s for s in row.score if s >= 0.5]
    if len(score_filtered) > 0:
        yield row


filtered = df.rdd.flatMap(score_filter).toDF()

filtered.show()

Output:

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

Upvotes: -2

Related Questions