DataDog
DataDog

Reputation: 525

PySpark drop-dupes based on a column condition

Still new to Spark and I'm trying to do this final transformation as cleanly and efficiently as possible.

Say I have a dataframe that looks like the following

+------+--------+                  
|ID    | Hit    |                  
+------+--------+
|123   |   0    | 
|456   |   1    |
|789   |   0    |     
|123   |   1    |   
|123   |   0    | 
|789   |   1    |   
|1234  |   0    |
| 1234 |   0    |   
+------+--------+

I'm trying to end up with a new dataframe(or two, depending on what's more efficient), where if a row has a 1 in "hit", it cannot have a row with a 0 in hit and if there is, the 0's would be to a distinct level based on the ID column.

Here's one of the methods I tried but I'm not sure if this is 1. The most efficient way possible 2. The cleanest way possible

dfhits = df.filter(df.Hit == 1)
dfnonhits = df.filter(df.Hit == 0)
dfnonhitsdistinct = dfnonhits.filter(~dfnonhits['ID'].isin(dfhits))

Enddataset would look like the following:

+------+--------+                  
|ID    | Hit    |                  
+------+--------+
|456   |   1    |    
|123   |   1    |   
|789   |   1    |   
|1234  |   0    |  
+------+--------+

Upvotes: 1

Views: 2160

Answers (1)

cph_sto
cph_sto

Reputation: 7585

# Creating the Dataframe.
from pyspark.sql.functions import col
df = sqlContext.createDataFrame([(123,0),(456,1),(789,0),(123,1),(123,0),(789,1),(500,0),(500,0)],
                                ['ID','Hit']) 
df.show()
+---+---+ 
| ID|Hit| 
+---+---+ 
|123|  0| 
|456|  1| 
|789|  0| 
|123|  1| 
|123|  0| 
|789|  1| 
|500|  0| 
|500|  0| 
+---+---+

The idea is to find the total of Hit per ID and in case it is more than 0, it means that there is atleast one 1 present in Hit. So, when this condition is true, we will remove all rows with Hit values 0.

# Registering the dataframe as a temporary view.
df.registerTempTable('table_view')
df=sqlContext.sql(
    'select ID, Hit, sum(Hit) over (partition by ID) as sum_Hit from table_view'
)
df.show()
+---+---+-------+ 
| ID|Hit|sum_Hit| 
+---+---+-------+ 
|789|  0|      1| 
|789|  1|      1| 
|500|  0|      0| 
|500|  0|      0| 
|123|  0|      1| 
|123|  1|      1| 
|123|  0|      1| 
|456|  1|      1| 
+---+---+-------+
df = df.filter(~((col('Hit')==0) & (col('sum_Hit')>0))).drop('sum_Hit').dropDuplicates()
df.show()
+---+---+ 
| ID|Hit|  
+---+---+ 
|789|  1| 
|500|  0| 
|123|  1| 
|456|  1|
+---+---+

Upvotes: 1

Related Questions