Vishal
Vishal

Reputation: 1492

SparkSQL : How to select column value on the basis of a column name

I am working with a dataframe with the following schema:

root
 |-- Id: integer (nullable = true)
 |-- defectiveItem: string (nullable = true)
 |-- item: struct (nullable = true)
 |    |-- gem1: integer (nullable = true)
 |    |-- gem2: integer (nullable = true)
 |    |-- gem3: integer (nullable = true)

defectiveItem column contains a value in gem1,gem2,gem3 and item contains the count for the items. Now depending on the defectiveItem, I need to project count of the given defectiveItem from item as a new column named count.

For example if the defectiveItem column contains gem1 and item contains {"gem1":3,"gem2":4,"gem3":5} the resulting count column should contain 3.

The resulting schema should be as follows:

root
     |-- Id: integer (nullable = true)
     |-- defectiveItem: string (nullable = true)
     |-- item: struct (nullable = true)
     |    |-- gem1: integer (nullable = true)
     |    |-- gem2: integer (nullable = true)
     |    |-- gem3: integer (nullable = true)
     |-- count: integer (nullable = true)

Upvotes: 0

Views: 779

Answers (2)

Bartosz Konieczny
Bartosz Konieczny

Reputation: 2033

You can also solve that with more classical approach using SQL native features of when-case:

import sparkSession.implicits._

val defectiveItems = Seq(
(1, "gem1", Map("gem1" -> 10, "gem2" -> 0, "gem3" -> 0)),
(2, "gem1", Map("gem1" -> 15, "gem2" -> 0, "gem3" -> 0)),
(3, "gem1", Map("gem1" -> 33, "gem2" -> 0, "gem3" -> 0)),
(4, "gem3", Map("gem1" -> 0, "gem2" -> 0, "gem3" -> 2))
).toDF("Id", "defectiveItem", "item")
import org.apache.spark.sql.functions._
val datasetWithCount = defectiveItems.withColumn("count", when($"defectiveItem" === "gem1", $"item.gem1").otherwise(when($"defectiveItem" === "gem2", $"item.gem2").otherwise($"item.gem3")))

println("All items="+datasetWithCount.collectAsList())

It'll print:

All items=[[1,gem1,Map(gem1 -> 10, gem2 -> 0, gem3 -> 0),10], [2,gem1,Map(gem1 -> 15, gem2 -> 0, gem3 -> 0),15], [3,gem1,Map(gem1 -> 33, gem2 -> 0, gem3 -> 0),33], [4,gem3,Map(gem1 -> 0, gem2 -> 0, gem3 -> 2),2]]

By using native solutions you can take advantage of Spark's internal optimizations for execution plans.

Upvotes: 0

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

You can get your desired output dataframe by using a udf function as

import org.apache.spark.sql.functions._
def getItemUdf = udf((defectItem: String, item: Row)=> item.getAs[Int](defectItem))

df.withColumn("count", getItemUdf(col("defectiveItem"), col("item"))).show(false)

I hope the answer is useful

Upvotes: 2

Related Questions