Leonard
Leonard

Reputation: 165

Pyspark: Add a new column based on a condition and distinct values

I have a dataframe.

df = spark.createDataFrame(
    [
        ['3', '2', '3', '30', '0040'],
        ['2', '5', '7', '6', '0012'],
        ['5', '8', '1', '73', '0062'],
        ['4', '2', '5', '2', '0005'],
        ['5', '2', '4', '12', '0002'],
        ['8', '3', '2', '23', '0025'],
        ['2', '2', '8', '23', '0004'],
        ['5', '5', '4', '12', '0002'],
        ['8', '2', '2', '23', '0042'],
        ['2', '2', '8', '23', '0004']
    ],
    ['col1', 'col2', 'col3', 'col4', 'col5']
)
df.show()

I want to add a new column based on the below condition and distinct values.

cond = F.substring(F.col('col5'), 3, 1) == '0'
df1 = df.where(cond)
d_list = df1.select('col2').rdd.map(lambda x: x[0]).distinct().collect()
df2 = df.withColumn('new_col', F.when(F.col('col2').isin(d_list), F.lit('1')).otherwise('0'))
df2.show()

Result:

+----+----+----+----+----+-------+
|col1|col2|col3|col4|col5|new_col|
+----+----+----+----+----+-------+
|   3|   2|   3|  30|0040|      1|
|   2|   5|   7|   6|0012|      1|
|   5|   8|   1|  73|0062|      0|
|   4|   2|   5|   2|0005|      1|
|   5|   2|   4|  12|0002|      1|
|   8|   3|   2|  23|0025|      0|
|   2|   2|   8|  23|0004|      1|
|   5|   5|   4|  12|0002|      1|
|   8|   2|   2|  23|0042|      1|
|   2|   2|   8|  23|0004|      1|
+----+----+----+----+----+-------+

I think this way is not good for large datasets. Looking for an improved or alternative way without the 'collect()' method because of the warning: use of collect() can lead to poor spark performance

Upvotes: 2

Views: 1763

Answers (3)

anky
anky

Reputation: 75080

you can also try setting 1 where the condition is True and then partition on col2 to get max:

cond = F.substring(F.col('col5'), 3, 1) == '0' 
out = (df.withColumn("new_col",F.when(cond,1).otherwise(0))
      .withColumn("new_col",F.max("new_col").over(Window.partitionBy("col2"))))

out.show()

+----+----+----+----+----+-------+
|col1|col2|col3|col4|col5|new_col|
+----+----+----+----+----+-------+
|   3|   2|   3|  30|0040|      1|
|   4|   2|   5|   2|0005|      1|
|   5|   2|   4|  12|0002|      1|
|   2|   2|   8|  23|0004|      1|
|   8|   2|   2|  23|0042|      1|
|   2|   2|   8|  23|0004|      1|
|   8|   3|   2|  23|0025|      0|
|   2|   5|   7|   6|0012|      1|
|   5|   5|   4|  12|0002|      1|
|   5|   8|   1|  73|0062|      0|
+----+----+----+----+----+-------+

If order matters, assign a id first and then orderBy later:

cond = F.substring(F.col('col5'), 3, 1) == '0' 

out = (df.withColumn("Idx",F.monotonically_increasing_id())
       .withColumn("new_col",F.when(cond,1).otherwise(0))
       .withColumn("new_col",F.max("new_col").over(Window.partitionBy("col2")))
       .orderBy("idx").drop("idx"))

out.show()

+----+----+----+----+----+-------+
|col1|col2|col3|col4|col5|new_col|
+----+----+----+----+----+-------+
|   3|   2|   3|  30|0040|      1|
|   2|   5|   7|   6|0012|      1|
|   5|   8|   1|  73|0062|      0|
|   4|   2|   5|   2|0005|      1|
|   5|   2|   4|  12|0002|      1|
|   8|   3|   2|  23|0025|      0|
|   2|   2|   8|  23|0004|      1|
|   5|   5|   4|  12|0002|      1|
|   8|   2|   2|  23|0042|      1|
|   2|   2|   8|  23|0004|      1|
+----+----+----+----+----+-------+

Upvotes: 2

ags29
ags29

Reputation: 2696

Here is another way:

# Aggregate to get the distinct values
df_distinct = df1.groupby('col2').count()

# Join back to the orignal DF
df = df.join(df_distinct, on='col2', how='left')

# Create the required column
df = df.withColumn('new_col', F.when(F.col('count').isNotNull(), F.lit('1')).otherwise(F.lit('0')))

# drop the extraneous count column
df = df.drop('count')

You do not say how many distinct values they may be in col2, but if the number is sufficiently small, you could use a broadcast join to improve the performance.

Upvotes: 3

mck
mck

Reputation: 42332

You can add a d_list column using collect_set, and use array_contains to check whether col2 is in that column:

from pyspark.sql import functions as F, Window

df2 = df.withColumn(
    'new_col', 
    F.array_contains(
        F.collect_set(
            F.when(
                F.substring(F.col('col5'), 3, 1) == '0', 
                F.col('col2')
            )
        ).over(Window.partitionBy(F.lit(1))), 
        F.col('col2')
    ).cast('int')
)

df2.show()
+----+----+----+----+----+-------+
|col1|col2|col3|col4|col5|new_col|
+----+----+----+----+----+-------+
|   3|   2|   3|  30|0040|      1|
|   2|   5|   7|   6|0012|      1|
|   5|   8|   1|  73|0062|      0|
|   4|   2|   5|   2|0005|      1|
|   5|   2|   4|  12|0002|      1|
|   8|   3|   2|  23|0025|      0|
|   2|   2|   8|  23|0004|      1|
|   5|   5|   4|  12|0002|      1|
|   8|   2|   2|  23|0042|      1|
|   2|   2|   8|  23|0004|      1|
+----+----+----+----+----+-------+

Upvotes: 3

Related Questions