Boooooooooms
Boooooooooms

Reputation: 296

How to use PyTorch to print out the prediction accuracy of every class?

I am trying to use PyTorch to print out the prediction accuracy of every class based on the official tutorial link

But things seem to go wrong. My code intends to do this work is as following:

    for epoch in range(num_epochs):

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
       ... (this is given by the tutorial)

    (my code)

    class_correct = list(0. for i in range(3))
    class_total = list(0. for i in range(3))

    for data in dataloaders['val']:
        images, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        c = (predicted == labels.data).squeeze()

        for i in range(4):
            label = labels.data[i]
            class_correct[label] += c[i]
            class_total[label] += 1

    for i in range(3):
        print('Accuracy of {} : {} / {} = {:.4f} %'.format(i, 
class_correct[i], class_total[i], 100 * class_correct[i].item() / 
class_total[i]))

    print(file = f)
    print()

For example, the output of epoch 1/1 is : enter image description here

I think the following equation should be satisfied:

running_corrects := 2 + 2

But things does not happen as I think.

What's wrong there? Hope someone can point out my fault and teach me how to do this correctly.

Thx!

Upvotes: 1

Views: 4057

Answers (1)

Boooooooooms
Boooooooooms

Reputation: 296

Finally, I solved this problem. First, I compared two models' parameters and found out they were the same. So I confirmed that the model is the same. And then, I checked out two inputs and surprisedly found out they were different.

So I reviewed two models' inputs carefully and the answer was that the arguments passed to the second model did not update.

Code:

for data in dataloaders['val']:
    images, labels = data
    outputs = model(inputs)

Change to:

for data in dataloaders['val']:
    inputs, labels = data
    outputs = model(inputs)

Done!

Upvotes: 1

Related Questions