Reputation: 21
I tried to print client updates as mentioned.
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return
{'num_examples': tff.federated_sum(metrics.num_examples),'loss': tff.federated_mean(metrics.loss, metrics.num_examples),'accuracy': tff.federated_mean(metrics.accuracy, metrics.num_examples),'per_client/num_examples': tff.federated_collect(metrics.num_examples),'per_client/loss': tff.federated_collect(metrics.loss),'per_client/accuracy': tff.federated_collect(metrics.accuracy)}
But it did not work. It displays blank values for a client on executing the first round. Can you please look into it. Thanks
round 1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', {'accuracy': 0.0990099, 'loss': 2.7817621, 'num_examples': 202.0, 'per_client/accuracy': <ConcatenateDataset shapes: (), types: tf.float32>, 'per_client/loss': <ConcatenateDataset shapes: (), types: tf.float32>, 'per_client/num_examples': <ConcatenateDataset shapes: (), types: tf.float32>})])
Upvotes: 2
Views: 100
Reputation: 900
Those values are represented as tf.data.Dataset
. You can iterate through their values using the basic mechanics, see https://www.tensorflow.org/guide/data#basic_mechanics
Upvotes: 0