Weishi Z
Weishi Z

Reputation: 1748

Scala Spark handles Double.NaN differently in dataframe and dataset

In tests, I'm trying to convert dataframe/datasets into sets and compare it. E.g.

actualResult.collect.toSet should be(expectedResult.collect.toSet)

I noticed some facts regarding Double.NaN value.

  1. In Scala, Double.NaN == Double.NaN returns false.
  2. In spark NaN == NaN is true. (offical doc)

But I couldn't figure out why dataframe and dataset behaves differently.

import org.apache.spark.sql.SparkSession

object Main extends App {
  val spark = SparkSession.builder().appName("Example").master("local").getOrCreate()
  import spark.implicits._

  val dataSet = spark.createDataset(Seq(Book("book 1", Double.NaN)))

  // Compare Set(Book(book 1,NaN)) to itself
  println(dataSet.collect.toSet == dataSet.collect.toSet) //false, why?

  // Compare Set([book 1,NaN]) to itself
  println(dataSet.toDF().collect.toSet == dataSet.toDF().collect.toSet) //true, why?
}

case class Book (title: String, price: Double)

Here's my question. Appreciate any insights.

  1. How does it happen in code? (where the equals gets overridden? etc..)
  2. Any reasons behind this design? Is there a better paradigm to assert dataset/dataframe in tests?

Upvotes: 2

Views: 629

Answers (1)

kavetiraviteja
kavetiraviteja

Reputation: 2208

I have a few points which I want to share related to this topic.

  1. When you do dataSet.collect.toSet you collect it as Set[Book] and when you do a comparison between two sets of book objects.

The individual (book)objects equal method is used for comparison which you define in Book Case class. that is why println(dataSet.collect.toSet == dataSet.collect.toSet) returned false because of Double.NaN == Double.NaN returns false.

  1. When you do dataSet.toDF().collect.toSet you collect it as Set[Row]

when you do toDF spark will convert**(i.e serialize Book then deserialize to javaType fields Row)** Book class to Row in this process it also does some conversions on fields using RowEncoders.

All the Primitive fields are converted to java types using the below code in RowEncoder.scala

def apply(schema: StructType): ExpressionEncoder[Row] = {
    val cls = classOf[Row]
    **val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
    val deserializer = deserializerFor(schema)**
    new ExpressionEncoder[Row](
      schema,
      flat = false,
      serializer.asInstanceOf[CreateNamedStruct].flatten,
      deserializer,
      ClassTag(cls))
  }

if you check the source code of Double.java and Float.java equal method. comparison of NAN will return true. that is why Row objects comparison will return true. and println(dataSet.toDF().collect.toSet == dataSet.toDF().collect.toSet) is true.

<li>If {@code d1} and {@code d2} both represent
     *     {@code Double.NaN}, then the {@code equals} method
     *     returns {@code true}, even though
     *     {@code Double.NaN==Double.NaN} has the value
     *     {@code false}.
     * <li>If {@code d1} represents {@code +0.0} while
     *     {@code d2} represents {@code -0.0}, or vice versa,
     *     the {@code equal} test has the value {@code false},
     *     even though {@code +0.0==-0.0} has the value {@code true}.
     * </ul>

**Sorry if I'm grammatically wrong.

Upvotes: 3

Related Questions