Marco
Marco

Reputation: 1235

PySpark drop Duplicates and Keep Rows with highest value in a column

I have the following Spark dataset:

id    col1    col2    col3    col4
1      1        5       2      3
1      1        0       2      3
2      3        1       7      7
3      6        1       3      3
3      6        5       3      3

I would like to drop the duplicates in the columns subset ['id,'col1','col3','col4'] and keep the duplicate rows with the highest value in col2. This is what the result should look like:

id    col1    col2    col3    col4
1      1        5       2      3
2      3        1       7      7
3      6        5       3      3

How can I do that in PySpark?

Upvotes: 1

Views: 3296

Answers (2)

karatekraft
karatekraft

Reputation: 185

If you are more comfortable with SQL syntax rather than the PySpark Dataframe apis, you can do this approach:

Create dataframe (optional since you already have data)

from pyspark.sql.types import StructType,StructField, IntegerType

data = [
  (1,      1,        5,       2,      3),
  (1,      1,        0,       2,      3),
  (2,      3,        1,       7,      7),
  (3,      6,        1,       3,      3),
  (3,      6,        5,       3,      3),
]

schema = StructType([ \
    StructField("id",IntegerType()), \
    StructField("col1",IntegerType()), \
    StructField("col2",IntegerType()), \
    StructField("col3", IntegerType()), \
    StructField("col4", IntegerType()), \
  ])

df = spark.createDataFrame(data=data,schema=schema)
df.show()

Then create a view of the dataframe to run sql queries. Below creates a new temporary view of the dataframe called "tbl".

# create view from df called "tbl"
df.createOrReplaceTempView("tbl")

Finally write a SQL query with the view. Here we group by id, col1, col3, and col4, and then select rows with max value of col2.

# query to group by id,col1,col3,col4 and select max col2
my_query = """
select 
  id, col1, max(col2) as col2, col3, col4
from tbl
group by id, col1, col3, col4
"""

new_df = spark.sql(my_query)
new_df.show()

Final output:

+---+----+----+----+----+
| id|col1|col2|col3|col4|
+---+----+----+----+----+
|  1|   1|   5|   2|   3|
|  2|   3|   1|   7|   7|
|  3|   6|   5|   3|   3|
+---+----+----+----+----+

Upvotes: 1

wwnde
wwnde

Reputation: 26676

Another way, compute the max, filter where max=col2. This allows you to keep multiple instances where the condition is true

df.withColumn('max',max('col2').over(Window.partitionBy('id'))).where(col('col2')==col('max')).show()

Upvotes: 3

Related Questions