User12345
User12345

Reputation: 5480

Pivot values to existing columns in PySpark data frame

I have a data frame like below in pyspark.

+---+-------------+----+
| id|       device| val|
+---+-------------+----+
|  3|      mac pro|   1|          
|  1|       iphone|   2|
|  1|android phone|   2|
|  1|   windows pc|   2|
|  1|   spy camera|   2|
|  2|   spy camera|   3|
|  2|       iphone|   3|
|  3|   spy camera|   1|
|  3|         cctv|   1|
+---+-------------+----+

I want to populate some columns based on the below lists

phone_list = ['iphone', 'android phone', 'nokia']
pc_list = ['windows pc', 'mac pro']
security_list = ['spy camera']
ucg_list = ['ipad']

I have done like below.

from pyspark.sql.functions import col, when, lit 
from pyspark.sql.types import IntegerType
df1 = df.withColumn('phones', lit(None).cast(IntegerType())).withColumn('pc', lit(None).cast(IntegerType())).withColumn('security', lit(None).cast(IntegerType())).withColumn('null', lit(None).cast(IntegerType())).withColumn('ucg', lit(None).cast(IntegerType()))

import pyspark.sql.functions as F

df1.withColumn('cat', 
    F.when(df1.device.isin(phone_list), 'phones').otherwise(
    F.when(df1.device.isin(pc_list), 'pc').otherwise(
    F.when(df1.device.isin(security_list), 'security')))
).groupBy('id', 'phones', 'pc', 'security', 'null', 'ucg').pivot('cat').agg(F.count('cat')).show()

output I am receiving

+---+------+----+--------+----+----+----+---+------+--------+
| id|phones|  pc|security|null| ucg|null| pc|phones|security|
+---+------+----+--------+----+----+----+---+------+--------+
|  3|  null|null|    null|null|null|   0|  1|     0|       1|
|  2|  null|null|    null|null|null|   0|  0|     1|       1|
|  1|  null|null|    null|null|null|   0|  1|     2|       1|
+---+------+----+--------+----+----+----+---+------+--------+

What I want is to create columns first based on the list names and then populate values

expected output

+---+------+---+------+--------+----+
| id|   ucg| pc|phones|security|null|
+---+------+---+------+--------+----+
|  1|     0|  1|     2|       1|   0|
|  2|     0|  0|     1|       1|   0|
|  3|     0|  1|     0|       1|   1|
+---+------+---+------+--------+----+

How can I get what I want?

Edit

when I do the below

df1 = df.withColumn('cat', 
    f.when(df.device.isin(phone_list), 'phones').otherwise(
    f.when(df.device.isin(pc_list), 'pc').otherwise(
    f.when(df.device.isin(ucg_list), 'ucg').otherwise(
    f.when(df.device.isin(security_list), 'security')))))

The output is

+---+-------------+---+--------+
| id|       device|val|     cat|
+---+-------------+---+--------+
|  3|      mac pro|  1|      pc|
|  3|   spy camera|  1|security|
|  3|         cctv|  1|    null|
|  1|       iphone|  2|  phones|
|  1|android phone|  2|  phones|
|  1|   windows pc|  2|      pc|
|  1|   spy camera|  2|security|
|  2|   spy camera|  3|security|
|  2|       iphone|  3|  phones|
+---+-------------+---+--------+

In the output you can see that id 3 has a null value in cat column

Upvotes: 0

Views: 154

Answers (1)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

Creating and populating None for 'phones', 'pc', 'ucg', 'security', 'null' columns just for groupBy doesn't make sense. Grouping with id and all of the above columns with null or grouping by only id, both are same.

What you can do instead is find the difference between actual pivoted columns and the intended columns and then create and populate with 0

So the following should work for you

phone_list = ['iphone', 'android phone', 'nokia']
pc_list = ['windows pc', 'mac pro']
security_list = ['spy camera']
ucg_list = ['ipad']

from pyspark.sql import functions as f
df = df.withColumn('cat',
               f.when(df.device.isin(phone_list), 'phones').otherwise(
                 f.when(df.device.isin(pc_list), 'pc').otherwise(
                   f.when(df.device.isin(ucg_list), 'ucg').otherwise(
                     f.when(df.device.isin(security_list), 'security'))))
               )\
    .groupBy('id').pivot('cat').agg(f.count('val'))\
    .na.fill(0)\

columnList = ['phones', 'pc', 'ucg', 'security', 'null']
actualcolumnList = df.columns[1:]

diffColumns = [x for x in columnList if x not in actualcolumnList]

for y in diffColumns:
    df = df.withColumn(y, f.lit(0))

df.show(truncate=False)

which should give you

+---+----+---+------+--------+---+
|id |null|pc |phones|security|ucg|
+---+----+---+------+--------+---+
|3  |1   |1  |0     |1       |0  |
|1  |0   |1  |2     |1       |0  |
|2  |0   |0  |1     |1       |0  |
+---+----+---+------+--------+---+

I hope the answer is helpful

Upvotes: 1

Related Questions