User12345
User12345

Reputation: 5480

Create new pyspark DataFrame column by concatenating values of another column based on a conditional

I have a data frame in pyspark like below

df.show()

+-------+--------------------+--------------------+
| Dev_No|               model|              Tested|
+-------+--------------------+--------------------+
|BTA16C5|          Windows PC|                   N|
|BTA16C5|                 SRL|                   N|
|BTA16C5|     Hewlett Packard|                   N|
|CTA16C5|     Android Devices|                   Y|
|CTA16C5|     Hewlett Packard|                   N|
|4MY16A5|               Other|                   N|
|4MY16A5|               Other|                   N|
|4MY16A5|              Tablet|                   Y|
|4MY16A5|               Other|                   N|
|4MY16A5|           Cable STB|                   Y|
|4MY16A5|               Other|                   N|
|4MY16A5|          Windows PC|                   Y|
|4MY16A5|          Windows PC|                   Y|
|4MY16A5|         Smart Watch|                   Y|
+-------+--------------------+--------------------+

Now using the above data frame I want to create the below data frame with a newcolumn called Tested_devices and populate the column with values where for each Dev_No select model where Tested is Y and populate all the values as comma separated.

df1.show()

+-------+--------------------+--------------------+------------------------------------------------------+
| Dev_No|               model|              Tested|                                        Tested_devices|
+-------+--------------------+--------------------+------------------------------------------------------+
|BTA16C5|          Windows PC|                   N|                                                      |
|BTA16C5|                 SRL|                   N|                                                      |  
|BTA16C5|     Hewlett Packard|                   N|                                                      |
|CTA16C5|     Android Devices|                   Y|                                       Android Devices|
|CTA16C5|     Hewlett Packard|                   N|                                                      |      
|4MY16A5|               Other|                   N|                                                      |
|4MY16A5|               Other|                   N|                                                      |
|4MY16A5|              Tablet|                   Y| Tablet, Cable STB,Windows PC, Windows PC, Smart Watch| 
|4MY16A5|               Other|                   N|                                                      |
|4MY16A5|           Cable STB|                   Y| Tablet, Cable STB,Windows PC, Windows PC, Smart Watch|
|4MY16A5|               Other|                   N|                                                      |
|4MY16A5|          Windows PC|                   Y| Tablet, Cable STB,Windows PC, Windows PC, Smart Watch|
|4MY16A5|          Windows PC|                   Y| Tablet, Cable STB,Windows PC, Windows PC, Smart Watch|
|4MY16A5|         Smart Watch|                   Y| Tablet, Cable STB,Windows PC, Windows PC, Smart Watch|
+-------+--------------------+--------------------+------------------------------------------------------+

I tried something like below to select Dev_No and model where Tested is Y

a = df.select("Dev_No", "model"), when(df.Tested == 'Y')

I am unable to get the result. It gave me below error

TypeError: when() takes exactly 2 arguments (1 given)

How can I achieve what I want

Upvotes: 2

Views: 2263

Answers (2)

pault
pault

Reputation: 43494

Update

For spark 1.6, you will need an alternative approach. One way to do this without using a udf or any Window functions is to create a second temporary DataFrame with the collected values and then join this back to the original DataFrame.

First group by both Dev_No and Tested and aggregate using concat_ws and collect_list. After aggregation, filter the DataFrame for tested devices only.

import pyspark.sql.functions as f

# create temporary DataFrame
df2 = df.groupBy('Dev_No', 'Tested')\
    .agg(f.concat_ws(", ", f.collect_list('model')).alias('Tested_devices'))\
    .where(f.col('Tested') == 'Y')

df2.show(truncate=False)
#+-------+------+------------------------------------------------------+
#|Dev_No |Tested|Tested_devices                                        |
#+-------+------+------------------------------------------------------+
#|CTA16C5|Y     |Android Devices                                       |
#|4MY16A5|Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#+-------+------+------------------------------------------------------+

Now do a left join of df with df2 using both the Dev_No and Tested columns as the join keys:

