Reputation: 23
num_samples = 10
def predict(x):
sampled_models = [guide(None, None) for _ in range(num_samples)]
yhats = [model(x).data for model in sampled_models]
mean = torch.mean(torch.stack(yhats), 0)
return np.argmax(mean.numpy(), axis=1)
print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
images, labels = data
predicted = predict(images.view(-1,28*28))
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("accuracy: %d %%" % (100 * correct / total))
Error:
correct += (predicted == labels).sum().item() TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of: * (Tensor other) didn't match because some of the arguments have invalid types: (!numpy.ndarray!) * (Number other) didn't match because some of the arguments have invalid types: (!numpy.ndarray!)
*
Upvotes: 2
Views: 18681
Reputation: 114816
You are trying to compare predicted
and labels
. However, your predicted
is an np.array
while labels
is a torch.tensor
therefore eq()
(the ==
operator) cannot compare between them.
Replace the np.argmax
with torch.argmax
:
return torch.argmax(mean, dim=1)
And you should be okay.
Upvotes: 4