Reputation: 886
I have a pyspark dataframe which contains duplicates of certain column values, like this:
showing deptDF
+--------+----+------------+----------+--------+-----------+
|quantity|cost|participants|activity |category|id |
+--------+----+------------+----------+--------+-----------+
|4 |10 |2 |skiing |outdoor |8589934592 |
|4 |13 |3 |golf |indoor |17179869184|
|4 |10 |5 |swimming |outdoor |25769803776|
|4 |10 |3 |basketball|indoor |34359738368|
|4 |11 |7 |pool |indoor |42949672960|
|4 |11 |12 |pool |outdoor |51539607552|
|4 |13 |15 |golf |indoor |60129542144|
+--------+----+------------+----------+--------+-----------+
I need to identify the duplicate quantity-category-activity combinations, take the row in the duplicate pair that has the lower number of participants, and set the cost for that row to 0. My original strategy was to add an index column and then use the pyspark window functionality to create a new dataframe that only contains the duplicated rows (and their duplicates, so that this dataframe would be composed of pairs that had the same quantity-category-activity combination). I then thought I could convert that dataframe into a python list and iterate over it, and in the event that I came across a row with a lower number of participants than the previous row, I would identify that row in the original dataframe using its id and set the row's cost to 0.
My code so far:
w = Window.partitionBy('quantity', 'cost', 'activity')
deptDF_duplicates = deptDF.select('*', f.count('quantity').over(w).alias('dupeCount'))\
.where('dupeCount > 1')\
.drop('dupeCount')
deptDF_duplicates.show()
duplicates_list = [list(row) for row in deptDF_duplicates["id", "cost"].collect()]
Before I go down this road I wanted to check if there isn't a way to do this more efficiently with dataframe operations, because depending on the size of my data, python dictionaries are probably much too slow for the job. I haven't found an example of this anywhere in the pyspark documentation or the tutorials I have looked at.
Any help would be much appreciated.
Upvotes: 0
Views: 58
Reputation: 2043
I don't think you have to use any iteration in this case, just a window function can achieve your goal:
df = spark.createDataFrame(
[
(4, 10, 1, 1),
(4, 10, 1, 2)
],
schema=['a', 'b', 'c', 'participants']
)
df.show(10, False)
+---+---+---+------------+
|a |b |c |participants|
+---+---+---+------------+
|4 |10 |1 |1 |
|4 |10 |1 |2 |
+---+---+---+------------+
Just check the participants
value is the max or not:
df.withColumn(
'b',
func.when(
func.col('participants') == func.max(func.col('participants')).over(Window.partitionBy('a', 'b', 'c')), func.col('b')
).otherwise(func.lit(0))
).show(10, False)
+---+---+---+------------+
|a |b |c |participants|
+---+---+---+------------+
|4 |0 |1 |1 |
|4 |10 |1 |2 |
+---+---+---+------------+
Upvotes: 1