LordBertson
LordBertson

Reputation: 554

PySpark - How to turn duplicate rows into new columns

Let's say we have a table of clients of the kind:

| id | Name |
|-----------|
| 1  | John |
| 2  | Bob  |
| 3  | Ella |
| 4  | Jim  |

and a table of vehicles for each client:

| id |client_id| vehicle |
|------------------------|
| 1  |    1    |  car1   |
| 2  |    2    |  car2   |
| 3  |    2    |  car3   |
| 4  |    2    |  car4   |

Now we can see that Bob has 3 cars. I would like to add these vehicles to the clients table in the way that it gains new column for each vehicle with the respective client_id.

It should look something like this:

| id | Name | vehicle1 | vehicle2 | vehicle3 |
|--------------------------------------------|
| 1  | John |   car1   |   null   |   null   |
| 2  | Bob  |   car2   |   car3   |   car4   |
| 3  | Ella |   null   |   null   |   null   |
| 4  | Jim  |   null   |   null   |   null   |

Can this be achieved?

Upvotes: 0

Views: 226

Answers (2)

LordBertson
LordBertson

Reputation: 554

WARNING: This solution requires the data to which the udf_numbering is applied to be either on a single executor or at least that each id is completely contained within one executor. When run on multiple executors where data is dispersed randomly among them, each would get its own copy of the id_dict without considering the changes on the other executors.

Figured it out. Terribly slow but does the job.

UDF:

id_dict = {}

def numbering(id):
    id = str(id)
    if id in id_dict:
        value: int = id_dict.get(id)
        br_dict[id] = value + 1
        return str(value)
    else:
        br_dict[br_id] = 1
        return str(1)

udf_numbering = udf(lambda id: numbering(id))

And then for vehicles_df:

vehicles_df = vehicles_df.withColumn('number_repeated', udf_numbering(col('client_id')))
vehicles_df = vehicles_df.groupBy('client_id').pivot('number_repeated').agg(first('vehicle'))

and we join into clients_df:

clients_df = clients_df.join(vehicles_df, vehicles_df('client_id') == clients_df('id'), 'left')

Upvotes: 0

SMaZ
SMaZ

Reputation: 2655

Another approach :

Just note that it can be little slow Since we are evaluating same dataset twice (first to find max length and second to derive final dataset using max lenght)

import pyspark.sql.functions as f
df = df1.join(df2, [df1.id == df2.client_id], 'left_outer').groupBy(df1['id'],'Name').agg(f.collect_list('vehicle').alias('vehicle'))
df.show()
+---+----+------------------+
| id|Name|           vehicle|
+---+----+------------------+
|  1|John|            [car1]|
|  3|Ella|                []|
|  2|Bob |[car3, car4, car2]|
|  4|Jim |                []|
+---+----+------------------+

Find max length from all vehicles and derive final dataset

max_len = df.select(f.max(f.size('vehicle')).alias('max')).first()['max']


df.select('id', 'Name', *[df.vehicle[x] for x in range(max_len)]).show()
+---+----+----------+----------+----------+
| id|Name|vehicle[0]|vehicle[1]|vehicle[2]|
+---+----+----------+----------+----------+
|  1|John|      car1|      null|      null|
|  3|Ella|      null|      null|      null|
|  2|Bob |      car2|      car3|      car4|
|  4|Jim |      null|      null|      null|
+---+----+----------+----------+----------+

Upvotes: 2

Related Questions