Aesir
Aesir

Reputation: 2483

pyspark counting number of nulls per group

I have a dataframe that has time series data in it and some categorical data

| cat | TS1  | TS2  | ... |
| A   | 1    | null | ... |
| A   | 1    | 20   | ... |
| B   | null | null | ... |
| A   | null | null | ... |
| B   | 1    | 100  | ... |

I would like to find out how many null values there are per column per group, so an expected output would look something like:

| cat | TS1 | TS2 |
| A   | 1   | 2   |
| B   | 1   | 1   |

Currently I can this for one of the groups with something like this

df_null_cats = df.where(df.cat == "A").where(reduce(lambda x, y: x | y, (col(c).isNull() for c in df.columns))).select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df_nulls.columns])

but I am struggling to get one that would work for the whole dataframe.

Upvotes: 2

Views: 2286

Answers (3)

Mohana B C
Mohana B C

Reputation: 5487

You can use groupBy and aggregation function to get required output.

from pyspark.sql import *
from pyspark.sql.functions import *

spark = SparkSession.builder.master("local").getOrCreate()

# Sample dataframe
in_values = [("A", 1, None),
             ("A", 1, 20),
             ("B", None, None),
             ("A", None, None),
             ("B", 1, 100)]

in_df = spark.createDataFrame(in_values, "cat string, TS1 int, TS2 int")

columns = in_df.columns
# Ignoring groupBy column and considering cols which are required in aggregation
columns.remove("cat")
agg_expression = [sum(when(in_df[x].isNull(), 1).otherwise(0)).alias(x) for x in columns]

in_df.groupby("cat").agg(*agg_expression).show()

+---+---+---+
|cat|TS1|TS2|
+---+---+---+
|  B|  1|  1|
|  A|  1|  2|
+---+---+---+

Upvotes: 2

Ric S
Ric S

Reputation: 9277

@Mohana's answer is good but it's still not dynamic: you need to code the operation for every single column.
In my answer below, we can use Pandas UDFs and applyInPandas to write a simple function in Pandas which will then be applied to our PySpark dataframe.

import pandas as pd
from pyspark.sql.types import *


in_values = [("A", 1, None),
             ("A", 1, 20),
             ("B", None, None),
             ("A", None, None),
             ("B", 1, 100)]
df = spark.createDataFrame(in_values, "cat string, TS1 int, TS2 int")


# define output schema: same column names, but we must ensure that the output type is integer
output_schema = StructType(
  [StructField('cat', StringType())] + \
  [StructField(col, IntegerType(), True) for col in [c for c in df.columns if c.startswith('TS')]]
)


# custom Python function to define aggregations in Pandas
def null_count(pdf):
  columns = [c for c in pdf.columns if c.startswith('TS')]
  result = pdf\
    .groupby('cat')[columns]\
    .agg(lambda x: x.isnull().sum())\
    .reset_index()
  return result


# use applyInPandas
df\
  .groupby('cat')\
  .applyInPandas(null_count, output_schema)\
  .show()

+---+---+---+
|cat|TS1|TS2|
+---+---+---+
|  A|  1|  2|
|  B|  1|  1|
+---+---+---+

Upvotes: 1

pasha701
pasha701

Reputation: 7207

"Sum" function can be used with condition for null value. On Scala:

val df = Seq(
  (Some("A"), Some(1), None),
  (Some("A"), Some(1), Some(20)),
  (Some("B"), None, None),
  (Some("A"), None, None),
  (Some("B"), Some(1), Some(100)),
).toDF("cat", "TS1", "TS2")

val aggregatorColumns = df
  .columns
  .tail
  .map(columnName => sum(when(col(columnName).isNull, 1).otherwise(0)).alias(columnName))

df
  .groupBy("cat")
  .agg(
    aggregatorColumns.head, aggregatorColumns.tail: _*
  )

Upvotes: 1

Related Questions