Kane
Kane

Reputation: 924

How does Tensorflow Federated update model from server

New to Tensorflow so not sure if this is a specific question for Tensorflow Federated.

I'm studying adversarial attack on federated learning in this code. I'm curious how the weights received from the server are updated at the client.

For example, here is the code for a 'benign' update:

@tf.function
def compute_benign_update():
  """compute benign update sent back to the server."""
  tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                        initial_weights)

  num_examples_sum = benign_dataset.reduce(initial_state=tf.constant(0),
                                           reduce_func=reduce_fn)

  weights_delta_benign = tf.nest.map_structure(lambda a, b: a - b,
                                               model_weights.trainable,
                                               initial_weights.trainable)

  aggregated_outputs = model.report_local_outputs()
  return weights_delta_benign, aggregated_outputs, num_examples_sum

I can see that the initial weights received from the server are assigned to model_weights then reduce_fn is used to train on a batch of data on the local client.

@tf.function
def reduce_fn(num_examples_sum, batch):
  """Runs `tff.learning.Model.train_on_batch` on local client batch."""
  with tf.GradientTape() as tape:
    output = model.forward_pass(batch)
  gradients = tape.gradient(output.loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return num_examples_sum + tf.shape(output.predictions)[0]

Inside this function training occurs and (I think) model.trainable_variables is updated. The part that doesn't make sense to me is how the weights_delta_benign is calculated:

weights_delta_benign = tf.nest.map_structure(lambda a, b: a - b,
                                             model_weights.trainable,
                                             initial_weights.trainable)

It seems that the difference between model_weights.trainable and initial_weights.trainable is used, but didn't we originally set these to be equal in the first line of the compute_benign_update() function? I'm assuming the reduce_fn alters initial_weights somehow but I don't see the connection between model.trainable_variables used in the reduce function and initial_weights.trainable_variables.

Thanks, any help appreciated!

Upvotes: 2

Views: 242

Answers (1)

Jakub Konecny
Jakub Konecny

Reputation: 900

In the code you point to, initial_weights is only a collection of values (tf.Tensor objects), and model_weights is a reference to the model's variables (tf.Variable objects). We use initial_weights to assign the initial value to the model's variables.

Then, in the call to optimizer.apply_gradients(zip(gradients, model.trainable_variables)), you only modify the model's variables. (model.trainable_variables, which refers is the same objects as model_weights.trainable. I acknowledge, this is a bit confusing.)

So the subsequent computation of weights_delta_benign is computing the difference between the model's trainable variables at the end and start of the client's training procedure.

Upvotes: 2

Related Questions