avloss
avloss

Reputation: 2646

Comparison of a `float` to `np.nan` in Spark Dataframe

Is this expected behaviour? I thought to raise an issue with Spark, but this seems such a basic functionality, that it's hard to imagine that there's a bug here. What am I missing?

Python

import numpy as np

>>> np.nan < 0.0
False

>>> np.nan > 0.0
False

PySpark

from pyspark.sql.functions import col

df = spark.createDataFrame([(np.nan, 0.0),(0.0, np.nan)])
df.show()
#+---+---+
#| _1| _2|
#+---+---+
#|NaN|0.0|
#|0.0|NaN|
#+---+---+

df.printSchema()
#root
# |-- _1: double (nullable = true)
# |-- _2: double (nullable = true)

df.select(col("_1")> col("_2")).show()
#+---------+
#|(_1 > _2)|
#+---------+
#|     true|
#|    false|
#+---------+

Upvotes: 8

Views: 3634

Answers (1)

user10938362
user10938362

Reputation: 4151

That is both expected and documented behavior. To quote NaN Semantics section of the official Spark SQL Guide (emphasis mine):

There is specially handling for not-a-number (NaN) when dealing with float or double types that does not exactly match standard floating point semantics. Specifically:

  • NaN = NaN returns true.
  • In aggregations, all NaN values are grouped together.
  • NaN is treated as a normal value in join keys.
  • NaN values go last when in ascending order, larger than any other numeric value.

AdAs you see ordering behavior is not the only difference, compared to Python NaN. In particular Spark considers NaN's equal:

spark.sql("""
    WITH table AS (SELECT CAST('NaN' AS float) AS x, cast('NaN' AS float) AS y) 
    SELECT x = y, x != y FROM table
""").show()
+-------+-------------+
|(x = y)|(NOT (x = y))|
+-------+-------------+
|   true|        false|
+-------+-------------+

while plain Python

float("NaN") == float("NaN"), float("NaN") != float("NaN")
(False, True)

and NumPy

np.nan == np.nan, np.nan != np.nan
(False, True)

don't.

You can check eqNullSafe docstring for additional examples.

So to get desired result you'll have to explicitly check for NaN's

from pyspark.sql.functions import col, isnan, when

when(isnan("_1") | isnan("_2"), False).otherwise(col("_1") > col("_2"))

Upvotes: 9

Related Questions