ZygD
ZygD

Reputation: 24498

Joining PySpark dataframes with conditional result column

I have these tables:

df1                  df2
+---+------------+   +---+---------+
| id|   many_cols|   | id|criterion|
+---+------------+   +---+---------+
|  1|lots_of_data|   |  1|    false|
|  2|lots_of_data|   |  1|     true|
|  3|lots_of_data|   |  1|     true|
+---+------------+   |  3|    false|
                     +---+---------+

I intend to create additional column in df1:

+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  2|lots_of_data|  null|
|  3|lots_of_data|     0|
+---+------------+------+

result should be 1 if there is a corresponding true in df2
result should be 0 if there's no corresponding true in df2
result should be null if there is no corresponding id in df2

I cannot think of an efficient way to do it. I am stuck with only the 3rd condition working after a join:

df = df1.join(df2, 'id', 'full')
df.show()

#  +---+------------+---------+
#  | id|   many_cols|criterion|
#  +---+------------+---------+
#  |  1|lots_of_data|    false|
#  |  1|lots_of_data|     true|
#  |  1|lots_of_data|     true|
#  |  3|lots_of_data|    false|
#  |  2|lots_of_data|     null|
#  +---+------------+---------+

PySpark dataframes are created like this:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()

df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
           (2, 'lots_of_data'),
           (3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
           (1, True),
           (1, True),
           (3, None)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)

Upvotes: 2

Views: 558

Answers (4)

ZygD
ZygD

Reputation: 24498

I had to merge the ideas of proposed answers to get the solution which suited me most.

# The `cond` variable is very useful, here it represents several complex conditions
cond = F.col('criterion') == True
df2_grp = df2.select(
    'id',
    F.when(cond, 1).otherwise(0).alias('c')
).groupBy('id').agg(F.max(F.col('c')).alias('result'))
df = df1.join(df2_grp, 'id', 'left')

df.show()
#+---+------------+------+
#| id|   many_cols|result|
#+---+------------+------+
#|  1|lots_of_data|     1|
#|  3|lots_of_data|     0|
#|  2|lots_of_data|  null|
#+---+------------+------+

Upvotes: 1

mck
mck

Reputation: 42422

You can try a correlated subquery to get the maximum Boolean from df2, and cast that to an integer.

df1.createOrReplaceTempView('df1') 
df2.createOrReplaceTempView('df2') 

df = spark.sql("""
    select
        df1.*,
        (select int(max(criterion)) from df2 where df1.id = df2.id) as result
    from df1
""")

df.show()
+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  3|lots_of_data|     0|
|  2|lots_of_data|  null|
+---+------------+------+

Upvotes: 2

blackbishop
blackbishop

Reputation: 32720

A simple way would be to groupby df2 to get the max criterion by id the join with df1, this way you reduce the number of lines to join. The max of a boolean column is true if there is at least one corresponding true value:

from pyspark.sql import functions as F

df2_group = df2.groupBy("id").agg(F.max("criterion").alias("criterion"))

result = df1.join(df2_group, ["id"], "left").withColumn(
    "result",
    F.col("criterion").cast("int")
).drop("criterion")

result.show()
#+---+------------+------+
#| id|   many_cols|result|
#+---+------------+------+
#|  1|lots_of_data|     1|
#|  3|lots_of_data|     0|
#|  2|lots_of_data|  null|
#+---+------------+------+

Upvotes: 2

kites
kites

Reputation: 1405

check out this solution. After joining. you can use multiple condition checks based on your requirement and assign the value accordingly using when clause and then take the max value of result grouping by id and other columns. you can use window function as well to calculate the max of result if you are just using just id for the partition.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

df1cols = ['id', 'many_cols']
df1data = [(1, 'lots_of_data'),
           (2, 'lots_of_data'),
           (3, 'lots_of_data')]
df2cols = ['id', 'criterion']
df2data = [(1, False),
           (1, True),
           (1, True),
           (3, False)]
df1 = spark.createDataFrame(df1data, df1cols)
df2 = spark.createDataFrame(df2data, df2cols)

df2_mod =df2.withColumnRenamed("id", "id_2")

df3=df1.join(df2_mod, on=df1.id== df2_mod.id_2, how='left')

cond1 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==1)
cond2 = (F.col("id")== F.col("id_2"))& (F.col("criterion")==0)
cond3 = (F.col("id_2").isNull())

df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
    .groupBy("id", "many_cols").agg(F.max(F.col("result")).alias("result")).orderBy("id").show()

Result:
------

+---+------------+------+
| id|   many_cols|result|
+---+------------+------+
|  1|lots_of_data|     1|
|  2|lots_of_data|  null|
|  3|lots_of_data|     0|
+---+------------+------+

Using window function

w=Window().partitionBy("id")

df3.select("id", "many_cols", F.when(cond1, 1).when(cond2,0).when(cond3, F.lit(None)).alias("result"))\
    .select("id", "many_cols", F.max("result").over(w).alias("result")).drop_duplicates().show()

Upvotes: 1

Related Questions