Reputation: 35
I tried to use SGD on MNIST dataset with batch size of 32, but the loss does not decrease at all. I checked my model, loss function and read documentation but couldn't figure out what I've done wrong.
I defined my neural network as below
class classification(nn.Module):
def __init__(self):
super(classification, self).__init__()
# construct layers for a neural network
self.classifier1 = nn.Sequential(
nn.Linear(in_features=28*28, out_features=20*20),
nn.Sigmoid(),
)
self.classifier2 = nn.Sequential(
nn.Linear(in_features=20*20, out_features=10*10),
nn.Sigmoid(),
)
self.classifier3 = nn.Sequential(
nn.Linear(in_features=10*10, out_features=10),
nn.LogSoftmax(dim=1),
)
def forward(self, inputs): # [batchSize, 1, 28, 28]
x = inputs.view(inputs.size(0), -1) # [batchSize, 28*28]
x = self.classifier1(x) # [batchSize, 20*20]
x = self.classifier2(x) # [batchSize, 10*10]
out = self.classifier3(x) # [batchSize, 10]
return out
And I defined my training process as below
classifier = classification().to("cuda")
#optimizer
optimizer = torch.optim.SGD(classifier.parameters(), lr=learning_rate_value)
#loss function
criterion = nn.NLLLoss()
batch_size=32
epoch = 30
#array to save loss history
loss_train_arr=np.zeros(epoch)
#used DataLoader to make split batch
batched_train = torch.utils.data.DataLoader(training_set, batch_size, shuffle=True)
for i in range(epoch):
loss_train=0
#train and compute loss, accuracy
for img, label in batched_train:
img=img.to(device)
label=label.to(device)
optimizer.zero_grad()
predicted = classifier(img)
label_predicted = torch.argmax(predicted,dim=1)
loss = criterion(predicted, label)
loss.backward
optimizer.step()
loss_train += loss.item()
loss_train_arr[i]=loss_train/(len(batched_train.dataset)/batch_size)
I am using a model with LogSoftmax layer, so my loss function seems right. But the loss does not decrease at all.
Upvotes: 0
Views: 754
Reputation: 7209
If the code you posted is the exact code you use, the problem is that you don't actually call backward on the loss (missing parentheses ()
).
Upvotes: 2