Reputation: 371
I'm new in pytorch, and i have been stuck for a while on this problem. I have trained a CNN for classifying X-ray images. The images can be found in this Kaggle page https://www.kaggle.com/prashant268/chest-xray-covid19-pneumonia/ . I managed to get good accuracy both on training and test data, but when i try to make predictions on new images i get the same (wrong class) output for every image. Here's my model in detail.
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import glob
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.autograd import Variable
import torchvision
import pathlib
from google.colab import drive
drive.mount('/content/drive')
epochs = 20
batch_size = 128
learning_rate = 0.001
#Data Transformation
transformer = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
#Load data with DataLoader
train_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Data/train'
test_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Data/test'
train_loader = DataLoader(torchvision.datasets.ImageFolder(train_path,transform = transformer), batch_size= batch_size, shuffle= True)
test_loader = DataLoader(torchvision.datasets.ImageFolder(test_path,transform = transformer), batch_size= batch_size, shuffle= False)
root = pathlib.Path(train_path)
classes = sorted([j.name.split('/')[-1] for j in root.iterdir()])
print(classes)
train_count = len(glob.glob(train_path+'/**/*.jpg')) + len(glob.glob(train_path+'/**/*.png')) + len(glob.glob(train_path+'/**/*.jpeg'))
test_count = len(glob.glob(test_path+'/**/*.jpg')) + len(glob.glob(test_path+'/**/*.png')) + len(glob.glob(test_path+'/**/*.jpeg'))
print(train_count,test_count)
#Create the CNN
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
'''nout = [(width + 2*padding - kernel_size) / stride] + 1 '''
# [128,3,224,224]
self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 12, kernel_size = 5)
# [4,12,220,220]
self.pool1 = nn.MaxPool2d(2,2) #reduces the images by a factor of 2
# [4,12,110,110]
self.conv2 = nn.Conv2d(in_channels = 12, out_channels = 24, kernel_size = 5)
# [4,24,106,106]
self.pool2 = nn.MaxPool2d(2,2)
# [4,24,53,53] which becomes the input of the fully connected layer
self.fc1 = nn.Linear(in_features = (24 * 53 * 53), out_features = 120)
self.fc2 = nn.Linear(in_features = 120, out_features = 84)
self.fc3 = nn.Linear(in_features = 84, out_features = len(classes)) #final layer, output will be the number of classes
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 24 * 53 * 53)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Training the model
model = CNN()
loss_function = nn.CrossEntropyLoss() #includes the softmax activation function
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
n_total_steps = len(train_loader)
for epoch in range(epochs):
n_correct = 0
n_samples = 0
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
loss = loss_function(outputs, labels)
# Backpropagation and optimization
optimizer.zero_grad() #empty gradients
loss.backward()
optimizer.step()
acc = 100.0 * n_correct / n_samples
print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Accuracy: {round(acc,2)} %, Loss: {loss.item():.4f}')
print('Done!!')
# Testing the model
with torch.no_grad():
n_correct = 0
n_samples = 0
n_class_correct = [0 for i in range(3)]
n_class_samples = [0 for i in range(3)]
for images, labels in test_loader:
outputs = model(images)
# max returns (value ,index)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
acc = 100.0 * n_correct / n_samples
print(f'Accuracy of the network: {acc} %')
torch.save(model.state_dict(),'/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/model.model')
For loading the model and trying to make predictions on new images, the code is as follows:
checkpoint = torch.load('/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/model.model')
model = CNN()
model.load_state_dict(checkpoint)
model.eval()
#Data Transformation
transformer = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])
#Making preidctions on new data
from PIL import Image
def prediction(img_path,transformer):
image = Image.open(img_path).convert('RGB')
image_tensor = transformer(image)
image_tensor = image_tensor.unsqueeze_(0) #so img is not treated as a batch
input_img = Variable(image_tensor)
output = model(input_img)
#print(output)
index = output.data.numpy().argmax()
pred = classes[index]
return pred
pred_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Test_images/Data/'
test_imgs = glob.glob(pred_path+'/*')
for i in test_imgs:
print(prediction(i,transformer))
I'm guessing the problem must be in the way that i am preprocessing the data, although i cannot find my mistake. Any help will be deeply appreciated, since i have been stuck on this for a while now. p.s. i can share my notebook as well, if it is of any help
Upvotes: 0
Views: 3832
Reputation: 1686
Regarding your problem, I have a really good way to debug this to target where the problem most likely will be and so it will be really easy to fix your issue.
So, my debugging process would be based on the fact that your CNN performs well on the test set. Firstly set your test loader batch size to 1 temporarily. After that, One thing to do is in your test loop when you calculate the amount correct, you can run the following code:
#Your code
outputs = model(images) # Really only one image and 1 output.
#Altered Code:
correct = (predicted == labels).sum().item() # This will be either 1 or 0 since you have only one image per batch
# My new code:
if correct:
# if value is 1 instead of 0 then turn value into a single image with no batch size
single_correct_image = images.squeeze(0)
# Then convert tensor image into PIL image
pil_image = transforms.ToPILImage()(single_correct_image)
# Save the pil image to any directory specified in quotes.
pil_image = pil_image.save("/content")
#Terminate testing process. Ignore Value Error if it says terminating process
raise ValueError("terminating process")
Now you have an image saved to disk that you know is correct in the test set. The next step would be to open such image and run it to your predict function. Couple of things can happen and thus give info about your situation
For additional reference, if you still get very stuck to the point where you feel like you can't solve it, then I suggest using this notebook to guide and give some suggestions on what code to atleast inspect.
https://www.kaggle.com/salvation23/xray-cnn-pytorch
Sarthak Jain
Upvotes: 1