Reputation: 115
I’m training a CNN to predict digits using the MNIST database. I’m doing Data Augmentation and for some reason accuracy sharply decreases when advancing to next epoch (iteration 60 in the image)
It has to do with data augmentation (transform = my_transforms in the code) because when I deactivate augmentation (transform = None) accuracy doesn't decrease when advancing to next epoch. But I can't explain why. Does anyone have an idea why this happens?
my_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop((25,25)),
transforms.Resize((28,28)),
transforms.RandomRotation(degrees=45, fill=255),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])
dataset = MNISTDataset(transform = my_transforms)
train_loader = DataLoader(dataset = dataset, batch_size = 1000, shuffle=True)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,10,kernel_size=5)
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2 = nn.Conv2d(10,20,kernel_size=5)
self.fc1 = nn.Linear(20*4*4, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,20*4*4)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
net = Net()
loss_function=nn.NLLLoss()
optimizer=optim.Adam(net.parameters())
EPOCHS=2
iteracion = 0
for epoch in range(EPOCHS):
for data in train_loader:
inputs, labels = data
inputs = inputs.view(-1,1,28,28)
net.zero_grad()
probabilities=net(inputs)
matches=[torch.argmax(i)==int(j) for i,j in zip(probabilities,labels)]
in_batch_acc=matches.count(True)/len(matches)
loss=loss_function(torch.log(probabilities), labels)
print('Loss:', round(float(loss), 3))
print('In-batch acc:', round(in_batch_acc, 2))
iteracion += 1
loss.backward()
optimizer.step()
Upvotes: 0
Views: 489
Reputation: 51
I replicated your model with data augmentation and tried to plot the Accuracy and Loss and it seems that the problem is the way you are plotting.
In the following lines I attach my code and Loss and Accuracy plots:
Code:
# -- Imports -- #
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
import matplotlib.pyplot as plt
# -- Data Loader -- #
my_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop((25,25)),
transforms.Resize((28,28)),
transforms.RandomRotation(degrees=45, fill=255),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])
dataset = datasets.MNIST('../data', train=True, download=True,
transform = transforms.Compose([
transforms.RandomCrop((25,25)),
transforms.Resize((28,28)),
transforms.RandomRotation(degrees=45, fill=255),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
]))
train_loader = DataLoader(dataset = dataset, batch_size = 1000, shuffle=True)
# -- Define Model -- #
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,10,kernel_size=5)
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2 = nn.Conv2d(10,20,kernel_size=5)
self.fc1 = nn.Linear(20*4*4, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,20*4*4)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
device = 'cuda'
net = Net()
net.to(device)
loss_function=nn.NLLLoss()
optimizer=optim.Adam(net.parameters())
EPOCHS=2
iteracion = 0
accuracy = []
loss_record = []
for epoch in range(EPOCHS):
for data in train_loader:
inputs, labels = data
inputs = inputs.view(-1,1,28,28)
inputs, labels = inputs.to(device), labels.to(device)
# -- Forward -- #
net.zero_grad()
probabilities=net(inputs)
matches=[torch.argmax(i)==int(j) for i,j in zip(probabilities,labels)]
in_batch_acc=matches.count(True)/len(matches)
loss=loss_function(torch.log(probabilities), labels)
# -- Statistics -- #
accuracy.append(in_batch_acc)
loss_record.append(loss)
print('Loss:', round(float(loss), 3))
print('In-batch acc:', round(in_batch_acc, 2))
iteracion += 1
loss.backward()
optimizer.step()
# -- Accuracy plot -- #
iterations = range(0,120)
plt.plot(iterations, accuracy, 'g', label='Accuracy')
plt.title('Accuracy')
plt.xlabel('Iterations')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
# -- Loss plot -- #
plt.plot(iterations, loss_record, label='Loss')
plt.title('Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()
Plots:
As you can see, there are not jumps when (iteration = 61) -> Next epoch
Upvotes: 1