David
David

Reputation: 89

Determine if pyspark DataFrame row value is present in other columns

I'm working with a dataframe in pyspark, and need to evaluate row by row if a value is present in other columns of the dataframe. As an example, given this dataframe:

df:

+---------+--------------+-------+-------+-------+
|Subject  |SubjectTotal  |TypeA  |TypeB  |TypeC  |
+---------+--------------+-------+-------+-------+
|Subject1 |10            |5      |3      |2      |
+---------+--------------+-------+-------+-------+
|Subject2 |15            |0      |15     |0      |
+---------+--------------+-------+-------+-------+
|Subject3 |5             |0      |0      |5      |
+---------+--------------+-------+-------+-------+

As an output, I need to determine which Type has 100% of the SubjectTotal. So my output would look like this:

df_output:

+---------+--------------+
|Subject  |Type          |
+---------+--------------+
|Subject2 |TypeB         |
+---------+--------------+
|Subject3 |TypeC         |
+---------+--------------+

Is it even possible?

Thanks!

Upvotes: 0

Views: 1480

Answers (3)

blackbishop
blackbishop

Reputation: 32640

You can use when expression within a list comprehension over all columns TypeX, then coalesce the list of expressions:

from pyspark.sql import functions as F

df1 = df.select(
    F.col("Subject"),
    F.coalesce(*[F.when(F.col(c) == F.col("SubjectTotal"), F.lit(c)) for c in df.columns[2:]]).alias("Type")
).filter("Type is not null")

df1.show()
#+--------+-----+
#| Subject| Type|
#+--------+-----+
#|Subject2|TypeB|
#|Subject3|TypeC|
#+--------+-----+

Upvotes: 1

mck
mck

Reputation: 42332

You can unpivot the dataframe using stack and filter the rows where SubjectTotal is equal to the value in the type columns:

df2 = df.selectExpr(
    'Subject', 
    'SubjectTotal', 
    "stack(3, 'TypeA', TypeA, 'TypeB', TypeB, 'TypeC', TypeC) as (type, val)"
).filter('SubjectTotal = val').select('Subject', 'type')

df2.show()
+--------+-----+
| Subject| type|
+--------+-----+
|Subject2|TypeB|
|Subject3|TypeC|
+--------+-----+

Upvotes: 1

Mageswaran
Mageswaran

Reputation: 450

Yo can try with when().otherwise() PySpark SQL function or case statement in SQL

import pyspark.sql.functions as F
df = spark.createDataFrame(
[
  ("Subject1", 10, 5, 3, 2),
  ("Subject2", 15, 0, 15, 0),
  ("Subject3", 5, 0, 0, 5)
],
("subject", "subjectTotal", "TypeA", "TypeB", "TypeC"))
df.show()

+--------+------------+-----+-----+-----+
| subject|subjectTotal|TypeA|TypeB|TypeC|
+--------+------------+-----+-----+-----+
|Subject1|          10|    5|    3|    2|
|Subject2|          15|    0|   15|    0|
|Subject3|           5|    0|    0|    5|
+--------+------------+-----+-----+-----+

df.withColumn("Type", F.
              when(F.col("subjectTotal") == F.col("TypeA"), "TypeA").
             when(F.col("subjectTotal") == F.col("TypeB"), "TypeB").
             when(F.col("subjectTotal") == F.col("TypeC"), "TypeC").
             otherwise(None)).show()

+--------+------------+-----+-----+-----+-----+
| subject|subjectTotal|TypeA|TypeB|TypeC| Type|
+--------+------------+-----+-----+-----+-----+
|Subject1|          10|    5|    3|    2| null|
|Subject2|          15|    0|   15|    0|TypeB|
|Subject3|           5|    0|    0|    5|TypeC|
+--------+------------+-----+-----+-----+-----+

Upvotes: 2

Related Questions