Reputation: 1492
I am working with a dataframe with the following schema:
root
|-- Id: integer (nullable = true)
|-- defectiveItem: string (nullable = true)
|-- item: struct (nullable = true)
| |-- gem1: integer (nullable = true)
| |-- gem2: integer (nullable = true)
| |-- gem3: integer (nullable = true)
defectiveItem
column contains a value in gem1
,gem2
,gem3
and item
contains the count for the items.
Now depending on the defectiveItem, I need to project count of the given defectiveItem from item
as a new column named count
.
For example if the defectiveItem
column contains gem1
and item
contains {"gem1":3,"gem2":4,"gem3":5}
the resulting count
column should contain 3.
The resulting schema should be as follows:
root
|-- Id: integer (nullable = true)
|-- defectiveItem: string (nullable = true)
|-- item: struct (nullable = true)
| |-- gem1: integer (nullable = true)
| |-- gem2: integer (nullable = true)
| |-- gem3: integer (nullable = true)
|-- count: integer (nullable = true)
Upvotes: 0
Views: 779
Reputation: 2033
You can also solve that with more classical approach using SQL native features of when-case:
import sparkSession.implicits._
val defectiveItems = Seq(
(1, "gem1", Map("gem1" -> 10, "gem2" -> 0, "gem3" -> 0)),
(2, "gem1", Map("gem1" -> 15, "gem2" -> 0, "gem3" -> 0)),
(3, "gem1", Map("gem1" -> 33, "gem2" -> 0, "gem3" -> 0)),
(4, "gem3", Map("gem1" -> 0, "gem2" -> 0, "gem3" -> 2))
).toDF("Id", "defectiveItem", "item")
import org.apache.spark.sql.functions._
val datasetWithCount = defectiveItems.withColumn("count", when($"defectiveItem" === "gem1", $"item.gem1").otherwise(when($"defectiveItem" === "gem2", $"item.gem2").otherwise($"item.gem3")))
println("All items="+datasetWithCount.collectAsList())
It'll print:
All items=[[1,gem1,Map(gem1 -> 10, gem2 -> 0, gem3 -> 0),10], [2,gem1,Map(gem1 -> 15, gem2 -> 0, gem3 -> 0),15], [3,gem1,Map(gem1 -> 33, gem2 -> 0, gem3 -> 0),33], [4,gem3,Map(gem1 -> 0, gem2 -> 0, gem3 -> 2),2]]
By using native solutions you can take advantage of Spark's internal optimizations for execution plans.
Upvotes: 0
Reputation: 41987
You can get your desired output dataframe by using a udf
function as
import org.apache.spark.sql.functions._
def getItemUdf = udf((defectItem: String, item: Row)=> item.getAs[Int](defectItem))
df.withColumn("count", getItemUdf(col("defectiveItem"), col("item"))).show(false)
I hope the answer is useful
Upvotes: 2