Yuanli Wang
Yuanli Wang

Reputation: 11

Does tensorflow-federated support dynamic batch size?

Does tensorflow-federated support assigning different batch-size for different simulated devices, and changing batch-size for different epoch?

Upvotes: 1

Views: 472

Answers (1)

Keith Rush
Keith Rush

Reputation: 1405

TFF does support dynamic batch size--this is encoded at the type signature level, in the shape attribute of tff.TensorType. Any dimension with associated shape of None will be dynamic. If you have a tff.learning.Model, the input_spec attribute should have None-size dimension for any dimension you wish to be dynamic.

The exact correct specification type signature is dependent on exactly what you are looking to do with these dynamic shapes. Here is a quick example that might illustrate a little more:

Suppose you have a Keras model model and a tff.simulation.ClientData object client_data. The input_spec argument to tff.learning.from_keras_model will populate the tff.learning.Model input_spec directly, so it is here you wish to specify that your batch dimension can vary:

input_spec = collections.OrderedDict(
    x=tf.TensorSpec(dtype=tf.float32, shape=[None, 784]),
    y=tf.TensorSpec(dtype=tf.int64, shape=[None]),
)

def model_fn():
  tff_model = tff.learning.from_keras_model(
      keras_model=model,
      input_spec=input_spec,
      # other args,...
  )

Then, inside your Python-driving training loop, you can use different batch sizes across different rounds of training (or even within the same round I suppose), like so (assuming we wrote a function called _whatever_batch_size_I_want which takes the round number as an argument and returns whatever batch size is appropriate for that round):

fedavg_process = tff.learning.build_federated_averaging_process(
    model_fn=model_fn, # other args, ...)

state = fedavg_process.initialize()

for k in range(NUM_ROUNDS):
  batch_size = _whatever_batch_size_you_want(k)
  sampled_client_ids = random.choices(
      client_data.client_ids, k=NUM_CLIENTS_PER_ROUND)
  client_datasets = [
      client_data.create_tf_dataset_for_client(x) for x in sampled_client_ids]
  batched_client_datasets = [ds.batch(batch_size) for ds in client_datasets]
  state = fedavg_process.next(state, batched_client_datasets)

You can do even fancier things with dynamic shapes and input spec arguments if desired; for example, you can train a sequence-processing model which takes variable-length sequences by specifying the sequence dimension to have size None as well.

Upvotes: 2

Related Questions