colin byrne
colin byrne

Reputation: 81

Tensorflow Federated Image Classification Example #Epochs has major effect. Is the model overfitting?

I've been trying to characterize the learning process (accuracy and loss) on the Federated Learning for Image Classification notebook tutorial with TF Federated.

I'm seeing major improvements in speed of convergence by modifying the epoch hyperparameter. Changing epochs from 5, 10, 20 etc. But I'm also seeing major increase in training accuracy. I suspect overfitting is occurring, though then I evaluate on the test set accuracy is still high.

Wondering what is going on. ?

My understanding is that the epoch param controls the # of forward/back prop on each client per round of training. Is this correct ? So ie 10 rounds of training on 10 clients with 10 epochs would be 10 Epochs X 10 Clients X 10 rounds. Realise a lager range of clients is needed etc but I was expecting to see poorer accuracy on the test set.

What can I do to see whats going on. Could I use the evaluation check with something like learning curves to to see if overfitting is occurring ?

test_metrics = evaluation(state.model, federated_test_data) Only appears to give a single data point, how can I get the individual test accuracy for each test example validated?

Upvotes: 1

Views: 348

Answers (1)

Zachary Garrett
Zachary Garrett

Reputation: 2941

Increasing the number of client epochs can indeed increase per-round convergence rate; but you're absolutely right that there is a risk of overfitting.

In the Federated Averaging algorithm, the number of client epochs determines the amount of "sequential progress" (or learning) each client makes before updating the global model. More epochs will result in more local progress each round, this can manifest as a much faster per-round convergence rate. Plotting this against the number of examples seen on all clients may instead show a more similar convergence rate however.

In the federated optimization setting, there is a new risk of overfitting that may be correlated to how non-IID each client dataset is. If each client dataset has the same distribution as the global data distribution, the same practices used for non-federated optimization can be used. The less similar each client dataset is to the "global" dataset, the more likely there will be "drift" (clients converge to different optimal points) when using a high number of client epochs during later rounds. Training accuracy can still appear high in this setting, as each client is fitting to its own local data well during local training. However test accuracy is less likely to improve, as the global model average likely will average out to be very small (the different client-local optimal points cancelling each other out). Praneeth et. al has some discussion about this.

Upvotes: 1

Related Questions