Sparkan
Sparkan

Reputation: 159

Spark : setNumClasses() for a subset of labels for Multiclass LogisticRegressionModel

I have a database with ids (labels) that range from 1 to 1040. I am using the Multiclass Logistic Regression for predciting the id. Now if I want to train only a subset of labels, let's say from 800 to 810. I get an error when I set setNumClasses(11) - for 11 classes. I must always set this method to the Max value of classes, which is 1040. That way the training model will train for all labels from 0 to 1040, and that is very expensive and uses a lot of resources.

Am I understaning this right? How can I train my model only for a subset of labels with giving the setNumClasses(count_of_classes).

final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
            .setNumClasses(811).run(train.rdd());

Upvotes: 0

Views: 803

Answers (2)

Md. Mahedi Kaysar
Md. Mahedi Kaysar

Reputation: 300

Based on the comments of previews answer I found the 2nd last comment is the main query. If you set setNumClasses(23) means: in the train set all the classes should be in the range of (0 to 22). Check the (docs). It is written as:

:: Experimental :: Set the number of possible outcomes for k classes classification problem in Multinomial Logistic Regression. By default, it is binary logistic regression so k will be set to 2.

That means, for binary logistic regression, binary values/classes are (0 and 1), so setNumClasses(2), is the default.

In the train set if you have other classes like 2,3,4, for binary classification it will not work.

Proposed Solution: if you have train set or subset contains 790 - 801 and 900 - 910 classes, then normalise or transform your data to (0 to 22) and put 23 as setNumClasses(23).

Upvotes: 4

Mateusz Dymczyk
Mateusz Dymczyk

Reputation: 15141

You cannot do it like this, you are supplying a set of training data and it probably fails somewhere in the gradient descent method in Spark (not sure since you haven't provided the error message).

Also how is Spark supposed to figure out for which 800 labels should it train the model?

What you should do is to filter out only the rows in the RDD with the labels for which you want to train the model. For instance lets say your labels are values from 0 to 1040 and you only want to train for labels 0 to 800 you can do:

val actualTrainingRDD = train.filter( _.label < 801 )
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
        .setNumClasses(801).run(train.rdd());

@Edit: yes it's of course possible to choose a different set of labels, that was just an example, simply change the filter method to:

train.filter( row => (row.label >= 790 && row.label < 801) )

This is Scala, Java closures use ->, right?

Upvotes: 2

Related Questions