Denver
Denver

Reputation: 245

PySpark Dataframe identify distinct value on one column based on duplicate values in other columns

I have a pyspark dataframe like: where c1,c2,c3,c4,c5,c6 are the columns

   +----------------------------+   
   |c1 | c2 | c3 | c4 | c5 | c6 |  
   |----------------------------|   
   | a |  x |  y |  z |  g |  h |    
   | b |  m |  f |  l |  n |  o |    
   | c |  x |  y |  z |  g |  h |    
   | d |  m |  f |  l |  n |  o |    
   | e |  x |  y |  z |  g |  i |   
   +----------------------------+

I want to extract c1 values for the rows which have same c2,c3,c4,c5 values but different c1 values. Like, 1st, 3rd & 5th rows have same values for c2,c3,c4 & c5 but different c1 value. So the output should be a, c & e.
(update) similarly, 2nd & 4th rows have same values for c2,c3,c4 & c5 but different c1 value. So the output should also contain b & d

How can I obtain such result ? I have tried applying groupby but I don't understand how to obtain distinct values for c1.

UPDATE:

Output should be a Dataframe of c1 values

# +-------+
# |c1_dups|
# +-------+
# |  a,c,e|
# |    b,e|
# +-------+   

My Approach:

m = data.groupBy('c2','c3','c4','c5)

but I'm not understanding how to retrieve the values in m. I'm new to pyspark dataframes hence very much confused

Upvotes: 2

Views: 5014

Answers (1)

eliasah
eliasah

Reputation: 40380

This is actually very simple, let's create some data first :

schema = ['c1','c2','c3','c4','c5','c6']

rdd = sc.parallelize(["a,x,y,z,g,h","b,x,y,z,l,h","c,x,y,z,g,h","d,x,f,y,g,i","e,x,y,z,g,i"]) \
        .map(lambda x : x.split(","))

df = sqlContext.createDataFrame(rdd,schema)
# +---+---+---+---+---+---+
# | c1| c2| c3| c4| c5| c6|
# +---+---+---+---+---+---+
# |  a|  x|  y|  z|  g|  h|
# |  b|  x|  y|  z|  l|  h|
# |  c|  x|  y|  z|  g|  h|
# |  d|  x|  f|  y|  g|  i|
# |  e|  x|  y|  z|  g|  i|
# +---+---+---+---+---+---+

Now the fun part, you'll just need to import some functions, group by and explode as following :

from pyspark.sql.functions import *

dupes = df.groupBy('c2','c3','c4','c5') \ 
          .agg(collect_list('c1').alias("c1s"),count('c1').alias("count")) \ # we collect as list and count at the same time
          .filter(col('count') > 1) # we filter dupes

df2 = dupes.select(explode("c1s").alias("c1_dups"))

df2.show()
# +-------+
# |c1_dups|
# +-------+
# |      a|
# |      c|
# |      e|
# +-------+

I hope this answers your question.

Upvotes: 6

Related Questions