Dan Ciborowski - MSFT
Dan Ciborowski - MSFT

Reputation: 7207

Reduce two Scala methods, that only differ in one Object Type

I have the following two methods, using objects from Apache Spark.

  def SVMModelScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = SVMModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

  def DecisionTreeScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = DecisionTreeModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

My previous attempts to merge these functions have resulted in errors surround model.predict.

Is there a way I can use model as a parameter that is weakly typed in Scala?

Upvotes: 1

Views: 123

Answers (1)

millhouse
millhouse

Reputation: 10007

Disclaimer - I've never used Apache Spark.

It looks to me like the only difference between the two methods is the way the model is instantiated. It's unfortunate that the two model instances don't actually share a common trait that provides predict(...) but we can still make this work by pulling out the part that changes - the scorer:

def scoreWith(sc: SparkContext, scoringDataset: String)(scorer: (Vector)=>Double): RDD[(Double, Double)] = {
  MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
    val score = scorer(point.features)
    (score, point.label)
  }
}

Now we can get the previous functionality with:

def svmScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(SVMModel.load(sc, modelFileName).predict)

def dtScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(DecisionTreeModel.load(sc, modelFileName).predict)

Upvotes: 2

Related Questions