PySpark: How to groupby with Or in columns

I want to groupby in PySpark, but the value can appear in more than a columns, so if it appear in any of the selected column it will be grouped by.

For example, if I have this table in Pyspark:

enter image description here

I want to sum the visits and investments for each ID, so that the result would be:

enter image description here

Note that the ID1 was the sum of the rows 0,1,3 which have the ID1 in one of the first three columns [ID1 Visits = 500 + 100 + 200 = 800]. The ID2 was the sum of the rows 1,2, etc

OBS 1: For the sake of simplicity my example was a simple dataframe, but in real is a much larger df with a lot of rows and a lot of variables, and other operations, not just "sum". This can't be worked on pandas, because is too large. Should be in PySpark

OBS2: For ilustration I printed in pandas the tables, but in real it is in the PySpark

I appreciate all the help and thank you very much in advance

Upvotes: 1

Views: 2089

Answers (2)

SMaZ
SMaZ

Reputation: 2655

You can do something like below:

  1. Create array of all id columns- > ids column below
  2. explode ids column
  3. Now you will get duplicates, to avoid duplicate aggregation use distinct
  4. Finally groupBy ids column and perform all your aggregations

Note: : If your dataset can have exact duplicate rows then add one columns with df.withColumn('uid', f.monotonically_increasing_id()) before creating array otherwise distinct will drop it.

Example for your dataset:

import pyspark.sql.functions as f

df.withColumn('ids', f.explode(f.array('id1','id2','id3'))).distinct().groupBy('ids').agg(f.sum('visits'), f.sum('investments')).orderBy('ids').show()
+---+-----------+----------------+
|ids|sum(visits)|sum(investments)|
+---+-----------+----------------+
|  1|        800|            1400|
|  2|        300|             600|
|  3|        500|             800|
|  4|        700|            1200|
|  5|        200|             400|
+---+-----------+----------------+

Upvotes: 2

lukaszKielar
lukaszKielar

Reputation: 541

First of all let's create our test dataframe.

>>> import pandas as pd

>>> data = {
       "ID1": [1, 2, 5, 1],
       "ID2": [1, 1, 3, 3],
       "ID3": [4, 3, 2, 4],
       "Visits": [500, 100, 200, 200],
       "Investment": [1000, 200, 400, 200]
    }
>>> df = spark.createDataFrame(pd.DataFrame(data))
>>> df.show()

+---+---+---+------+----------+
|ID1|ID2|ID3|Visits|Investment|
+---+---+---+------+----------+
|  1|  1|  4|   500|      1000|
|  2|  1|  3|   100|       200|
|  5|  3|  2|   200|       400|
|  1|  3|  4|   200|       200|
+---+---+---+------+----------+

Once we have DataFrame that we can operate on we have to define a function which will return list of unique IDs from columns ID1, ID2 and ID3.

>>> import pyspark.sql.functions as F
>>> from pyspark.sql.types import ArrayType, IntegerType

>>> @F.udf(returnType=ArrayType(IntegerType()))
... def ids_list(*cols):
...    return list(set(cols))

Now it's time to apply our udf on a DataFrame.

>>> df = df.withColumn('ids', ids_list('ID1', 'ID2', 'ID3'))
>>> df.show()

+---+---+---+------+----------+---------+
|ID1|ID2|ID3|Visits|Investment|      ids|
+---+---+---+------+----------+---------+
|  1|  1|  4|   500|      1000|   [1, 4]|
|  2|  1|  3|   100|       200|[1, 2, 3]|
|  5|  3|  2|   200|       400|[2, 3, 5]|
|  1|  3|  4|   200|       200|[1, 3, 4]|
+---+---+---+------+----------+---------+

To make use of ids column we have to explode it into separate rows and drop ids column.

>>> df = df.withColumn("ID", F.explode('ids')).drop('ids')
>>> df.show()

+---+---+---+------+----------+---+
|ID1|ID2|ID3|Visits|Investment| ID|
+---+---+---+------+----------+---+
|  1|  1|  4|   500|      1000|  1|
|  1|  1|  4|   500|      1000|  4|
|  2|  1|  3|   100|       200|  1|
|  2|  1|  3|   100|       200|  2|
|  2|  1|  3|   100|       200|  3|
|  5|  3|  2|   200|       400|  2|
|  5|  3|  2|   200|       400|  3|
|  5|  3|  2|   200|       400|  5|
|  1|  3|  4|   200|       200|  1|
|  1|  3|  4|   200|       200|  3|
|  1|  3|  4|   200|       200|  4|
+---+---+---+------+----------+---+

Finally we have to group our DataFrame by ID column and calculate sums. Final result is ordered by ID.

>>> final_df = (
...    df.groupBy('ID')
...       .agg( F.sum('Visits'), F.sum('Investment') )
...       .orderBy('ID')
... )
>>> final_df.show()

+---+-----------+---------------+
| ID|sum(Visits)|sum(Investment)|
+---+-----------+---------------+
|  1|        800|           1400|
|  2|        300|            600|
|  3|        500|            800|
|  4|        700|           1200|
|  5|        200|            400|
+---+-----------+---------------+

I hope you make it useful.

Upvotes: 3

Related Questions