Gere
Gere

Reputation: 12717

How to get total test accuracy for pytorch lightning?

How can the trainer.test method be used to get total accuracy over all batches?

I know I can implement model.test_step but that is for a single batch only. I need the accuracy over the whole data set. I can use torchmetrics.Accuracy to accumulate accuracy. But what is the proper way to combine that and get the total accuracy out? What is model.test_step supposed to return anyway since batchwise test scores are not very useful? I could hack it somehow, but I'm surprised that I couldn't find any example on the internet that demonstrates how to get accuracy with the pytorch-lightning native way.

Upvotes: 7

Views: 7133

Answers (3)

IcecreamArtist
IcecreamArtist

Reputation: 21

You can create a method test_epoch_end inside your lightning module, whose input arguments consist of a all_outputs. This will be an array of all the return values from each test_step method. Then, you can calculate customized metrics on these values all at once.

For example:

def test_step(self, batch, batch_idx):
    targets, logits, loss = self(
        batch, batch_idx, "test")
    return targets, logits

def test_epoch_end(self, all_outputs):
    # all_outputs is a list of [targets, logits] tuples from each test.
    for target, logit in all_outputs:
        # do something
    # calculate metrics based on all outputs from test.
    # log the results.
    log = {}
    log.update({'Accuracy': acc})
    self.log_dict(log)

Another reference may help you: Not able to print overall results from testing.

Upvotes: 0

Abdul Mukit
Abdul Mukit

Reputation: 160

I am working on a notebook. I did some initial experimentation with the following code.

def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    self.test_acc(logits, y)
    self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)

Prints out a nicely formatted text after calling

model = Cifar100Model()
trainer = pl.Trainer(max_epochs=1, accelerator='cpu')
trainer.test(model, test_dataloader)

This printed test_acc 0.008200000040233135

I tried verifying whether the printed value is actually an average over the test data batches. By modifying the test_step as follows:

def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    self.test_acc(logits, y)
    self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)

    preds = logits.argmax(dim=-1)
    acc = (y == preds).float().mean()
    print(acc)

Then ran trainer.test() again. This time the following values were printed out:
tensor(0.0049)
tensor(0.0078)
tensor(0.0088)
tensor(0.0078)
tensor(0.0122)
Averaging them gets me: 0.0083 which is very close to the value printed by the test_step().

The logic behind this solution is that I had specified in the

self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)

on_epoch = True, and I used a TorchMetric class, the average is computed by PL, automatically using the metric.compute() function.

I'll try to post my full notebook shortly. You can check there too.

Upvotes: 1

Mike B
Mike B

Reputation: 3476

You can see here (https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging) that the on_epoch argument in log automatically accumulates and logs at the end of the epoch. The right way of doing this would be:

from torchmetrics import Accuracy

def validation_step(self, batch, batch_idx): 
    x, y = batch 
    preds = self.forward(x) 
    loss = self.criterion(preds, y) 
    accuracy = Accuracy()
    acc = accuracy(preds, y)
    self.log('accuracy', acc, on_epoch=True)
    return loss 

If you want a custom reduction function you can set it using the reduce_fx argument, the default is torch.mean(). log() can be called from any method in you LightningModule

Upvotes: 3

Related Questions