atalantara
atalantara

Reputation: 23

how to select first n row items based on multiple conditions in pyspark

Now I have data like this:

+----+----+
|col1|   d|
+----+----+
|   A|   4|
|   A|  10|
|   A|   3|
|   B|   3|
|   B|   6|
|   B|   4|
|   B| 5.5|
|   B|  13|
+----+----+

col1 is StringType, d is TimestampType, here I use DoubleType instead. I want to generate data based on conditions tuples. Given a tuple[(A,3.5),(A,8),(B,3.5),(B,10)] I want to have the result like

+----+---+
|col1|  d|
+----+---+
|   A|  4|
|   A| 10|
|   B|  4|
|   B| 13|
+----+---+

That is for each element in the tuple, we select from the pyspark dataframe the first 1 row that d is larger than the tuple number and col1 is equal to the tuple string. What I've already written is:

df_res=spark_empty_dataframe    
for (x,y) in tuples:
         dft=df.filter(df.col1==x).filter(df.d>y).limit(1)
         df_res=df_res.union(dft)

But I think this might have efficiency problem, I do not know if I were right.

Upvotes: 2

Views: 508

Answers (1)

anky
anky

Reputation: 75100

A possible approach avoiding loops can be creating a dataframe from the tuple you have as input:

t = [('A',3.5),('A',8),('B',3.5),('B',10)]
ref=spark.createDataFrame([(i[0],float(i[1])) for i in t],("col1_y","d_y"))

Then we can join on the input dataframe(df) on condition and then group on the keys and values of tuple which will be repeated to get the first value on each group, then drop the extra columns:

(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner').orderBy("col1","d")

.groupBy("col1_y","d_y").agg(F.first("col1").alias("col1"),F.first("d").alias("d"))

.drop("col1_y","d_y")).show()

+----+----+
|col1|   d|
+----+----+
|   A|10.0|
|   A| 4.0|
|   B| 4.0|
|   B|13.0|
+----+----+

Note, if order of the dataframe is important , you can assign an index column with monotonically_increasing_id and include them in the aggregation then orderBy the index column.

EDIT another way instead of ordering and get first directly with min:

(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner')

.groupBy("col1_y","d_y").agg(F.min("col1").alias("col1"),F.min("d").alias("d"))

.drop("col1_y","d_y")).show()

+----+----+
|col1|   d|
+----+----+
|   B| 4.0|
|   B|13.0|
|   A| 4.0|
|   A|10.0|
+----+----+

Upvotes: 2

Related Questions