Reputation: 12717
How can the trainer.test
method be used to get total accuracy over all batches?
I know I can implement model.test_step
but that is for a single batch only. I need the accuracy over the whole data set. I can use torchmetrics.Accuracy
to accumulate accuracy. But what is the proper way to combine that and get the total accuracy out? What is model.test_step
supposed to return anyway since batchwise test scores are not very useful? I could hack it somehow, but I'm surprised that I couldn't find any example on the internet that demonstrates how to get accuracy with the pytorch-lightning native way.
Upvotes: 7
Views: 7133
Reputation: 21
You can create a method test_epoch_end
inside your lightning module, whose input arguments consist of a all_outputs
. This will be an array of all the return values from each test_step
method. Then, you can calculate customized metrics on these values all at once.
For example:
def test_step(self, batch, batch_idx):
targets, logits, loss = self(
batch, batch_idx, "test")
return targets, logits
def test_epoch_end(self, all_outputs):
# all_outputs is a list of [targets, logits] tuples from each test.
for target, logit in all_outputs:
# do something
# calculate metrics based on all outputs from test.
# log the results.
log = {}
log.update({'Accuracy': acc})
self.log_dict(log)
Another reference may help you: Not able to print overall results from testing.
Upvotes: 0
Reputation: 160
I am working on a notebook. I did some initial experimentation with the following code.
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
Prints out a nicely formatted text after calling
model = Cifar100Model()
trainer = pl.Trainer(max_epochs=1, accelerator='cpu')
trainer.test(model, test_dataloader)
This printed test_acc 0.008200000040233135
I tried verifying whether the printed value is actually an average over the test data batches. By modifying the test_step as follows:
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.test_acc(logits, y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
preds = logits.argmax(dim=-1)
acc = (y == preds).float().mean()
print(acc)
Then ran trainer.test() again. This time the following values were printed out:
tensor(0.0049)
tensor(0.0078)
tensor(0.0088)
tensor(0.0078)
tensor(0.0122)
Averaging them gets me: 0.0083
which is very close to the value printed by the test_step().
The logic behind this solution is that I had specified in the
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
on_epoch = True, and I used a TorchMetric class, the average is computed by PL, automatically using the metric.compute() function.
I'll try to post my full notebook shortly. You can check there too.
Upvotes: 1
Reputation: 3476
You can see here (https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging) that the on_epoch
argument in log
automatically accumulates and logs at the end of the epoch. The right way of doing this would be:
from torchmetrics import Accuracy
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.criterion(preds, y)
accuracy = Accuracy()
acc = accuracy(preds, y)
self.log('accuracy', acc, on_epoch=True)
return loss
If you want a custom reduction function you can set it using the reduce_fx
argument, the default is torch.mean()
. log()
can be called from any method in you LightningModule
Upvotes: 3