mrbrahman
mrbrahman

Reputation: 517

Spark-ML writing custom Model, Transformer

This is on Spark 2.0.1

I'm trying to compile and use the SimpleIndexer example from here.

import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._

import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

trait SimpleIndexerParams extends Params {
  final val inputCol= new Param[String](this, "inputCol", "The input column")
  final val outputCol = new Param[String](this, "outputCol", "The output column")
}

class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {

  def setInputCol(value: String) = set(inputCol, value)

  def setOutputCol(value: String) = set(outputCol, value)

  def this() = this(Identifiable.randomUID("simpleindexer"))

  override def copy(extra: ParamMap): SimpleIndexer = {
    defaultCopy(extra)
  }

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    }
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))
  }

  override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
    import dataset.sparkSession.implicits._
    val words = dataset.select(dataset($(inputCol)).as[String]).distinct
      .collect()
    new SimpleIndexerModel(uid, words)
 ; }
}

class SimpleIndexerModel(
  override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {

  override def copy(extra: ParamMap): SimpleIndexerModel = {
    defaultCopy(extra)
  }

  private val labelToIndex: Map[String, Double] = words.zipWithIndex.
    map{case (x, y) => (x, y.toDouble)}.toMap

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    }
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    val indexer = udf { label: String => labelToIndex(label) }
    dataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol)))
  }
}

However, I'm getting an error during transformation:

val df = Seq(
  (10, "hello"),
  (20, "World"),
  (30, "goodbye"),
  (40, "sky")
).toDF("id", "phrase")

val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)

si.transform(df).show(false)

// java.util.NoSuchElementException: Failed to find a default value for inputCol

Any idea how to fix it?

Upvotes: 0

Views: 1528

Answers (2)

mrbrahman
mrbrahman

Reputation: 517

Okay, I figured out by going into the source code for CountVectorizer. Looks like I need to replace new SimpleIndexerModel(uid, words) with copyValues(new SimpleIndexerModel(uid, words).setParent(this)). So the new fit method becomes

  override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
    import dataset.sparkSession.implicits._
    val words = dataset.select(dataset($(inputCol)).as[String]).distinct
      .collect()
    //new SimpleIndexerModel(uid, words)
    copyValues(new SimpleIndexerModel(uid, words).setParent(this))
  }

With this, the params are recognized, and transform happens neatly.

val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)

si.explainParams
// res3: String =
// inputCol: The input column (current: phrase)
// outputCol: The output column (current: phrase_idx)

si.transform(df).show(false)
// +---+-------+----------+
// |id |phrase |phrase_idx|
// +---+-------+----------+
// |10 |hello  |1.0       |
// |20 |World  |0.0       |
// |30 |goodbye|3.0       |
// |40 |sky    |2.0       |
// +---+-------+----------+

Upvotes: 1

jsdeveloper
jsdeveloper

Reputation: 4045

The SimpleIndexer transform method appears to accept a Dataset as the parameter - rather than a DataFrame (which is what you are passing in).

case class Phrase(id: Int, phrase:String)
si.transform(df.as[Phrase])....

See docs for more info: https://spark.apache.org/docs/2.0.1/sql-programming-guide.html

EDIT: The problem appears to be that the SimpleIndexerModel cannot access the "phrase" column via the expression $(inputCol). I think this is because it gets set in the SimpleIndexer class (and the above expression works fine) but is not accessible in SimpleIndexerModel.

One solution is to manually set the col names:

indexer(dataset.col("phrase").cast(StringType)).as("phrase_idx"))

But it might be nicer to pass in the col names when instantiating the SimpleIndexerModel:

class SimpleIndexerModel(override val uid: String, words: Array[String], inputColName: String, outputColName: String)
....

new SimpleIndexerModel(uid, words, $(inputCol), $(outputCol))

Results:

+---+-------+----------+
|id |phrase |phrase_idx|
+---+-------+----------+
|10 |hello  |1.0       |
|20 |World  |0.0       |
|30 |goodbye|3.0       |
|40 |sky    |2.0       |
+---+-------+----------+

Upvotes: 1

Related Questions