BinderNet
BinderNet

Reputation: 47

Spark columnar performance

I'm a relative beginner to things Spark. I have a wide dataframe (1000 columns) that I want to add columns to based on whether a corresponding column has missing values

so

+----+          
| A  |
+----+
| 1  |
+----+
|null|     
+----+
| 3  |
+----+

becomes

+----+-------+          
| A  | A_MIS |
+----+-------+
| 1  |   0   |
+----+-------+
|null|   1   |
+----+-------+
| 3  |   1   |
+----+-------+

This is part of a custom ml transformer but the algorithm should be clear.

override def transform(dataset: org.apache.spark.sql.Dataset[_]): org.apache.spark.sql.DataFrame = {
  var ds = dataset
  dataset.columns.foreach(c => {
    if (dataset.filter(col(c).isNull).count() > 0) {
      ds = ds.withColumn(c + "_MIS", when(col(c).isNull, 1).otherwise(0))
    }
  })


  ds.toDF()
}

Loop over the columns, if > 0 nulls create a new column.

The dataset passed in is cached (using the .cache method) and the relevant config settings are the defaults. This is running on a single laptop for now, and runs in the order of 40 minutes for the 1000 columns even with a minimal amount of rows. I thought the problem was due to hitting a database, so I tried with a parquet file instead with the same result. Looking at the jobs UI it appears to be doing filescans in order to do the count.

Is there a way I can improve this algorithm to get better performance, or tune the cacheing in some way? Increasing spark.sql.inMemoryColumnarStorage.batchSize just got me an OOM error.

Upvotes: 0

Views: 967

Answers (2)

BinderNet
BinderNet

Reputation: 47

Here's the code that fixes the problem.

override def transform(dataset: Dataset[_]): DataFrame = {
  var ds = dataset
  val rowCount = dataset.count()
  val exprs = dataset.columns.map(count(_))
  val colCounts = dataset.agg(exprs.head, exprs.tail: _*).toDF(dataset.columns: _*).first()
  dataset.columns.foreach(c => {
    if (colCounts.getAs[Long](c) > 0 && colCounts.getAs[Long](c) < rowCount   ) {
      ds = ds.withColumn(c + "_MIS", when(col(c).isNull, 1).otherwise(0))
    }
  })
  ds.toDF()
}

Upvotes: 0

Alper t. Turker
Alper t. Turker

Reputation: 35249

Remove the condition:

if (dataset.filter(col(c).isNull).count() > 0) 

and leave only the internal expression. As it is written Spark requires #columns data scans.

If you want prune columns compute statistics once, as outlined in Count number of non-NaN entries in each column of Spark dataframe with Pyspark, and use single drop call.

Upvotes: 1

Related Questions