cian
cian

Reputation: 51

FedProx with TensorFlow Federated

Would anyone know how to implement the FedProx optimisation algorithm with TensorFlow Federated? The only implementation that seems to be available online was developed directly with TensorFlow. A TFF implementation would enable an easier comparison with experiments that utilise FedAvg which the framework supports.

This is the link to the FedProx repo: https://github.com/litian96/FedProx

Link to the paper: https://arxiv.org/abs/1812.06127

Upvotes: 1

Views: 1169

Answers (2)

Alessio Mora
Alessio Mora

Reputation: 296

I provide below my implementation of FedProx in TFF. I am not 100% sure that this is the right implementation; I post this answer also for discussing on actual code example.

I tried to follow the suggestions in the Jacub Konecny's answer and comment.

Starting from the simple_fedavg (referring to the TFF Github repo), I just modified the client_update method, and specifically changing the input argument for calculating the gradient with the GradientTape, i.e. instaead of just passing in input the outputs.loss, the tape calculates the gradient considering the outputs.loss + proximal_term previosuly (and iteratively) calculated.

@tf.function
def client_update(model, dataset, server_message, client_optimizer):
"""Performans client local training of "model" on "dataset".Args:
model: A "tff.learning.Model".
dataset: A "tf.data.Dataset".
server_message: A "BroadcastMessage" from server.
client_optimizer: A "tf.keras.optimizers.Optimizer". 
Returns:
A "ClientOutput".
""" 

def difference_model_norm_2_square(global_model, local_model):
    """Calculates the squared l2 norm of a model difference (i.e.
    local_model - global_model)
    Args:
        global_model: the model broadcast by the server
        local_model: the current, in-training model

    Returns: the squared norm

    """
    model_difference = tf.nest.map_structure(lambda a, b: a - b,
                                           local_model,
                                           global_model)
    squared_norm = tf.square(tf.linalg.global_norm(model_difference))
    return squared_norm

model_weights = model.weights
initial_weights = server_message.model_weights
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                     initial_weights)

num_examples = tf.constant(0, dtype=tf.int32)
loss_sum = tf.constant(0, dtype=tf.float32)
# Explicit use `iter` for dataset is a trick that makes TFF more robust in
# GPU simulation and slightly more performant in the unconventional usage
# of large number of small datasets.

for batch in iter(dataset):
    with tf.GradientTape() as tape:
        outputs = model.forward_pass(batch)

        # ------ FedProx ------
        mu = tf.constant(0.2, dtype=tf.float32)
        prox_term =(mu/2)*difference_model_norm_2_square(model_weights.trainable, initial_weights.trainable)
        fedprox_loss = outputs.loss + prox_term

    # Letting GradientTape dealing with the FedProx's loss
    grads = tape.gradient(fedprox_loss, model_weights.trainable)

    client_optimizer.apply_gradients(zip(grads, model_weights.trainable))

    batch_size = tf.shape(batch['x'])[0]
    num_examples += batch_size
    loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)

weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                      model_weights.trainable,
                                      initial_weights.trainable)
client_weight = tf.cast(num_examples, tf.float32)
return ClientOutput(weights_delta, client_weight, loss_sum / client_weight)

Upvotes: 1

Jakub Konecny
Jakub Konecny

Reputation: 900

At this moment, FedProx implementation is not available. I agree it would be a valuable algorithm to have.

If you are interested in contributing FedProx, the best place to start would be simple_fedavg which is a minimal implementation of FedAvg meant as a starting point for extensions -- see the readme there for more details.

I think the major change would need to happen to the client_update method, where you would add the proximal term depending on model_weights and initial_weights to the loss computed in forward pass.

Upvotes: 1

Related Questions