lizarisk
lizarisk

Reputation: 7820

How to set class weights for OneVsRestClassifier in scikit-learn?

I need an SVM working as multilabel classifier, so I decided to use OneVsRestClassifier wrapper. However the problem arises that the training set becomes highly unbalanced: for a given class there are much more negative examples than positive. This could be solved by class_weight parameter, but if I use it in a classifier wrapped in OneVsRestClassifier, I get an error:

from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier

weights = {'ham': 1, 'eggs': 2}
svm = OneVsRestClassifier(LinearSVC(class_weight=weights))

X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]]
Y = [['ham'], [], ['eggs', 'spam'], ['spam'], ['eggs']]

svm.fit(X, Y)
Traceback (most recent call last):
  File "", line 1, in 
  File "/usr/local/lib/python2.7/site-packages/sklearn/multiclass.py", line 197, in fit
    n_jobs=self.n_jobs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/multiclass.py", line 87, in fit_ovr
    for i in range(Y.shape[1]))
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 514, in __call__
    self.dispatch(function, args, kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 311, in dispatch
    job = ImmediateApply(func, args, kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 135, in __init__
    self.results = func(*args, **kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/multiclass.py", line 56, in _fit_binary
    estimator.fit(X, y)
  File "/usr/local/lib/python2.7/site-packages/sklearn/svm/base.py", line 681, in fit
    self.classes_, y)
  File "/usr/local/lib/python2.7/site-packages/sklearn/utils/class_weight.py", line 49, in compute_class_weight
    if classes[i] != c:
IndexError: index 2 is out of bounds for axis 0 with size 2

Upvotes: 4

Views: 6108

Answers (1)

Ando Saabas
Ando Saabas

Reputation: 1967

The problem is that LinearSVC expects binary class [0, 1]. So giving weights for non-binary classes ('ham', 'egg' or even [0,1,2]) fails. But you can use 'auto' weights instead, which automatically "balances" your classes by choosing the appropriate weights. It will then also work for your multiclass OneVsRest classifier.

svm = OneVsRestClassifier(LinearSVC(class_weight='auto'))

X = [[1, 2], [3, 4], [5, 4]]
Y = [0,1,2]

svm.fit(X, Y)

Upvotes: 6

Related Questions