Reputation: 51
Would anyone know how to implement the FedProx optimisation algorithm with TensorFlow Federated? The only implementation that seems to be available online was developed directly with TensorFlow. A TFF implementation would enable an easier comparison with experiments that utilise FedAvg which the framework supports.
This is the link to the FedProx repo: https://github.com/litian96/FedProx
Link to the paper: https://arxiv.org/abs/1812.06127
Upvotes: 1
Views: 1169
Reputation: 296
I provide below my implementation of FedProx in TFF. I am not 100% sure that this is the right implementation; I post this answer also for discussing on actual code example.
I tried to follow the suggestions in the Jacub Konecny's answer and comment.
Starting from the simple_fedavg
(referring to the TFF Github repo), I just modified the client_update
method, and specifically changing the input argument for calculating the gradient with the GradientTape
, i.e. instaead of just passing in input the outputs.loss
, the tape calculates the gradient considering the outputs.loss
+ proximal_term
previosuly (and iteratively) calculated.
@tf.function
def client_update(model, dataset, server_message, client_optimizer):
"""Performans client local training of "model" on "dataset".Args:
model: A "tff.learning.Model".
dataset: A "tf.data.Dataset".
server_message: A "BroadcastMessage" from server.
client_optimizer: A "tf.keras.optimizers.Optimizer".
Returns:
A "ClientOutput".
"""
def difference_model_norm_2_square(global_model, local_model):
"""Calculates the squared l2 norm of a model difference (i.e.
local_model - global_model)
Args:
global_model: the model broadcast by the server
local_model: the current, in-training model
Returns: the squared norm
"""
model_difference = tf.nest.map_structure(lambda a, b: a - b,
local_model,
global_model)
squared_norm = tf.square(tf.linalg.global_norm(model_difference))
return squared_norm
model_weights = model.weights
initial_weights = server_message.model_weights
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
initial_weights)
num_examples = tf.constant(0, dtype=tf.int32)
loss_sum = tf.constant(0, dtype=tf.float32)
# Explicit use `iter` for dataset is a trick that makes TFF more robust in
# GPU simulation and slightly more performant in the unconventional usage
# of large number of small datasets.
for batch in iter(dataset):
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
# ------ FedProx ------
mu = tf.constant(0.2, dtype=tf.float32)
prox_term =(mu/2)*difference_model_norm_2_square(model_weights.trainable, initial_weights.trainable)
fedprox_loss = outputs.loss + prox_term
# Letting GradientTape dealing with the FedProx's loss
grads = tape.gradient(fedprox_loss, model_weights.trainable)
client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
batch_size = tf.shape(batch['x'])[0]
num_examples += batch_size
loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)
weights_delta = tf.nest.map_structure(lambda a, b: a - b,
model_weights.trainable,
initial_weights.trainable)
client_weight = tf.cast(num_examples, tf.float32)
return ClientOutput(weights_delta, client_weight, loss_sum / client_weight)
Upvotes: 1
Reputation: 900
At this moment, FedProx implementation is not available. I agree it would be a valuable algorithm to have.
If you are interested in contributing FedProx, the best place to start would be simple_fedavg
which is a minimal implementation of FedAvg meant as a starting point for extensions -- see the readme there for more details.
I think the major change would need to happen to the client_update
method, where you would add the proximal term depending on model_weights
and initial_weights
to the loss computed in forward pass.
Upvotes: 1