Sharare Zehtabian
Sharare Zehtabian

Reputation: 53

What exactly happens when we call IterativeProcess.next on federated training data?

I went through the Federated Learning tutorial. I was wondering how .next function work when we call it on an iterative process. Assuming that we have train data which is a list of lists. The outer list is a list of clients and the inner lists are batches of data for each client. Then, we create an iterative process, for example, a federated averaging process and we initialize the state. What exactly happens when we call IterativeProcess.next on this training data. Does it take from these data randomly in each round? Or just take data from each client one batch at a time?

Assume that I have a list of tf.data.Datasets each representing a client data. How can I add some randomness to sampling from this list for the next iteration of federated learning?

My datasets are not necessarily the same length. When one of them is completely iterated over, does this dataset waits for all other datasets to completely iterate over their data or not?

Upvotes: 5

Views: 631

Answers (2)

Zachary Garrett
Zachary Garrett

Reputation: 2941

Does (the iterative process) take from these data randomly in each round? Or just take data from each client one batch at a time?

The TFF tutorials all use tff.learning.build_federated_averaging_process which constructs a tff.templates.IterativeProcess that implements the Federated Averaging algorithm (McMahan et al. 2017). In this algorithm each "round" (one invocation of IterativePocess.next()) processes as many batches of examples on each client as the tf.data.Dataset is setup to produce in one iteration. tf.data: Build TensorFlow input pipelines is a great guide for tf.data.Dataset.

The order in which examples are processed is determined by how the tf.data.Datasets that were passed into the next() method as arguments were constructed. For example, in the Federated Learning for Text Generation tutorial's section titled Load and Preprocess the Federated Shakespeare Data, each client dataset is setup with preprocessing pipeline:

def preprocess(dataset):
  return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))

The next function will iterate over these datasets in its entirety once each invocation of next(), in this case since there is no call to tf.data.Dataset.repeat(), next() will have each client see all of its examples once.

Assume that I have a list of tf.data.Datasets each representing a client data. How can I add some randomness to sampling from this list for the next iteration of federated learning?

To add randomness to each client's dataset, one could use the tf.data.Dataset.shuffle() to first randomize the order of yielded examples, and then tf.data.Dataset.take() to take only a sample of that new random ordering. This could be added to the preprocess() method above.

Alternatively, randomness in the selection of clients (e.g. randomly picking which clients participate each round) can be done using any Python library to sub-sample the list of datasets, e.g. Python's random.sample.

My datasets are not necessarily the same length. When one of them is completely iterated over, does this dataset waits for all other datasets to completely iterate over their data or not?

Each dataset is only iterated over once on each invocation of .next(). This is in line with the synchronous communication "rounds" in McMahan et al. 2017. In some sense, yes, the datasets "wait" for each other.

Upvotes: 5

Keith Rush
Keith Rush

Reputation: 1405

Any tff.Computation (like next) will always run the entire specified computation. If your tff.templates.IterativeProcess is, for example, the result of tff.learning.build_federated_averaging_process, its next function will represent one round of the federated averaging algorithm.

The federated averaging algorithm runs training for a fixed number of epochs (let's say 1 for simplicity) over each local dataset, and averages the model updates in a data-weighted manner at the server in order to complete a round--see Algorithm 1 in the original federated averaging paper for a specification of the algorithm.

Now, for how TFF represents and executes this algorithm. From the documentation for build_federated_averaging_process, the next function has type signature:

(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)

TFF's type system represents a dataset as a tff.SequenceType (this is the meaning of the * above), so the second element in the parameter of the type signature represents a set (technically a multiset) of datasets with elements of type B, placed at the clients.

What this means in your example is as follows. You have a list of tf.data.Datasets, each of which represents the local data on each client--you can think of the list as representing the federated placement. In this context, TFF executing the entire specified computation means: TFF will treat every item in the list as a client to be trained on in this round. In the terms of the algorithm linked above, your list of datasets represents the set S_t.

TFF will faithfully execute one round of the federated averaging algorithm, with the Dataset elements of your list representing the clients selected for this round. Training will be run for a single epoch on each client (in parallel); if datasets have different amounts of data, you are correct that the training on each client is likely to finish at different times. However, this is the correct semantics of a single round of the federated averaging algorithm, as opposed to a parameterization of a similar algorithm like Reptile, which runs for a fixed number of steps for each client.

If you wish to select a subset of clients to run a round of training on, this should be done in Python, before calling into TFF, e.g.:

state = iterative_process.initialize()

# ls is list of datasets
sampled_clients = random.sample(ls, N_CLIENTS)

state = iterative_process.next(state, sampled_clients)

Generally, you can think of the Python runtime as an "experiment driver" layer--any selection of clients, for example, should happen at this layer. See the beginning of this answer for further detail on this.

Upvotes: 4

Related Questions