Reputation: 517
This is on Spark 2.0.1
I'm trying to compile and use the SimpleIndexer
example from here.
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
trait SimpleIndexerParams extends Params {
final val inputCol= new Param[String](this, "inputCol", "The input column")
final val outputCol = new Param[String](this, "outputCol", "The output column")
}
class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {
def setInputCol(value: String) = set(inputCol, value)
def setOutputCol(value: String) = set(outputCol, value)
def this() = this(Identifiable.randomUID("simpleindexer"))
override def copy(extra: ParamMap): SimpleIndexer = {
defaultCopy(extra)
}
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
new SimpleIndexerModel(uid, words)
; }
}
class SimpleIndexerModel(
override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {
override def copy(extra: ParamMap): SimpleIndexerModel = {
defaultCopy(extra)
}
private val labelToIndex: Map[String, Double] = words.zipWithIndex.
map{case (x, y) => (x, y.toDouble)}.toMap
override def transformSchema(schema: StructType): StructType = {
// Check that the input type is a string
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != StringType) {
throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
}
// Add the return field
schema.add(StructField($(outputCol), IntegerType, false))
}
override def transform(dataset: Dataset[_]): DataFrame = {
val indexer = udf { label: String => labelToIndex(label) }
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol)))
}
}
However, I'm getting an error during transformation:
val df = Seq(
(10, "hello"),
(20, "World"),
(30, "goodbye"),
(40, "sky")
).toDF("id", "phrase")
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.transform(df).show(false)
// java.util.NoSuchElementException: Failed to find a default value for inputCol
Any idea how to fix it?
Upvotes: 0
Views: 1528
Reputation: 517
Okay, I figured out by going into the source code for CountVectorizer
. Looks like I need to replace new SimpleIndexerModel(uid, words)
with copyValues(new SimpleIndexerModel(uid, words).setParent(this))
. So the new fit
method becomes
override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
//new SimpleIndexerModel(uid, words)
copyValues(new SimpleIndexerModel(uid, words).setParent(this))
}
With this, the params are recognized, and transform happens neatly.
val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)
si.explainParams
// res3: String =
// inputCol: The input column (current: phrase)
// outputCol: The output column (current: phrase_idx)
si.transform(df).show(false)
// +---+-------+----------+
// |id |phrase |phrase_idx|
// +---+-------+----------+
// |10 |hello |1.0 |
// |20 |World |0.0 |
// |30 |goodbye|3.0 |
// |40 |sky |2.0 |
// +---+-------+----------+
Upvotes: 1
Reputation: 4045
The SimpleIndexer transform method appears to accept a Dataset as the parameter - rather than a DataFrame (which is what you are passing in).
case class Phrase(id: Int, phrase:String)
si.transform(df.as[Phrase])....
See docs for more info: https://spark.apache.org/docs/2.0.1/sql-programming-guide.html
EDIT:
The problem appears to be that the SimpleIndexerModel cannot access the "phrase" column via the expression $(inputCol)
. I think this is because it gets set in the SimpleIndexer class (and the above expression works fine) but is not accessible in SimpleIndexerModel.
One solution is to manually set the col names:
indexer(dataset.col("phrase").cast(StringType)).as("phrase_idx"))
But it might be nicer to pass in the col names when instantiating the SimpleIndexerModel:
class SimpleIndexerModel(override val uid: String, words: Array[String], inputColName: String, outputColName: String)
....
new SimpleIndexerModel(uid, words, $(inputCol), $(outputCol))
Results:
+---+-------+----------+
|id |phrase |phrase_idx|
+---+-------+----------+
|10 |hello |1.0 |
|20 |World |0.0 |
|30 |goodbye|3.0 |
|40 |sky |2.0 |
+---+-------+----------+
Upvotes: 1