Zeynep Akkalyoncu
Zeynep Akkalyoncu

Reputation: 875

Spark Scala SQL: Take average of non-null columns

How do I take the average of columns in an array cols with non-null values in a dataframe df? I can do this for all columns but it gives null when any of the values are null.

val cols = Array($"col1", $"col2", $"col3")
df.withColumn("avgCols", cols.foldLeft(lit(0)){(x, y) => x + y} / cols.length)

I don't want to na.fill because I want to preserve the true average.

Upvotes: 1

Views: 374

Answers (2)

ZygD
ZygD

Reputation: 24386

aggregate function can do it without udf.

val cols = Array($"col1", $"col2", $"col3")
df.withColumn(
    "avgCols",
    aggregate(
        cols,
        struct(lit(0).alias("sum"), lit(0).alias("count")),
        (acc, x) => struct((acc("sum") + coalesce(x, lit(0))).alias("sum"), (acc("count") + coalesce(x.cast("boolean").cast("int"), lit(0))).alias("count")),
        (s) => s("sum") / s("count")
    )
)

Upvotes: 0

Kobotan
Kobotan

Reputation: 112

I guess you can do something like this:

    val cols = Array("col1", "col2", "col3")
    def countAvg =
      udf((data: Row) => {
        val notNullIndices = cols.indices.filterNot(i => data.isNullAt(i))
        notNullIndices.map(i => data.getDouble(i)).sum / notNullIndices.lenght
      })

    df.withColumn("seqNull", struct(cols.map(col): _*))
      .withColumn("avg", countAvg(col("seqNull")))
      .show(truncate = false)

But be careful, here average is counted only for not null elements.

If you need exactly solution like in your code:

    val cols = Array("col1", "col2", "col3")
    def countAvg =
      udf((data: Row) => {
        val notNullIndices = cols.indices.filterNot(i => data.isNullAt(i))
        notNullIndices.map(i => data.getDouble(i)).sum / cols.lenght
      })

    df.withColumn("seqNull", struct(cols.map(col): _*))
      .withColumn("avg", countAvg(col("seqNull")))
      .show(truncate = false)

Upvotes: 1

Related Questions