HHH
HHH

Reputation: 6465

How to find columns with many nulls in Spark/Scala

I have a dataframe in Spark/Scala which has 100's of column. Many of the oth columns have many null values. I'd like to find the columns that have more than 90% nulls and then drop them from my dataframe. How can I do that in Spark/Scala?

Upvotes: 0

Views: 2851

Answers (2)

emesday
emesday

Reputation: 6186

org.apache.spark.sql.functions.array and udf will help.

import spark.implicits._
import org.apache.spark.sql.functions._

val df = sc.parallelize[(String, String, String, String, String, String, String, String, String, String)](
  Seq(
    ("a", null, null, null, null, null, null, null, null, null), // 90%
    ("b", null, null, null, null, null, null, null, null, ""), // 80%
    ("c", null, null, null, null, null, null, null, "", "") // 70%
  )
).toDF("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9","c10")

// count nulls then check the condition
val check_90_null = udf { xs: Seq[String] =>
  xs.count(_ == null) >= (xs.length * 0.9)
}

// all columns as array
val columns = array(df.columns.map(col): _*)

// filter out
df.where(not(check_90_null(columns)))
  .show()

shows

+---+----+----+----+----+----+----+----+----+---+
| c1|  c2|  c3|  c4|  c5|  c6|  c7|  c8|  c9|c10|
+---+----+----+----+----+----+----+----+----+---+
|  b|null|null|null|null|null|null|null|null|   |
|  b|null|null|null|null|null|null|null|    |   |
+---+----+----+----+----+----+----+----+----+---+

which the row started "a" is excluded.

Upvotes: 2

akuiper
akuiper

Reputation: 214937

Suppose you have a data frame like this:

val df = Seq((Some(1.0), Some(2), Some("a")), 
             (null, Some(3), null), 
             (Some(2.0), Some(4), Some("b")), 
             (null, null, Some("c"))
            ).toDF("A", "B", "C")

df.show
+----+----+----+
|   A|   B|   C|
+----+----+----+
| 1.0|   2|   a|
|null|   3|null|
| 2.0|   4|   b|    
|null|null|   c|
+----+----+----+

Count NULL using agg function and filter columns based on the null counts and threshold, set it to be 1 here:

val null_thresh = 1                 // if you want to use percentage 
                                    // val null_thresh = df.count() * 0.9

val to_keep = df.columns.filter(
    c => df.agg(
        sum(when(df(c).isNull, 1).otherwise(0)).alias(c)
    ).first().getLong(0) <= null_thresh
)

df.select(to_keep.head, to_keep.tail: _*).show

And you get:

+----+----+
|   B|   C|
+----+----+
|   2|   a|
|   3|null|
|   4|   b|
|null|   c|
+----+----+

Upvotes: 2

Related Questions