GJT Praveen
GJT Praveen

Reputation: 45

Multi Label Imbalanced dataset classification

I am currently working on an multi label fashion item dataset which is highly imbalanced I tried using class_weights to tackle it, but still the accuracy is stuck at 0.7556 every epoch. Is there any way, I can avoid this problem. Did I implement the class weights in a wrong way? I tried using data augmentation too.

I have like 224 unique classes in train set. And some of them have only one example which is very frustrating

Tried to solve the problem with the help of this notebook as well, but I am unable to get the same accuracy score. Looks like, in this notebook the possibility of imbalance in the dataset is not considered.

def calculating_class_weights(classes,df):
  number_dim = np.shape(classes)[0]
  weights = np.empty([number_dim, 2])
  for i in range(len(classes)):
    weights[i] = compute_class_weight(class_weight='balanced', classes=[0.,1.], y=df[classes[i]])
  return weights

def get_weighted_loss(weights):
  def weighted_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    return K.mean((weights[:,0]**(1-y_true))*(weights[:,1]**(y_true))* K.binary_crossentropy(y_true, y_pred), axis=-1)
  return weighted_loss

weights=calculating_class_weights(train_labels,train_df)

train_dataGen = ImageDataGenerator(
                                  rescale=1./255,
                                  rotation_range=40,
                                  width_shift_range=0.2,
                                  height_shift_range=0.2,
                                  shear_range = 0.2,
                                  zoom_range=0.2,
                                  horizontal_flip=True,
                                  fill_mode='nearest',
                                  )
                
valid_dataGen = ImageDataGenerator(rescale=1./255)

model = keras.models.Sequential([
    keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(256,256,3)),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same"),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    keras.layers.Flatten(),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(224, activation='sigmoid')
])

model.compile(loss=get_weighted_loss(weights), optimizer='adam', metrics=['accuracy'])

model.fit(train_generator,
          epochs=10,
          validation_data=valid_generator,
          callbacks=[tensorboard_cb,lrr])

Upvotes: 3

Views: 648

Answers (1)

Illustrati
Illustrati

Reputation: 321

First of all, metrics such as Precision and Recall are focused on the positive class only, avoiding the problems encountered by multi-class focus metrics in the case of the class imbalance. Thus, we may not obtain enough information about the performance of the negative class if we keep considering all indicators. Haibo He et al suggest the metrics below to rate both items:

  1. Geometric Mean.
  2. F-Measure.
  3. Macro-Averaged Accuracy.
  4. Newer Combinations of Threshold Metrics: Mean-Class-Weighted Accuracy, Optimized Precision, Adjusted Geometric Mean, Index of Balanced Accuracy.

My suggestions:

  1. Use the PR-curve and the F1-score.
  2. Try with geometric transformations, photometric transformations, random occlusions (to avoid overfitting), SMOTE, Tomek links (for undersampling majorities), etc.
  3. Random undersampling may delete relevant features of your dataset. Again, analyze your dataset using KNN and other similar techniques.
  4. Check this book: H. He and Y. Ma, Imbalanced Learning: Foundations, Algorithms, and Applications, Hoboken, New Jersey: Wiley-IEEE Press, 2013.

Upvotes: 1

Related Questions