Raphael Roth
Raphael Roth

Reputation: 27373

Spark replace all NaNs to null in DataFrame API

I have a dataframe with many double (and/or float) columns, which do contain NaNs. I want to replace all NaNs (i.e. Float.NaN and Double.NaN) with null.

I can do this with e.g. for a single column x:

val newDf = df.withColumn("x", when($"x".isNaN,lit(null)).otherwise($"x"))

This works but I'd like to do this for all columns at once. I recently discovered the DataFrameNAFunctions (df.na) fill which sounds exactely what I need. Unfortunately I failed to do the above. fill should replace all NaNs and nulls with a given value, so I do:

df.na.fill(null.asInstanceOf[java.lang.Double]).show

which gives me a NullpointerException

There is also a promising replace method, but I cant even compile the code:

df.na.replace("x", Map(java.lang.Double.NaN -> null.asInstanceOf[java.lang.Double])).show

strangely, this gives me

Error:(57, 34) type mismatch;
 found   : scala.collection.immutable.Map[scala.Double,java.lang.Double]
 required: Map[Any,Any]
Note: Double <: Any, but trait Map is invariant in type A.
You may wish to investigate a wildcard type such as `_ <: Any`. (SLS 3.2.10)
    df.na.replace("x", Map(java.lang.Double.NaN -> null.asInstanceOf[java.lang.Double])).show

Upvotes: 3

Views: 10074

Answers (2)

Suraj Bansal
Suraj Bansal

Reputation: 11

To Replace all NaN by any value in Spark Dataframe using Pyspark API you can do the following:

col_list = [column1, column2] df = df.na.fill(replace_by_value, col_list)

Upvotes: 1

himanshuIIITian
himanshuIIITian

Reputation: 6085

To replace all NaN(s) with null in Spark you just have to create a Map of replace values for every column, like this:

val map = df.columns.map((_, "null")).toMap

Then you can use fill to replace NaN(s) with null values:

df.na.fill(map)

For Example:

scala> val df = List((Float.NaN, Double.NaN), (1f, 0d)).toDF("x", "y")
df: org.apache.spark.sql.DataFrame = [x: float, y: double]

scala> df.show
+---+---+
|  x|  y|
+---+---+
|NaN|NaN|
|1.0|0.0|
+---+---+

scala> val map = df.columns.map((_, "null")).toMap
map: scala.collection.immutable.Map[String,String] = Map(x -> null, y -> null)

scala> df.na.fill(map).printSchema
root
 |-- x: float (nullable = true)
 |-- y: double (nullable = true)


scala> df.na.fill(map).show
+----+----+
|   x|   y|
+----+----+
|null|null|
| 1.0| 0.0|
+----+----+

I hope this helps !

Upvotes: 5

Related Questions