Arturo Gatto
Arturo Gatto

Reputation: 53

Scala - Spark In Dataframe retrieve, for row, column name with have max value

I have a DataFrame:

name     column1  column2  column3  column4
first    2        1        2.1      5.4
test     1.5      0.5      0.9      3.7
choose   7        2.9      9.1      2.5

I want a new dataframe with a column with contain, the column name with have max value for row :

| name   | max_column |
|--------|------------|
| first  | column4    |
| test   | column4    |
| choose | column3    |

Thank you very much for support.

Upvotes: 3

Views: 8637

Answers (3)

Arturo Gatto
Arturo Gatto

Reputation: 53

I want post my final solution:

val finalDf = originalDf.withColumn("name", maxValAsMap(keys, values)).select("cookie_id", "max_column")

val maxValAsMap = udf((keys: Seq[String], values: Seq[Any]) => {

    val valueMap:Map[String,Double] = (keys zip values).filter( _._2.isInstanceOf[Double] ).map{
      case (x,y) => (x, y.asInstanceOf[Double])
    }.toMap

    if (valueMap.isEmpty) "not computed" else valueMap.maxBy(_._2)._1
  })

It's work very fast.

Upvotes: 0

Wilmerton
Wilmerton

Reputation: 1538

You get the job done making a detour to an RDD and using 'getValuesMap'.

val dfIn = Seq(
  ("first", 2.0, 1., 2.1, 5.4),
  ("test", 1.5, 0.5, 0.9, 3.7),
  ("choose", 7., 2.9, 9.1, 2.5)
).toDF("name","column1","column2","column3","column4")

The simple solution is

val dfOut = dfIn.rdd
  .map(r => (
       r.getString(0),
       r.getValuesMap[Double](r.schema.fieldNames.filter(_!="name"))
     ))
  .map{case (n,m) => (n,m.maxBy(_._2)._1)}
  .toDF("name","max_column")

But if you want to take back all columns from the original dataframe (like in Scala/Spark dataframes: find the column name corresponding to the max), you have to play a bit with merging rows and extending the schema

import org.apache.spark.sql.types.{StructType,StructField,StringType}
import org.apache.spark.sql.Row
val dfOut = sqlContext.createDataFrame(
  dfIn.rdd
    .map(r => (r, r.getValuesMap[Double](r.schema.fieldNames.drop(1))))
    .map{case (r,m) => Row.merge(r,(Row(m.maxBy(_._2)._1)))},
  dfIn.schema.add(StructField("max_column",StringType))
)

Upvotes: 4

mrsrinivas
mrsrinivas

Reputation: 35444

There might some better way of writing UDF. But this could be the working solution

val spark: SparkSession = SparkSession.builder.master("local").getOrCreate

//implicits for magic functions like .toDf
import spark.implicits._

import org.apache.spark.sql.functions.udf

//We have hard code number of params as UDF don't support variable number of args
val maxval = udf((c1: Double, c2: Double, c3: Double, c4: Double) =>
  if(c1 >= c2 && c1 >= c3 && c1 >= c4)
    "column1"
  else if(c2 >= c1 && c2 >= c3 && c2 >= c4)
    "column2"
  else if(c3 >= c1 && c3 >= c2 && c3 >= c4)
    "column3"
  else
    "column4"
)

//create schema class
case class Record(name: String, 
                    column1: Double, 
                    column2: Double, 
                    column3: Double, 
                    column4: Double)

val df = Seq(
  Record("first", 2.0, 1, 2.1, 5.4),
  Record("test", 1.5, 0.5, 0.9, 3.7),
  Record("choose", 7, 2.9, 9.1, 2.5)
).toDF();

df.withColumn("max_column", maxval($"column1", $"column2", $"column3", $"column4"))
  .select("name", "max_column").show

Output

+------+----------+
|  name|max_column|
+------+----------+
| first|   column4|
|  test|   column4|
|choose|   column3|
+------+----------+

Upvotes: 4

Related Questions