Drew Stevens
Drew Stevens

Reputation: 392

Using Apache Spark ML, how do you transform (for predictions) a dataset that doesn't have a label?

I'm certain I've developed a gap in my understanding Spark ML's Pipelines.

I have a pipeline that trains against a set of data, with a schema of "label", "comment" (both strings). My pipeline transforms "label", adding "indexedLabel", and vectorizes "comment" by tokenizing and then HashingTF (ending with "vectorizedComment") The pipeline concludes with a LogisticRegression, with label column "indexedLabel" and a features column of "vectorizedComment".

And it works great! I can fit against my pipeline and get a pipeline model that transforms datasets with "label", "comment" all day long! However, my goal is to be able to throw datasets of just "comment", since "label" is only present for training the model purposes.

I'm confident that I've got a gap in understanding of how predictions with pipelines work - could someone point it out for me?

Upvotes: 3

Views: 472

Answers (1)

Shaido
Shaido

Reputation: 28352

Transformations of the label can be done outside of the pipeline (i.e. before). The label is only necessary during training and not during actual usage of the pipeline/model. By performing label transformations in the pipeline any dataframe is required to have a label column which is undesired.

Small example:

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")

val df2 = indexer.fit(df).transform(df)

// Create pipeline with other stages and use df2 to fit it

Alternativly, you could have two separate pipelines. One including the label transformations which is used during training and one without it. Make sure the other stages refer to the same objects in both pipelines.

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")

// Create feature transformers and add to the pipelines

val pipelineTraining = new Pipeline().setStages(Array(indexer, ...))
val pipelineUsage = new Pipeline().setStages(Array(...))

Upvotes: 1

Related Questions