Reputation: 1160
I have the following dataframe in Spark (it has only one row):
df.show
+---+---+---+---+---+---+
| A| B| C| D| E| F|
+---+---+---+---+---+---+
| 1|4.4| 2| 3| 7|2.6|
+---+---+---+---+---+---+
I want to get the columns that their values are greater than 2.8 (just as an example). The outcomes should be:
List(B, D , E)
Here is my own solution:
val cols = df.columns
val threshold = 2.8
val values = df.rdd.collect.toList
val res = values
.flatMap(x => x.toSeq)
.map(x => x.toString.toDouble)
.zip(cols)
.filter(x => x._1 > threshold)
.map(x => x._2)
Upvotes: 0
Views: 1450
Reputation: 41957
A simple udf
function should give you correct result as
val columns = df.columns
def getColumns = udf((cols: Seq[Double]) => cols.zip(columns).filter(_._1 > 2.8).map(_._2))
df.withColumn("columns > 2.8", getColumns(array(columns.map(col(_)): _*))).show(false)
So that even if you have multiple rows as below
+---+---+---+---+---+---+
|A |B |C |D |E |F |
+---+---+---+---+---+---+
|1 |4.4|2 |3 |7 |2.6|
|4 |2.7|2 |3 |1 |2.9|
+---+---+---+---+---+---+
You will get result for each rows as
+---+---+---+---+---+---+-------------+
|A |B |C |D |E |F |columns > 2.8|
+---+---+---+---+---+---+-------------+
|1 |4.4|2 |3 |7 |2.6|[B, D, E] |
|4 |2.7|2 |3 |1 |2.9|[A, D, F] |
+---+---+---+---+---+---+-------------+
I hope the answer is helpful
Upvotes: 2
Reputation: 27373
you could use explode
and array
functions:
df.select(
explode(
array(
df.columns.map(c => struct(lit(c).alias("key"), col(c).alias("val"))): _*
)
).as("kv")
)
.where($"kv.val" > 2.8)
.select($"kv.key")
.show()
+---+
|key|
+---+
| B|
| D|
| E|
+---+
you could then collect this result. But I don't see any issue with collecting the dataframe first as t has only 1 row:
df.columns.zip(df.first().toSeq.map(_.asInstanceOf[Double]))
.collect{case (c,v) if v>2.8 => c} // Array(B,D,E)
Upvotes: 2
Reputation: 2091
val c = df.columns.foldLeft(df){(a,b) => a.withColumn(b, when(col(b) > 2.8, b))}
c.collect
You can remove the nulls from the array
Upvotes: 1