Mehdi Ben Hamida
Mehdi Ben Hamida

Reputation: 1070

How to detect null column in pyspark

I have a dataframe defined with some null values. Some Columns are fully null values.

>> df.show()
+---+---+---+----+
|  A|  B|  C|   D|
+---+---+---+----+
|1.0|4.0|7.0|null|
|2.0|5.0|7.0|null|
|3.0|6.0|5.0|null|
+---+---+---+----+

In my case, I want to return a list of columns name that are filled with null values. My idea was to detect the constant columns (as the whole column contains the same null value).

this is how I did it:

nullCoulumns = [c for c, const in df.select([(min(c) == max(c)).alias(c) for c in df.columns]).first().asDict().items() if const] 

but this does no consider null columns as constant, it works only with values. How should I then do it ?

Upvotes: 5

Views: 25395

Answers (4)

s510
s510

Reputation: 2832

Time effective: Spark> 3.1 If you have more Not Null columns compared to Null columns

Generally checking if an entire col is Null is a time intensive operation, however generally calling isNotNull() is faster compared to isNull() call on NotNull columns. The following trick with limit() can fasten the approach by a lot if you expect more Not Null columns compared to Null cols.

# col_a is a Null col
st = time.time()
print("Result:", sdf.where(sdf.col_a.isNotNull()).limit(1).count())
print("Time Taken:", time.time() - st)
# Output
# Result: 0
# Time Taken: 215.386

# col_b is a Not Null col
st = time.time()
print("Result:", sdf.where(sdf.col_b.isNotNull()).limit(1).count())
print("Time Taken:", time.time() - st)
# Output
# Result: 1
# Time Taken: 7.857

Where a result of 0 means it is a null col and 1 it is otherwise. Now you can loop over all columns in the spark dataframe.

Upvotes: 1

matt
matt

Reputation: 823

How about this? In order to guarantee the column are all nulls, two properties must be satisfied:

(1) The min value is equal to the max value

(2) The min or max is null

Or, equivalently

(1) The min AND max are both equal to None

Note that if property (2) is not satisfied, the case where column values are [null, 1, null, 1] would be incorrectly reported since the min and max will be 1.

import pyspark.sql.functions as F

def get_null_column_names(df):
    column_names = []

    for col_name in df.columns:

        min_ = df.select(F.min(col_name)).first()[0]
        max_ = df.select(F.max(col_name)).first()[0]

        if min_ is None and max_ is None:
            column_names.append(col_name)

    return column_names

Here's an example in practice:

>>> rows = [(None, 18, None, None),
            (1, None, None, None),
            (1, 9, 4.0, None),
            (None, 0, 0., None)]

>>> schema = "a: int, b: int, c: float, d:int"

>>> df = spark.createDataFrame(data=rows, schema=schema)

>>> df.show()

+----+----+----+----+
|   a|   b|   c|   d|
+----+----+----+----+
|null|  18|null|null|
|   1|null|null|null|
|   1|   9| 4.0|null|
|null|   0| 0.0|null|
+----+----+----+----+

>>> get_null_column_names(df)
['d']

Upvotes: 4

desertnaut
desertnaut

Reputation: 60390

One way would be to do it implicitly: select each column, count its NULL values, and then compare this with the total number or rows. With your data, this would be:

spark.version
# u'2.2.0'

from pyspark.sql.functions import col

nullColumns = []
numRows = df.count()
for k in df.columns:
  nullRows = df.where(col(k).isNull()).count()
  if nullRows ==  numRows: # i.e. if ALL values are NULL
    nullColumns.append(k)

nullColumns
# ['D']

But there is a simpler way: it turns out that the function countDistinct, when applied to a column with all NULL values, returns zero (0):

from pyspark.sql.functions import countDistinct

df.agg(countDistinct(df.D).alias('distinct')).collect()
# [Row(distinct=0)]

So the for loop now can be:

nullColumns = []
for k in df.columns:
  if df.agg(countDistinct(df[k])).collect()[0][0] == 0:
    nullColumns.append(k)

nullColumns
# ['D']

UPDATE (after comments): It seems possible to avoid collect in the second solution; since df.agg returns a dataframe with only one row, replacing collect with take(1) will safely do the job:

nullColumns = []
for k in df.columns:
  if df.agg(countDistinct(df[k])).take(1)[0][0] == 0:
    nullColumns.append(k)

nullColumns
# ['D']

Upvotes: 6

user8996419
user8996419

Reputation: 89

Extend the condition to

from pyspark.sql.functions import min, max

((min(c).isNull() & max(c).isNull()) | (min(c) == max(c))).alias(c) 

or use eqNullSafe (PySpark 2.3):

(min(c).eqNullSafe(max(c))).alias(c) 

Upvotes: 4

Related Questions