DSmonster
DSmonster

Reputation: 23

Can I use class_weight in keras model in Tensorflow Federated Learning (TFF)

My dataset is class imbalanced, so I want to use class_weight which enables the classifier heavily weight minor class. In general setting, I can assign class weight as below:

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 

Is there any way that I can assign class_weight in tensorflow federated learning? My code for federated learning is below:

def create_keras_model(output_bias=None):
    return tf.keras.models.Sequential([
        tf.keras.layers.Dense(12, activation='relu', input_shape(5,)),
        tf.keras.layers.Dense(8, activation='relu'),
        tf.keras.layers.Dense(5, activation='relu'),
        tf.keras.layers.Dense(3, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_example_dataset.element_spec,
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[tf.keras.metrics.BinaryAccuracy()])

Upvotes: 2

Views: 380

Answers (1)

Jakub Konecny
Jakub Konecny

Reputation: 900

Not directly. The main problem is that tf.keras.Model.fit method does not conceptually map onto the idea of training from decentralized data.

If you want to make this work in TFF, the first step is to determine what is the algorithm that should be executed. This does not have an obvious answer as far as I can see -- for example, how do you determine what are those class_weights if you don't have direct access to the data?

But let's assume that you have that information somehow available and simply want to modify the local training procedure at clients. Starting from examples/simple_fedavg, the way to make it happen would be to appropriately modify how gradients are computed in this loop.

Upvotes: 2

Related Questions