df.join(df2, on=['Dev_No', 'Tested'], how='left')\
    .select('Dev_No', 'model', 'Tested', 'Tested_devices')\
    .show(truncate=False)

The purpose of using the select at the end is to get the columns in the same order as the original DataFrame for display purposes- you can remove this step if you choose.

This will result in the following output (same output as below (with the concat_ws):

#+-------+---------------+------+------------------------------------------------------+
#|Dev_No |model          |Tested|Tested_devices                                        |
#+-------+---------------+------+------------------------------------------------------+
#|4MY16A5|Other          |N     |null                                                  |
#|4MY16A5|Other          |N     |null                                                  |
#|4MY16A5|Other          |N     |null                                                  |
#|4MY16A5|Other          |N     |null                                                  |
#|CTA16C5|Hewlett Packard|N     |null                                                  |
#|BTA16C5|Windows PC     |N     |null                                                  |
#|BTA16C5|SRL            |N     |null                                                  |
#|BTA16C5|Hewlett Packard|N     |null                                                  |
#|CTA16C5|Android Devices|Y     |Android Devices                                       |
#|4MY16A5|Tablet         |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Cable STB      |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Smart Watch    |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#+-------+---------------+------+------------------------------------------------------+

Original Answer: (For later versions of Spark)

You can achieve this by using two pyspark.sql.functions.when() statements- one of them within a call to pyspark.sql.functions.collect_list() over a Window, taking advantage of the fact that the default null value does not get added to the list:

from pyspark.sql import Window
import pyspark.sql.functions as f

df.select(
    "*",
    f.when(
        f.col("Tested") == "Y",
        f.collect_list(
            f.when(
                f.col("Tested") == "Y",
                f.col('model')
            )
        ).over(Window.partitionBy("Dev_No"))
    ).alias("Tested_devices")
).show(truncate=False)
#+-------+---------------+------+--------------------------------------------------------+
#|Dev_No |model          |Tested|Tested_devices                                          |
#+-------+---------------+------+--------------------------------------------------------+
#|BTA16C5|Windows PC     |N     |null                                                    |
#|BTA16C5|SRL            |N     |null                                                    |
#|BTA16C5|Hewlett Packard|N     |null                                                    |
#|4MY16A5|Other          |N     |null                                                    |
#|4MY16A5|Other          |N     |null                                                    |
#|4MY16A5|Tablet         |Y     |[Tablet, Cable STB, Windows PC, Windows PC, Smart Watch]|
#|4MY16A5|Other          |N     |null                                                    |
#|4MY16A5|Cable STB      |Y     |[Tablet, Cable STB, Windows PC, Windows PC, Smart Watch]|
#|4MY16A5|Other          |N     |null                                                    |
#|4MY16A5|Windows PC     |Y     |[Tablet, Cable STB, Windows PC, Windows PC, Smart Watch]|
#|4MY16A5|Windows PC     |Y     |[Tablet, Cable STB, Windows PC, Windows PC, Smart Watch]|
#|4MY16A5|Smart Watch    |Y     |[Tablet, Cable STB, Windows PC, Windows PC, Smart Watch]|
#|CTA16C5|Android Devices|Y     |[Android Devices]                                       |
#|CTA16C5|Hewlett Packard|N     |null                                                    |
#+-------+---------------+------+--------------------------------------------------------+

If instead you wanted your output exactly as you showed in your question- as a string of comma separated values instead of a list and empty strings instead of null- you could modify this slightly as follows:

Use pyspark.sql.functions.concat_ws to concatenate the output of collect_list into a string. I'm using ", " as the separator. This is equivalent to doing ", ".join(some_list) in python. Next, we add a .otherwise(f.lit("")) to the end of the outer when() call to specify that we want to return a literal empty string if the condition is False.

df.select(
    "*",
    f.when(
        f.col("Tested") == "Y",
        f.concat_ws(
            ", ",
            f.collect_list(
                f.when(
                    f.col("Tested") == "Y",
                    f.col('model')
                )
            ).over(Window.partitionBy("Dev_No"))
        )
    ).otherwise(f.lit("")).alias("Tested_devices")
).show(truncate=False)
#+-------+---------------+------+------------------------------------------------------+
#|Dev_No |model          |Tested|Tested_devices                                        |
#+-------+---------------+------+------------------------------------------------------+
#|BTA16C5|Windows PC     |N     |                                                      |
#|BTA16C5|SRL            |N     |                                                      |
#|BTA16C5|Hewlett Packard|N     |                                                      |
#|4MY16A5|Other          |N     |                                                      |
#|4MY16A5|Other          |N     |                                                      |
#|4MY16A5|Tablet         |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Other          |N     |                                                      |
#|4MY16A5|Cable STB      |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Other          |N     |                                                      |
#|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|4MY16A5|Smart Watch    |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
#|CTA16C5|Android Devices|Y     |Android Devices                                       |
#|CTA16C5|Hewlett Packard|N     |                                                      |
#+-------+---------------+------+------------------------------------------------------+

Using pyspark-sql syntax, the first example above is equivalent to:

df.registerTempTable("df")
query = """
 SELECT *, 
        CASE 
          WHEN Tested = 'Y' 
          THEN COLLECT_LIST(
            CASE 
              WHEN Tested = 'Y' 
              THEN model
            END
          ) OVER (PARTITION BY Dev_No) 
        END AS Tested_devices
   FROM df
"""
sqlCtx.sql(query).show(truncate=False)

Upvotes: 1

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

commented for clarity and explanation

pyspark > 1.6

#window function to group by Dev_No
from pyspark.sql import Window
windowSpec = Window.partitionBy("Dev_No")

from pyspark.sql import functions as f
from pyspark.sql import types as t
#udf function to change the collected list to string and also to check if Tested column is Y or N
@f.udf(t.StringType())
def populatedUdfFunc(tested, list):
    if(tested == "Y"):
        return ", ".join(list)
    else:
        return ""
#collecting models when Tested is Y using window function defined above
df.withColumn("Tested_devices", populatedUdfFunc(f.col("Tested"), f.collect_list(f.when(f.col("Tested") == "Y", f.col("model")).otherwise(None)).over(windowSpec))).show(truncate=False)

which should give you

+-------+---------------+------+------------------------------------------------------+
|Dev_No |model          |Tested|Tested_devices                                        |
+-------+---------------+------+------------------------------------------------------+
|BTA16C5|Windows PC     |N     |                                                      |
|BTA16C5|SRL            |N     |                                                      |
|BTA16C5|Hewlett Packard|N     |                                                      |
|4MY16A5|Other          |N     |                                                      |
|4MY16A5|Other          |N     |                                                      |
|4MY16A5|Tablet         |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
|4MY16A5|Other          |N     |                                                      |
|4MY16A5|Cable STB      |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
|4MY16A5|Other          |N     |                                                      |
|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
|4MY16A5|Windows PC     |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
|4MY16A5|Smart Watch    |Y     |Tablet, Cable STB, Windows PC, Windows PC, Smart Watch|
|CTA16C5|Android Devices|Y     |Android Devices                                       |
|CTA16C5|Hewlett Packard|N     |                                                      |
+-------+---------------+------+------------------------------------------------------+

spark = 1.6

for pyspark 1.6, collect_list won't work with window function and there is no collect_list function defined in SqlContext. So you will have to do without window function and use HiveContext instead of SQLContext

from pyspark.sql import functions as f
from pyspark.sql import types as t
#udf function to change the collected list to string and also to check if Tested column is Y or N
def populatedUdfFunc(list):
    return ", ".join(list)

populateUdf = f.udf(populatedUdfFunc, t.StringType())

#collecting models when Tested is Y using window function defined above
tempdf = df.groupBy("Dev_No").agg(populateUdf(f.collect_list(f.when(f.col("Tested") == "Y", f.col("model")).otherwise(None))).alias("Tested_devices"))
df.join(
    tempdf,
    (df["Dev_No"] == tempdf["Dev_No"]) & (df["Tested"] == f.lit("Y")), "left").show(truncate=False)

You would get the same output as above

Upvotes: 1

Related Questions