Reputation: 141
My current schema is as follows
root
|-- product: array (nullable = true)
| |-- element: string (containsNull = true)
|-- Items: map (nullable = true)
| |-- key: string
| |-- value: struct (valueContainsNull = true)
| | |-- _1: string (nullable = true)
| | |-- _2: long (nullable = false)
I will first check if any element in the product would be a key in Items, then check the _2 field in the value of that entry to see if it is smaller than some value. My code is as follows:
def has(product:Seq[String],items:Map[String,(String,Long,Long)]):Double={
var count = 0
for(x<- asin)
{
if(items.contains(x))
{
val item = items.get(x)
val iitem = item.get
val(a,b,c) = iitem
if(b<=rank)
{
count = count + 1
}
}
}
return count.toDouble
}
def hasId = udf((product:Seq[String] ,items:Map[String,(String,Long,Long)])
=>has(product,items)/items.size.toDouble
)
for(rank <- 0 to 47)
{
joined =joined.withColumn("hasId"+rank,hasId(col("product"),col("items")))
}
I am getting errors saying that
GenericRowWithSchema cannot be cast to scala.Tuple3
The error appears to be something related to
val(a,b,c) = iitem
if(b<=rank)
But I am not able to figure out what I am doing wrong.
Upvotes: 2
Views: 1966
Reputation: 37852
When passing a MapType
or ArrayType
column as a UDF's input, tuple values/keys are actually passed as org.apache.spark.sql.Row
s. You'll have to modify your UDF to expect a Map[String, Row]
as its second argument, and "convert" these Row
values into tuples using pattern matching:
def hasId = udf((product: Seq[String], items: Map[String, Row]) =>
has(product, items.mapValues {
case Row(s: String, i1: Long, i2: Long) => (s, i1, i2)
}) / items.size.toDouble
)
NOTE: somewhat unrelated to the question, looks like there's some other mistakes in the code - I assume rank
should be passed as a parameter into has
? And everything could be made more idiomatic by removing usages of mutable var
s - altogether, I'm partly guessing this does what you need:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
def has(products: Seq[String], items: Map[String, (String, Long, Long)], rank: Long): Double = products
.flatMap(items.get)
.map(_._2)
.count(_ <= rank)
.toDouble
def hasId(rank: Long) = udf((product: Seq[String], items: Map[String, Row]) => {
val convertedItems = items.mapValues {
case Row(s: String, i1: Long, i2: Long) => (s, i1, i2)
}
has(product, convertedItems, rank) / items.size.toDouble
})
val result = (0 to 47).foldLeft(joined) {
(df, rank) => df.withColumn("hasId" + rank, hasId(rank)(col("product"), col("items")))
}
Upvotes: 3