Mahdi Ghelichi
Mahdi Ghelichi

Reputation: 1160

How to get a subset of columns based on row values in Spark Dataframe?

I have the following dataframe in Spark (it has only one row):

df.show
+---+---+---+---+---+---+
|  A|  B|  C|  D|  E|  F|
+---+---+---+---+---+---+
|  1|4.4|  2|  3|  7|2.6|
+---+---+---+---+---+---+

I want to get the columns that their values are greater than 2.8 (just as an example). The outcomes should be:

List(B, D , E)

Here is my own solution:

val cols = df.columns
val threshold = 2.8
val values = df.rdd.collect.toList
val res = values
         .flatMap(x => x.toSeq)
         .map(x => x.toString.toDouble)
         .zip(cols)
         .filter(x => x._1 > threshold)
         .map(x => x._2)

Upvotes: 0

Views: 1450

Answers (3)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

A simple udf function should give you correct result as

val columns = df.columns

def getColumns = udf((cols: Seq[Double]) => cols.zip(columns).filter(_._1 > 2.8).map(_._2))

df.withColumn("columns > 2.8", getColumns(array(columns.map(col(_)): _*))).show(false)

So that even if you have multiple rows as below

+---+---+---+---+---+---+
|A  |B  |C  |D  |E  |F  |
+---+---+---+---+---+---+
|1  |4.4|2  |3  |7  |2.6|
|4  |2.7|2  |3  |1  |2.9|
+---+---+---+---+---+---+

You will get result for each rows as

+---+---+---+---+---+---+-------------+
|A  |B  |C  |D  |E  |F  |columns > 2.8|
+---+---+---+---+---+---+-------------+
|1  |4.4|2  |3  |7  |2.6|[B, D, E]    |
|4  |2.7|2  |3  |1  |2.9|[A, D, F]    |
+---+---+---+---+---+---+-------------+

I hope the answer is helpful

Upvotes: 2

Raphael Roth
Raphael Roth

Reputation: 27373

you could use explode and array functions:

df.select(
    explode(
      array(
        df.columns.map(c => struct(lit(c).alias("key"), col(c).alias("val"))): _*
      )
    ).as("kv")
  )
  .where($"kv.val" > 2.8)
  .select($"kv.key")
  .show()

+---+
|key|
+---+
|  B|
|  D|
|  E|
+---+

you could then collect this result. But I don't see any issue with collecting the dataframe first as t has only 1 row:

df.columns.zip(df.first().toSeq.map(_.asInstanceOf[Double]))
      .collect{case (c,v) if v>2.8 => c} // Array(B,D,E)

Upvotes: 2

Chandan Ray
Chandan Ray

Reputation: 2091

val c = df.columns.foldLeft(df){(a,b) =>  a.withColumn(b, when(col(b) > 2.8, b))}
c.collect

You can remove the nulls from the array

Upvotes: 1

Related Questions