sandbar
sandbar

Reputation: 84

Retrieve column value given a column of column names (spark / scala)

I have a dataframe like the following:

+-----------+-----------+---------------+------+---------------------+                                        
|best_col   |A          |B              |  C   |<many more columns>  |
+-----------+-----------+---------------+------+---------------------+
|     A     |    14     |        26     |  32  |       ...           |
|     C     |    13     |        17     |  96  |       ...           |
|     B     |    23     |        19     |  42  |       ...           |
+-----------+-----------+---------------+------+---------------------+ 

I want to end up with a DataFrame like this:

+-----------+-----------+---------------+------+---------------------+----------+                                        
|best_col   |A          |B              |  C   |<many more columns>  | result   |
+-----------+-----------+---------------+------+---------------------+----------+
|     A     |    14     |        26     |  32  |       ...           |   14     |
|     C     |    13     |        17     |  96  |       ...           |   96     |
|     B     |    23     |        19     |  42  |       ...           |   19     |
+-----------+-----------+---------------+------+---------------------+----------+

Essentially, I want to add a column result that will choose the value from the column specified in the best_col column. best_col only contains column names that are present in the DataFrame. Since I have dozens of columns, I want to avoid using a bunch of when statements to check when col(best_col) === A etc. I tried doing col(col("best_col").toString()), but this didn't work. Is there an easy way to do this?

Upvotes: 2

Views: 1055

Answers (1)

meysam
meysam

Reputation: 1794

Using map_filter introduced in Spark 3.0:

val df = Seq(
    ("A", 14, 26, 32),
    ("C", 13, 17, 96),
    ("B", 23, 19, 42),
).toDF("best_col", "A", "B", "C")

df.withColumn("result", map(df.columns.tail.flatMap(c => Seq(col(c), lit(col("best_col") === lit(c)))): _*))
    .withColumn("result", map_filter(col("result"), (a, b) => b))
    .withColumn("result", map_keys(col("result"))(0))
    .show()

+--------+---+---+---+------+
|best_col|  A|  B|  C|result|
+--------+---+---+---+------+
|       A| 14| 26| 32|    14|
|       C| 13| 17| 96|    96|
|       B| 23| 19| 42|    19|
+--------+---+---+---+------+

Upvotes: 1

Related Questions