Rashida
Rashida

Reputation: 491

groupby and aggregate in multiple elements in an RDD object in pyspark

Here are the first three elements of my RDD object:

[('E7750A37CAB07D0DFF0AF7E3573AC141',
  0.03333333333333333,
  0.44,
  1.0,
  0.0,
  0.0,
  3.5),
 ('778C92B26AE78A9EBDF96B49C67E4007',
  0.03333333333333333,
  0.71,
  1.0,
  0.0,
  1.0,
  4.0),
 ('BE317B986700F63C43438482792C8654',
  0.03333333333333333,
  0.48,
  1.0,
  0.0,
  0.0,
  4.0)]

I want to groupby using the string element (such as 'BE317B986700F63C43438482792C8654') and add the rest of the elements. I am new to pyspark.

Upvotes: 0

Views: 244

Answers (1)

Shanif Ansari
Shanif Ansari

Reputation: 37

We can take your input as

    input=[('E7750A37CAB07D0DFF0AF7E3573AC141',0.03333333333333333,0.44,1.0,0.0,0.0,3.5),('778C92B26AE78A9EBDF96B49C67E4007',0.03333333333333333,0.71,1.0,0.0,1.0,4.0),('BE317B986700F63C43438482792C8654',0.03333333333333333,0.48,1.0,0.0,0.0,4.0)]

First, you can use reduceByKey() function to add elements according to the group by key.

But to use that we have to create a PairRDD which is just an RDD of tuple in which the first element is always the key(though you can change this using keyBy function on your RDD).

First of all, read the input:

    input = sc.parallelize(input) #creating an RDD

In our input the first element is the key. Now, we want to put each number in the input with its associated key. We want something like this:

('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.03333333333333333), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.44), ('E7750A37CAB07D0DFF0AF7E3573AC141', 1.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 3.5), ('778C92B26AE78A9EBDF96B49C67E4007', 0.0), ('778C92B26AE78A9EBDF96B49C67E4007', 0.03333333333333333), ('778C92B26AE78A9EBDF96B49C67E4007', 0.71), ('778C92B26AE78A9EBDF96B49C67E4007', 1.0), ('778C92B26AE78A9EBDF96B49C67E4007', 0.0), ('778C92B26AE78A9EBDF96B49C67E4007', 1.0), ('778C92B26AE78A9EBDF96B49C67E4007', 4.0), ('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 0.03333333333333333), ('BE317B986700F63C43438482792C8654', 0.48), ('BE317B986700F63C43438482792C8654', 1.0), ('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 4.0)

To achieve this, we can use a lambda function to iterate over each element(For Example, ('E7750A37CAB07D0DFF0AF7E3573AC141',0.03333333333333333,0.44,1.0,0.0,0.0,3.5)) in the RDD, and under each element use a list comprehension to iterate over the integer elements, for example,

     (lambda x: [(x[0],y) for y in x])

In list comprehension, we don't want to create tuple of x[0] with itself. So, remove this using if else.

     lambda x: [(x[0],y) if y != x[0] else (x[0],0.000) for y in x]

Now, we can write this as:

     input2 = input.map(lambda x: [(x[0],y) if y != x[0] else (x[0],0.000) for y in x])

    input2.collect()

[[('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.03333333333333333), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.44), ('E7750A37CAB07D0DFF0AF7E3573AC141', 1.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 0.0), ('E7750A37CAB07D0DFF0AF7E3573AC141', 3.5)], [('778C92B26AE78A9EBDF96B49C67E4007', 0.0), ('778C92B26AE78A9EBDF96B49C67E4007', 0.03333333333333333), ('778C92B26AE78A9EBDF96B49C67E4007', 0.71), ('778C92B26AE78A9EBDF96B49C67E4007', 1.0), ('778C92B26AE78A9EBDF96B49C67E4007', 0.0), ('778C92B26AE78A9EBDF96B49C67E4007', 1.0), ('778C92B26AE78A9EBDF96B49C67E4007', 4.0)], [('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 0.03333333333333333), ('BE317B986700F63C43438482792C8654', 0.48), ('BE317B986700F63C43438482792C8654', 1.0), ('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 0.0), ('BE317B986700F63C43438482792C8654', 4.0)]]

In the above output we have got list of lists, so we need to flatten this into a single list.

    input3 = input2.flatMap(lambda x: x)
    input3.collect()

We can put all of this in a single line as:

    input2 = input.flatMap(lambda x: [(x[0],y) if y != x[0] else (x[0],0.000) for y in x])

Finally, use reduceByKey:

    from operator import add
    finalOutput = input2.reduceByKey(add)
    finalOutput.collect()

[('778C92B26AE78A9EBDF96B49C67E4007', 6.743333333333333), ('BE317B986700F63C43438482792C8654', 5.513333333333334), ('E7750A37CAB07D0DFF0AF7E3573AC141', 4.973333333333334)]

Hope that my answer helps you!

Upvotes: 1

Related Questions