Reputation: 659
I am trying to use keras to fit a CNN model to classify images. The data set has much more images form certain classes, so its unbalanced.
I read different thing on how to weight the loss to account for this in Keras, e.g.: https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras, which is nicely explained. But, its always explaining for the fit() function, not the fit_generator() one.
Indeed, in the fit_generator() function we dont have the 'class_weights' parameter, but instead we have 'weighted_metrics', which I dont understand its description: "weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing."
How can I pass from 'class_weights' to 'weighted_metrics'? Would any one have a simple example?
Upvotes: 11
Views: 11565
Reputation: 16587
We have class_weight
in fit_generator
(Keras v.2.2.2) According to docs:
Class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.
Assume you have two classes [positive and negative], you can pass class_weight
to fit_generator
with:
model.fit_generator(gen,class_weight=[0.7,1.3])
Upvotes: 23