Reputation: 238
I use Spark MLLib to conduct a SVM classification on a RDD of LabeledPoints. I want to cross validate it. Which is the best way to do it? Does anyone have an example code? I found the CrossValidator class which relies on a DataFrame though.
My aim is to obtain the F-score.
Upvotes: 2
Views: 877
Reputation: 3433
I've faced the same issue for over a month until I realized that I must use the ML API instead of the MLlib API (more about the differences between both of them here). In that case, the SVM for the new API is the LinearSVC:
from pyspark.ml.classification import RandomForestClassifier, LinearSVC
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# SVM
crossval = CrossValidator(estimator=LinearSVC(),
estimatorParamMaps=ParamGridBuilder().build(),
evaluator=MulticlassClassificationEvaluator(metricName='f1'),
numFolds=5,
parallelism=4)
# Random Forest
crossval = CrossValidator(estimator=RandomForestClassifier(),
estimatorParamMaps=ParamGridBuilder().build(),
evaluator=MulticlassClassificationEvaluator(metricName='f1'),
numFolds=5,
parallelism=4)
In both cases you can just fit the model:
cross_model: CrossValidatorModel = crossval.fit
Upvotes: 1
Reputation: 15141
You can find a complete example on Spark's github, though not with SVM but logistic regression.
The best way is to change your RDD into a DataFrame using rdd.toDF()
method.
Upvotes: 0