PaulH
PaulH

Reputation: 327

How can I pivot on multiple columns separately in PySpark

Is there a possibility to make a pivot for different columns at once in PySpark? I have a dataframe like this:

from pyspark.sql import functions as sf
import pandas as pd
sdf = spark.createDataFrame(
    pd.DataFrame([[1, 'str1', 'str4'], [1, 'str1', 'str4'], [1, 'str2', 'str4'], [1, 'str2', 'str5'],
        [1, 'str3', 'str5'], [2, 'str2', 'str4'], [2, 'str2', 'str4'], [2, 'str3', 'str4'],
        [2, 'str3', 'str5']], columns=['id', 'col1', 'col2'])
)
# +----+------+------+
# | id | col1 | col2 |
# +----+------+------+
# |  1 | str1 | str4 |
# |  1 | str1 | str4 |
# |  1 | str2 | str4 |
# |  1 | str2 | str5 |
# |  1 | str3 | str5 |
# |  2 | str2 | str4 |
# |  2 | str2 | str4 |
# |  2 | str3 | str4 |
# |  2 | str3 | str5 |
# +----+------+------+

I want to pivot it on multiple columns ("col1", "col2", ...) to have a dataframe that looks like this:

+----+-----------+-----------+-----------+-----------+-----------+
| id | col1_str1 | col1_str2 | col1_str3 | col2_str4 | col2_str5 |
+----+-----------+-----------+-----------+-----------+-----------+
|  1 |         2 |         2 |         1 |         3 |         3 |
|  2 |         0 |         2 |         2 |         3 |         1 |
+----+-----------+-----------+-----------+-----------+-----------+

I've found a solution that works:

sdf_pivot_col1 = (
    sdf
    .groupby('id')
    .pivot('col1')
    .agg(sf.count('id'))
)
sdf_pivot_col2 = (
    sdf
    .groupby('id')
    .pivot('col2')
    .agg(sf.count('id'))
)

sdf_result = (
    sdf
    .select('id').distinct()
    .join(sdf_pivot_col1, on = 'id' , how = 'left')
    .join(sdf_pivot_col2, on = 'id' , how = 'left')
).show()

# +---+----+----+----+----+----+
# | id|str1|str2|str3|str4|str5|
# +---+----+----+----+----+----+
# |  1|   2|   2|   1|   3|   2|
# |  2|null|   2|   2|   3|   1|
# +---+----+----+----+----+----+

But I'm looking for a more compact solution.

Upvotes: 5

Views: 7523

Answers (3)

ZygD
ZygD

Reputation: 24356

What you want here is not pivoting on multiple columns (this is pivoting on multiple columns).
What you really want is pivoting on one column, but first moving both column values into one...

from pyspark.sql import functions as F

cols = [c for c in sdf.columns if c!= 'id']
sdf = (sdf
    .withColumn('_pivot', F.explode(F.array(
        *[F.concat(F.lit(f'{c}_'), c) for c in cols]
    ))).groupBy('id').pivot('_pivot').count().fillna(0)
)

sdf.show()
# +---+---------+---------+---------+---------+---------+
# | id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
# +---+---------+---------+---------+---------+---------+
# |  1|        2|        2|        1|        3|        2|
# |  2|        0|        2|        2|        3|        1|
# +---+---------+---------+---------+---------+---------+

Upvotes: 0

Ala Tarighati
Ala Tarighati

Reputation: 3817

Try this:

from functools import reduce
from pyspark.sql import DataFrame

cols = [x for x in sdf.columns if x!='id']
df_array = [sdf.withColumn('col', F.concat(F.lit(x), F.lit('_'), F.col(x))).select('id', 'col') for x in cols]

reduce(DataFrame.unionAll, df_array).groupby('id').pivot('col').agg(F.count('col')).show()

Output:

+---+---------+---------+---------+---------+---------+
| id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
+---+---------+---------+---------+---------+---------+
|  1|        2|        2|        1|        3|        2|
|  2|     null|        2|        2|        3|        1|
+---+---------+---------+---------+---------+---------+

Upvotes: 1

PaulH
PaulH

Reputation: 327

With the link of @mrjoseph I came up with the following solution: It works, it's more clean, but I still don't like the joins...

def pivot_udf(df, *cols):
    mydf = df.select('id').drop_duplicates()
    for c in cols:
        mydf = mydf.join(
            df
            .withColumn('combcol',sf.concat(sf.lit('{}_'.format(c)),df[c]))
            .groupby('id.pivot('combcol.agg(sf.count(c)),
            how = 'left', 
            on = 'id'
        )
    return mydf

pivot_udf(sdf, 'col1','col2').show()

+---+---------+---------+---------+---------+---------+
| id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
+---+---------+---------+---------+---------+---------+
|  1|        2|        2|        1|        3|        2|
|  2|     null|        2|        2|        3|        1|
+---+---------+---------+---------+---------+---------+

Upvotes: 1

Related Questions