Reputation:
I have datasets for 2 classes on which I have to perform binary classification. I chose Random forest as a classifier as it is giving me the best accuracy among other models. Number of datapoints in dataset-1 is 462 and dataset-2 contains 735 datapoints. I have noticed that my data has minor class imbalance so I tried to optimise my training model and retrained my model by providing class weights. I provided following value of class weights.
cwt <- c(0.385,0.614) # Class weights
ss <- c(300,300) # Sample size
I trained the model using following code
tr_forest <- randomForest(output ~., data = train,
ntree=nt, mtry=mt,importance=TRUE, proximity=TRUE,
maxnodes=mn,sampsize=ss,classwt=cwt,
keep.forest=TRUE,oob.prox=TRUE,oob.times= oobt,
replace=TRUE,nodesize=ns, do.trace=1
)
Using chosen class weight has increased the accuracy of my model, but I am still doubtful whether my approach is correct or is it just a coincidence. How can I make sure my class weight choice is perfect?
I calculated class weights using following formula:
Class weight for positive class = (No. of datapoints in dataset-1)/(Total datapoints)
Class weight for negative class = (No. of datapoints in dataset-2)/(Total datapoints))
For dataset-1 462/1197 = 0.385 For dataset-2 735/1197 = 0.614
Is this an acceptable method, if not why it is improving the accuracy of my model. Please help me understand the nuances of class weights.
Upvotes: 1
Views: 11752
Reputation: 60317
How can I make sure my class weight choice is perfect?
Well, you can certainly not - perfect is the absolutely wrong word here; we are looking for useful heuristics, which both improve performance and make sense (i.e. they don't feel like magic).
Given that, we do have an independent way of cross-checking your choice (which seems sound indeed), albeit in Python and not in R: the scikit-learn method of compute_class_weight
; we don't even need the exact data - only the sample numbers for each class, which you have already provided:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
y_1 = np.ones(462) # dataset-1
y_2 = np.ones(735) + 1 # dataset-2
y = np.concatenate([y_1, y_2])
len(y)
# 1197
classes=[1,2]
cw = compute_class_weight('balanced', classes, y)
cw
# array([ 1.29545455, 0.81428571])
Actually, these are your numbers multiplied by ~ 2.11, i.e.:
cw/2.11
# array([ 0.6139595, 0.3859174])
Looks good (multiplications by a constant do not affect the outcome), save one detail: seems that scikit-learn advises us to use your numbers transposed, i.e. a 0.614 weight for class 1 and 0.386 for class 2, instead of vice versa as per your computation.
We have just entered the subtleties of the exact definitions of what a class weight actually is, which are not necessarily the same across frameworks and libraries. scikit-learn uses these weights to weight differently the misclassification cost, so it makes sense to assign a greater weight to the minority class; this was the very idea in a draft paper by Breiman (inventor of RF) and Andy Liaw (maintainer of the randomForest
R package):
We assign a weight to each class, with the minority class given larger weight (i.e., higher misclassification cost).
Nevertheless, this is not what the classwt
argument in the randomForest
R method seems to be; from the docs:
classwt Priors of the classes. Need not add up to one. Ignored for regression.
"Priors of the classes" is in fact the analogy of the class presence, i.e. exactly what you have computed here; this usage seems to be the consensus of a related (and highly voted) SO thread, What does the parameter 'classwt' in RandomForest function in RandomForest package in R stand for?; additionally, Andy Liaw himself has stated that (emphasis mine):
The current "classwt" option in the randomForest package [...] is different from how the official Fortran code (version 4 and later) implements class weights.
where the official Fortran implementation I guess was as described in the previous quotation from the draft paper (i.e. scikit-learn-like).
I used RF for imbalanced data myself during my MSc thesis ~ 6 years ago, and, as far as I can remember, I had found the sampsize
parameter much more useful that classwt
, against which Andy Liaw (again...) has advised (emphasis mine):
Search in the R-help archive to see other options and why you probably shouldn't use classwt.
What's more, in an already rather "dark" context regarding detailed explanations, it is not at all clear what exactly is the effect of using both sampsize
and classwt
arguments together, as you have done here...
To wrap-up:
classwt
and sampsize
arguments in isolation (and not together), in order to be sure where your improved accuracy should be attributedUpvotes: 5