rosenloev
rosenloev

Reputation: 17

Fitting CrossValidator object to training data error with PySpark

I've found a lot of questions here regarding the extraction of the best best model parameters from a fitted CrossValidator object, but I haven't found anything regarding this particular error. I'm trying to create a CrossValidator-object and fit it to my training data and later to evaluate the metrics agains my initial linear regression model. My dataset has been digested into the appropriate format and divided into training and test sets and applied to the first linear regression model with success. Trying out cross-validation for the first time, I'm bumbling it somehow, as the last line of code returns an "IllegalArgumentException: label does not exist. Available: PE, features, CrossValidator_3fda633cd32d_rand, prediction", where 'PE' is my labelCol.

from pyspark.ml.regression import LinearRegression 
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

lrCV = LinearRegression(featuresCol='features', labelCol='PE', maxIter=50)

#Creating a grid of parameter values that the cross-validation will use

paramGrid = ParamGridBuilder() \
  .addGrid(lrCV.regParam, [1, 0.1, 0.01]) \
  .addGrid(lrCV.elasticNetParam, [0.0, 0.5, 1.0, 2]) \
  .addGrid(lrCV.fitIntercept, [True, False]) \
  .build()

#Create an instance of the CrossValidator object and enter our predefined parameters

crossVal = CrossValidator(estimator=lrCV,
                         estimatorParamMaps=paramGrid,
                         evaluator=RegressionEvaluator(),
                         numFolds=5)

lrModelCV = crossVal.fit(train)

Does anybody have any suggestions? I'm guessing it is something really simple that I have overlooked, but cannot for the life of me find what it is. Thanks in advance.

Upvotes: 0

Views: 658

Answers (1)

A.B
A.B

Reputation: 20445

By default, CrossValidation expects output label to be named label

  1. You can either make the output column label at beginning to avoid all mess

or

  1. you can pass labelCol='PE' to RegressionEvaluator() as well

    RegressionEvaluator(labelCol='PE')
    

Upvotes: 1

Related Questions