xiaodai
xiaodai

Reputation: 16004

How to rename columns in pyspark similar to to using a Spark-compatible SQL PIVOT statement?

I ran this query in Spark

f"""
                    select
                        col_cate,
                        true as segment_true,
                        false as segment_false
                    from (
                        SELECT
                            {feature},
                            {target_col},
                            count(*) as cnt
                        from
                            table1
                        group by
                            {feature},
                            {target_col}
                    ) pivot (
                        sum(cnt)
                        for target_bool in (true, false)
                    )
                """)

The input data is this

+--------+-----------+
|col_cate|target_bool|
+--------+-----------+
|       A|       true|
|       A|      false|
|       B|       true|
|       B|      false|
|       A|       true|
|       A|      false|
|       B|       true|
|       B|      false|
+--------+-----------+

and the output data is

+--------+------------+-------------+
|col_cate|segment_true|segment_false|
+--------+------------+-------------+
|       A|        true|        false|
|       B|        true|        false|
+--------+------------+-------------+

However, I when I tried to do the same in pyspark, I can't figure out how to rename the outputs from [col_cate, true, false] to [col_cate, segment_true segment_false]

How do I do that?

I tried

output_df.groupBy(["col_cate", "target_bool"]).\
  count().\
  groupBy("col_cate").\
  pivot("target_bool").\
  sum("count")

but there doesn't seem to be a way to rename the column in the code. I know I can rename it after but it feels less elegant.

Upvotes: 2

Views: 147

Answers (2)

Raghu
Raghu

Reputation: 1712

You can do it using the agg and alias method after pivoting.

import pyspark.sql.functions as F
tst = sqlContext.createDataFrame([('A','true',1),('A','false',2),('B','true',3),('B','false',4),('A','true',5),('A','false',6),('B','true',7),('B','false',8)],schema=['col_cate','target_bool','id'])
#%%
tst_res =tst.groupBy(["col_cate", "target_bool"]).count().groupBy("col_cate").pivot("target_bool").agg(F.sum("count").alias("segment"),F.sum("count").alias("dummy"))

results:

+--------+-------------+-----------+------------+----------+
|col_cate|false_segment|false_dummy|true_segment|true_dummy|
+--------+-------------+-----------+------------+----------+
|       B|            2|          2|           2|         2|
|       A|            2|          2|           2|         2|
+--------+-------------+-----------+------------+----------+

The reason for the second aggregation, dummy is also not clear to me. But the renaming happens only when there is more than one aggregation. I am investigating this. But this should work for you.

EDIT : as pointed in the comments, you need not group it twice . The below code will do the purpose: tst_res1=tst.groupby("col_cate").pivot("target_bool").agg(F.count("target_bool").alias('segment'),F.count("target_bool").alias('dummy'))

Upvotes: 1

Shubham Jain
Shubham Jain

Reputation: 5526

You can transform the column accordingly as

df.withColumn('target_bool', when(col('target_bool')=='true',lit('segment_true')).otherwise(lit('segment_false'))).\
  groupBy(["col_cate", "target_bool"]).\
  count().\
  groupBy("col_cate").\
  pivot("target_bool").\
  sum("count").show()


+--------+-------------+------------+
|col_cate|segment_false|segment_true|
+--------+-------------+------------+
|       B|            2|           2|
|       A|            2|           2|
+--------+-------------+------------+

or your sql equivalent version goes as

output_df.groupBy(["col_cate", "target_bool"]).\
  count().\
  groupBy("col_cate").\
  pivot("target_bool").\
  sum("count").select(col('col_cate'),col('true').alias('segment_true'),col('false').alias('segment_false'))

Upvotes: 2

Related Questions