Michael Wolz
Michael Wolz

Reputation: 23

PyTorch model validation: The size of tensor a (32) must match the size of tensor b (13)

I am a very beginner in case of machine learning. So for learning purpose I am trying to develop a simple CNN to classify chess pieces. The net already works and I can train it but I have a problem with my validation function.

I can't compare my prediction with my target_data because my prediction is only a tensor of size 13 while target.data is [batch_size]x13. I can't figure out where my mistake is. The PyTorch examples are almost all using this function to compare the prediction with the target data.

It would be really great if anybody could help me out here.

You can lookup the rest of the code here: https://github.com/michaelwolz/ChessML/blob/master/train.ipynb

def validate(model, validation_data, criterion):
    model.eval()
    loss = 0
    correct = 0

    for i in range(len(validation_data)):
        data, target = validation_data[i][0], validation_data[i][1]
        target = torch.Tensor(target)

        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        out = model(data)

        loss += criterion(out, target).item()

        _, prediction = torch.max(out.data, 1)
        correct += (prediction == target.data).sum().item()

    loss = loss / len(validation_data)
    print("###################################")
    print("Average loss:", loss)
    print("Accuracy:", 100. * correct / len(validation_data))
    print("###################################")

Error:

<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, 

criterion)
     17 
     18         _, prediction = torch.max(out.data, 1)
---> 19         correct += (prediction == target.data).sum().item()
     20 
     21     loss = loss / len(validation_data)

RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1

Edit: My labels look like this:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Each index represents one class.
The output of the torch.max() function seems to be the index of the class. I don't understand how I could compare the index to the target_label. I mean I could just write a function which checks if there is a 1 at the predicted index but I think that my mistake is somewhere else.

Upvotes: 0

Views: 1683

Answers (1)

dedObed
dedObed

Reputation: 1363

Simply run "argmax" on the target as well:

_, target = torch.max(target.data, 1)

Or better yet, just keep the target around as [example_1_class, example_2_class, ...], instead of 1-hot encoding.

Upvotes: 1

Related Questions