Deva
Deva

Reputation: 85

User defined transformer in Pyspark pipeline

I am trying to create a pyspark pipeline to run a classification model. My dataset has a column which is a string. So I am using 'StringIndexer' to convert it to numeric before applying a model in pipeline.

My pipeline contains just 2 stages StringIndexer and ClassificationModel

StringIndexer is creating a new column with index, however old column is also retained. I want to introduce a new transformer in pipeline to drop a 'string' column. Is this possible ?

Is there any other way to drop the actual columns in StringIndexer?

Thanks

Upvotes: 2

Views: 1421

Answers (1)

Haroun Mohammedi
Haroun Mohammedi

Reputation: 2434

Yes you can extends the abstract class Transformer and create your own transformer which drop the unnecessary columns.

This should look something like the following :

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{
  ArrayType,
  StringType,
  StructField,
  StructType
}
import org.apache.spark.sql.functions.collect_list

class Dropper(override val uid: String) extends Transformer {

  def this() = this(Identifiable.randomUID("dropper"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.drop("your-column-name-here")
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    //here you should right your result schema i.e. the schema without the dropped column
  }

}

I've been doing that for a while and it works for me just fine.

Note that you can also extends the abstract class Estimator.

Hope it helps. Best Regards

Upvotes: 2

Related Questions