Marcos Dias
Marcos Dias

Reputation: 450

How to remove nulls from a pyspark dataframe based on conditions?

Let's say I have a table with the ids of my clients, which streaming platform they are subscribers to and how often they pay for their subscription:

user_info
+------+------------------+---------------------+
| id   | subscription_plan| payment_frequency   |
+------+------------------+---------------------+
| 3004 |   Netflix        | Monthly             |
| 3004 |   Disney +       | Monthly             |
| 3004 |   Netflix        | Null                |
| 3006 |   Star +         | Yearly              |
| 3006 |   Apple TV       | Yearly              |
| 3006 |   Netflix        | Monthly             |
| 3006 |   Star +         | Null                |
| 3009 |   Apple TV       | Null                |
| 3009 |   Star +         | Monthly             |
+------+------------------+---------------------+

The problem is that I have some duplicate values, and I need to get rid of the ones that are duplicate and where the status on the payment_frequency is null. If payment_frequency is null but the record is not duplicated, this is fine, like for example ID 3009 for Apple TV.

I could simple remove all the nulls from the payment_frequency table, but that's not the ideal, as the only reason where a null is worthless for me is when it's coming from a duplicated id and subscription_plan. How do I make sure I get rid of the nulls if they match those requirements?

The result I need:

user_info
+------+------------------+---------------------+
| id   | subscription_plan| payment_frequency   |
+------+------------------+---------------------+
| 3004 |   Netflix        | Monthly             |
| 3004 |   Disney +       | Monthly             |
| 3006 |   Star +         | Yearly              |
| 3006 |   Apple TV       | Yearly              |
| 3006 |   Netflix        | Monthly             |
| 3009 |   Apple TV       | Null                |
| 3009 |   Star +         | Monthly             |
+------+------------------+---------------------+

Thanks

Upvotes: 0

Views: 125

Answers (2)

wwnde
wwnde

Reputation: 26676

    (df.withColumn('x',count('payment_frequency').over(Window.partitionBy('id').orderBy('id','subscription_plan',df.payment_frequency.desc_nulls_last())))#Order and count per group
.where(#Filter
  ((col('x')==0)&(col('payment_frequency').isNull()))#When null and equal to 0
  |(col('x')>=0)&(col('payment_frequency').isNotNull()))#When not null
 .drop('x')#Drop unwanted column
).show()


+----+-----------------+-----------------+
|  id|payment_frequency|subscription_plan|
+----+-----------------+-----------------+
|3004|          Monthly|         Disney +|
|3004|          Monthly|          Netflix|
|3006|           Yearly|         Apple TV|
|3006|          Monthly|          Netflix|
|3006|           Yearly|           Star +|
|3009|             null|         Apple TV|
|3009|          Monthly|           Star +|
+----+-----------------+-----------------+

Upvotes: 0

M_S
M_S

Reputation: 3733

I think that you may try something like this.

I am grouping by ("id", "subscription_plan") and sorting by payment_frequency desc which moves nulls to last position in group. I am using this to keep null values only when they are on first position in the group, other nulls are dropped

import pyspark.sql.functions as F
from pyspark.sql import Window

data = [
    {"id": 3004, "subscription_plan": "Netflix", "payment_frequency": "Monthly"},
    {"id": 3004, "subscription_plan": "Disney +", "payment_frequency": "Monthly"},
    {"id": 3004, "subscription_plan": "Netflix", "payment_frequency": None},
    {"id": 3006, "subscription_plan": "Star +", "payment_frequency": "Yearly"},
    {"id": 3006, "subscription_plan": "Apple TV", "payment_frequency": "Yearly"},
    {"id": 3006, "subscription_plan": "Netflix", "payment_frequency": "Monthly"},
    {"id": 3006, "subscription_plan": "Star +", "payment_frequency": None},
    {"id": 3009, "subscription_plan": "Apple TV", "payment_frequency": None},
    {"id": 3009, "subscription_plan": "Star +", "payment_frequency": "Monthly"},
]

df = spark.createDataFrame(data)

windowSpec = Window.partitionBy("id", "subscription_plan").orderBy(
    F.col("payment_frequency").desc()
)

dfWithRowNumber = df.withColumn("row_number", F.row_number().over(windowSpec))
dfWithRowNumber.filter(
    F.col("payment_frequency").isNotNull()
    | ((F.col("row_number") == F.lit(1)) & F.col("payment_frequency").isNull())
).drop("row_number").show()

output

+----+-----------------+-----------------+
|  id|payment_frequency|subscription_plan|
+----+-----------------+-----------------+
|3004|          Monthly|         Disney +|
|3004|          Monthly|          Netflix|
|3006|           Yearly|         Apple TV|
|3006|          Monthly|          Netflix|
|3006|           Yearly|           Star +|
|3009|             null|         Apple TV|
|3009|          Monthly|           Star +|
+----+-----------------+-----------------+

Upvotes: 1

Related Questions