Toren
Toren

Reputation: 6856

Update columns when iterate over DataFrame

There is data frame :

import sqlContext.implicits._

case class TestData(banana: String, orange: String, apple : String, feijoa: String)

var data = sc.parallelize((1 to 5).map(i => TestData("banana="+i.toString,
                    "orange="+i.toString,"apple="+i.toString,"feijoa="+i.toString))).toDF

data.registerTempTable("data")
data.show 

Which looks like following:

+--------+--------+-------+--------+
|  banana|  orange|  apple|  feijoa|
+--------+--------+-------+--------+
|banana=1|orange=1|apple=1|feijoa=1|
|banana=2|orange=2|apple=2|feijoa=2|
|banana=3|orange=3|apple=3|feijoa=3|
|banana=4|orange=4|apple=4|feijoa=4|
|banana=5|orange=5|apple=5|feijoa=5|
+--------+--------+-------+--------+

In addition, there is a sorted list of results:

case class result(fruits: Set[String], weight: Double)

val results = List(
         result(Set("banana=1"), 200),
         result(Set("banana=3", "orange=3"), 180),
         result(Set("banana=2", "orange=2", "apple=3"), 170)
 )

I'd like to iterate results ,compare single result with rows in data frame and set 1 in appropriate column if row contains the specific result

Update : each column in dataframe contain only one value , for example banana = 1 . Set of result.fruits made from these values.

1) I know how to iterate the results:

(0 to results.size-1)
  .map(i => results(i).fruits)

2) I know how to add columns to data frame by size of results

data =
(1 to results.size)
 .par
 .foldLeft(data){ case(data,i) => 
  data.withColumn(i.toString(),lit(0) )
}     


+--------+--------+-------+--------+-+-+-+
|  banana|  orange|  apple|  feijoa|1|2|3|
+--------+--------+-------+--------+-+-+-+
|banana=1|orange=1|apple=1|feijoa=1|0|0|0|
|banana=2|orange=2|apple=2|feijoa=2|0|0|0|
|banana=3|orange=3|apple=3|feijoa=3|0|0|0|
|banana=4|orange=4|apple=4|feijoa=4|0|0|0|
|banana=5|orange=5|apple=5|feijoa=5|0|0|0|
+--------+--------+-------+--------+-+-+-+

3) I need a help to understand how to combine either select function that check if specific row contains result.fruits , then set the value to 1 in appropriate column: first from results in column #1 , second from results list in column #2 and etc

Upvotes: 3

Views: 1420

Answers (1)

Assaf Mendelson
Assaf Mendelson

Reputation: 13001

Try something like this (giving the simple solution but you can generalize it):

data = data.withColumn("combined", array($"banana",$"orange", $"apple",$"feijoa"))
def getFunc(resultSet: Set[String]) = {
    def f(x: Seq[String]): Int = {
        if(resultSet.forall(y=>x.contains(y))) 1 else 0
    }
    udf(f _)
}

data =(1 to results.size).foldLeft(data){
  (x,i) => x.withColumn(i.toString, getFunc(results(i-1).fruits)($"combined"))
}  

Upvotes: 1

Related Questions