Reputation: 55
I am trying to implement a custom aggregation using TFF by changing the code from this tutorial . I would like to rewrite next_fn
so that all the client weights are placed at the server for further computations. As federated_collect
was removed from tff-nightly, I am trying to do that using federated_aggregate
.
This is what I have so far:
def accumulate(x, y):
x.append(y)
return x
def merge(x, y):
x.extend(y)
return y
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_state, federated_dataset):
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
client_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
z = []
agg_result = tff.federated_aggregate(client_deltas, z,
accumulate=tff.tf_computation(accumulate),
merge=tff.tf_computation(merge),
report=tff.tf_computation(lambda x: x))
new_weights = do_smth_with_result(agg_result)
server_state = tff.federated_map(
server_update_fn, (server_state, new_weights))
return server_state
However this results in the following Exception:
File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 351, in <module>
def next_fn(server_state, federated_dataset):
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 494, in __call__
wrapped_func = self._strategy(
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 222, in __call__
result = fn_to_wrap(*args, **kwargs)
File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 358, in next_fn
agg_result = tff.federated_aggregate(client_deltas, z,
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/federated_context/intrinsics.py", line 140, in federated_aggregate
raise TypeError(
TypeError: Expected parameter `accumulate` to be of type (<<<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>), but received (<<>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>) instead.
Upvotes: 2
Views: 253
Reputation: 900
Try using tff.aggregators.federated_sample with max_num_samples
being equal to the number of clients you have.
That should be a simple drop-in replacement for how you would previously use tff.federated_collect
.
In your accumulate
, the issue is that you are changing number of tensors the accumulator would contain, so you get an error when accumulating more than a single accumuland. If you would want to go this way though, for a rank-1 accumuland with k
elements, you could probably do something like the following instead:
@tff.tf_computation(tff.types.TensorType(tf.float32, [None, k]),
tff.types.TensorType(tf.float32, [k]))
def accumulate(accumulator, accumuland):
return tf.concat([accumulator, tf.expand_dims(accumuland, axis=0)], axis=0)
Upvotes: 1