Ethan
Ethan

Reputation: 261

Spark with Scala : Filter RDD by values contained in another list

How would i filter by list.contains() ? This is my current code, I have a Main class that gets input from command line arguments and according to that input executes the corresponding dispatcher. In this case its a RecommendationDispatcher class that does all it's magic in the constructor - trains a model and generates recommendations for various users that are inputted :

import org.apache.commons.lang.StringUtils.indexOfAny
import java.io.{BufferedWriter, File, FileWriter}
import java.text.DecimalFormat
import Util.javaHash
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.Rating


import org.apache.spark.{SparkConf, SparkContext}

class RecommendDispatcher(master:String, inputFile:String, outputFile:String, userList: List[String]) extends java.io.Serializable {

  val format : DecimalFormat = new DecimalFormat("0.#####");
  val file = new File(outputFile)
  val bw = new BufferedWriter(new FileWriter(file))

  val conf = new SparkConf().setAppName("Movies").setMaster(master)
  val sparkContext = new SparkContext(conf)
  val sqlContext = new org.apache.spark.sql.SQLContext(sparkContext)
  val baseRdd = sparkContext.textFile(inputFile)




  val movieIds = baseRdd.map(line => line.split("\\s+")(1)).distinct().map(id => (javaHash(id), id))

  val userIds = baseRdd.map(line => line.split("\\s+")(3)).distinct()
                                        .filter(x => userList.contains(x))
                                        .map(id => (javaHash(id), id))


  val ratings = baseRdd.map(line => line.split("\\s+"))
    .map(tokens => (tokens(3),tokens(1), tokens(tokens.indexOf("review/score:")+1).toDouble))
      .map( x => Rating(javaHash(x._1),javaHash(x._2),x._3))

  // Build the recommendation model using ALS
  val rank = 10
  val numIterations = 10
  val model = ALS.train(ratings, rank, numIterations, 0.01)

  val users = userIds.collect()
  var mids = movieIds.collect()

    usrs.foreach(u => {
      bw.write("Recommendations for " + u + ":\n")
      var ranked = List[(Double, Int)]()
      mids.foreach(x => {
        val movieId = x._1
        val prediction = (model.predict(u._1, movieId), movieId)
        ranked = ranked :+ prediction
      })
      //Sort in descending order
      ranked = ranked.sortBy(x => -1 * x._1)
      ranked.foreach(x => bw.write(x._1 + " ; " + x._2 + "\n"))
    })

  bw.close()

}

And this exception gets thrown on the ".filter" line :

Exception in thread "main" org.apache.spark.SparkException: Task not serializable

Upvotes: 2

Views: 3508

Answers (3)

Ethan
Ethan

Reputation: 261

I tried to serialize the RecommendDispatcher class but have still gotten the same Exception. So i decided to put the code in the Main class instead and that solved my problem.

Upvotes: 0

Harel Gliksman
Harel Gliksman

Reputation: 754

I am guessing Sim is right about this being closure "leakage" and the sample code you provided is oversimplified.

If your main looks like this:

object test
{
  def main(args: Array[String]): Unit = 
  {
    val sc = ...
    val rdd1 = ...
    val userList = ...
    val rdd2 = rdd1.filter { list.contains( _ ) }
  } 
}

Then no serialization error occurs. "userList", which is serializable, has no problem to be serialized over to the executors...

The problems begin when you start to model your "big" main into separate classes.

Here is an example of how things might get wrong:

class FilterLogic
{
  val userList = List( 1 )  
  def filterRDD( rdd : RDD[ Int ] ) : RDD[ Int ] = 
  {
    rdd.filter { list.contains( _ ) }
  }
}

object Test 
{
  def main(args: Array[String]): Unit = 
  {
    val sc = ...
    val rdd1 = ...
    val rdd2 = new FilterLogic().filterRDD( rdd1 )// This will result in a serialization error!!!
  }
}

Now that userList is a value of the Logic class, when it needs to be serialized over to the executors it demands the entire wrapping Logic class to be serialized as well ( Why? because in Scala userList is actually a getter in the Logic class ).

A few ways to solve this problem:

1) userList can be created inside the filterRDD function, then it is not a val of Logic ( works but limits code sharing/modeling )

1.1) Similar idea is using a temp val inside the the filterRDD function like so:

val list_ = list ; rdd.filter { list_.contains( _ ) }

works but so ugly it's almost painful...

2) Logic class can be made Serializable ( sometimes it might not be possible to make it serializable though )

Lastly, using a broadcast might have ( or not have ) its benefits but it is not related to the serialization error.

Upvotes: 0

Alberto Bonsanto
Alberto Bonsanto

Reputation: 18022

I think that a good approach is to convert your userList into a broadcast variable.

val broadcastUserList= sc.broadcast(userList)
val userIds = baseRdd.map(line => line.split("\\s+")(3)).distinct()
                                      .filter(x => broadcastUserList.value.contains(x))
                                      .map(id => (javaHash(id), id))

Upvotes: 3

Related Questions