Reputation: 11
Does tensorflow-federated support assigning different batch-size for different simulated devices, and changing batch-size for different epoch?
Upvotes: 1
Views: 472
Reputation: 1405
TFF does support dynamic batch size--this is encoded at the type signature level, in the shape attribute of tff.TensorType
. Any dimension with associated shape of None
will be dynamic. If you have a tff.learning.Model
, the input_spec
attribute should have None
-size dimension for any dimension you wish to be dynamic.
The exact correct specification type signature is dependent on exactly what you are looking to do with these dynamic shapes. Here is a quick example that might illustrate a little more:
Suppose you have a Keras model model
and a tff.simulation.ClientData
object client_data
. The input_spec
argument to tff.learning.from_keras_model
will populate the tff.learning.Model
input_spec
directly, so it is here you wish to specify that your batch dimension can vary:
input_spec = collections.OrderedDict(
x=tf.TensorSpec(dtype=tf.float32, shape=[None, 784]),
y=tf.TensorSpec(dtype=tf.int64, shape=[None]),
)
def model_fn():
tff_model = tff.learning.from_keras_model(
keras_model=model,
input_spec=input_spec,
# other args,...
)
Then, inside your Python-driving training loop, you can use different batch sizes across different rounds of training (or even within the same round I suppose), like so (assuming we wrote a function called _whatever_batch_size_I_want
which takes the round number as an argument and returns whatever batch size is appropriate for that round):
fedavg_process = tff.learning.build_federated_averaging_process(
model_fn=model_fn, # other args, ...)
state = fedavg_process.initialize()
for k in range(NUM_ROUNDS):
batch_size = _whatever_batch_size_you_want(k)
sampled_client_ids = random.choices(
client_data.client_ids, k=NUM_CLIENTS_PER_ROUND)
client_datasets = [
client_data.create_tf_dataset_for_client(x) for x in sampled_client_ids]
batched_client_datasets = [ds.batch(batch_size) for ds in client_datasets]
state = fedavg_process.next(state, batched_client_datasets)
You can do even fancier things with dynamic shapes and input spec arguments if desired; for example, you can train a sequence-processing model which takes variable-length sequences by specifying the sequence dimension to have size None
as well.
Upvotes: 2