Olive Yew
Olive Yew

Reputation: 371

Making predictions on new images using a CNN in pytorch

I'm new in pytorch, and i have been stuck for a while on this problem. I have trained a CNN for classifying X-ray images. The images can be found in this Kaggle page https://www.kaggle.com/prashant268/chest-xray-covid19-pneumonia/ . I managed to get good accuracy both on training and test data, but when i try to make predictions on new images i get the same (wrong class) output for every image. Here's my model in detail.

import os 
import matplotlib.pyplot as plt
import numpy as np 
import torch
import glob 
import torch.nn.functional as F 
import torch.nn as nn 
from torchvision.transforms import transforms 
from torch.utils.data import DataLoader 
from torch.optim import Adam 
from torch.autograd import Variable 
import torchvision 
import pathlib 
from google.colab import drive 
drive.mount('/content/drive')

epochs = 20 
batch_size = 128
learning_rate = 0.001

#Data Transformation 
transformer = transforms.Compose([
                                  transforms.Resize((224,224)),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(), 
                                  transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
                                  ])

#Load data with DataLoader
train_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Data/train' 
test_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Data/test' 

train_loader = DataLoader(torchvision.datasets.ImageFolder(train_path,transform = transformer), batch_size= batch_size, shuffle= True)
test_loader = DataLoader(torchvision.datasets.ImageFolder(test_path,transform = transformer), batch_size= batch_size, shuffle= False)

root = pathlib.Path(train_path)
classes = sorted([j.name.split('/')[-1] for j in root.iterdir()])
print(classes)
train_count = len(glob.glob(train_path+'/**/*.jpg')) + len(glob.glob(train_path+'/**/*.png')) + len(glob.glob(train_path+'/**/*.jpeg'))
test_count = len(glob.glob(test_path+'/**/*.jpg')) + len(glob.glob(test_path+'/**/*.png')) + len(glob.glob(test_path+'/**/*.jpeg'))
print(train_count,test_count)

#Create the CNN 
class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    '''nout = [(width + 2*padding - kernel_size) / stride] + 1 '''
    # [128,3,224,224]
    self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 12, kernel_size = 5) 
    # [4,12,220,220]
    self.pool1 = nn.MaxPool2d(2,2) #reduces the images by a factor of 2
    # [4,12,110,110]
    self.conv2 = nn.Conv2d(in_channels = 12, out_channels = 24, kernel_size = 5)
    # [4,24,106,106]
    self.pool2 = nn.MaxPool2d(2,2)
    # [4,24,53,53] which becomes the input of the fully connected layer 
    self.fc1 = nn.Linear(in_features = (24 * 53 * 53), out_features = 120) 
    self.fc2 = nn.Linear(in_features = 120, out_features = 84) 
    self.fc3 = nn.Linear(in_features = 84, out_features = len(classes)) #final layer, output will be the number of classes 

  def forward(self, x):
    x = self.pool1(F.relu(self.conv1(x)))  
    x = self.pool2(F.relu(self.conv2(x)))  
    x = x.view(-1, 24 * 53 * 53)            
    x = F.relu(self.fc1(x))               
    x = F.relu(self.fc2(x))              
    x = self.fc3(x)                       
    return x


# Training the model 
model = CNN()
loss_function = nn.CrossEntropyLoss() #includes the softmax activation function 
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

n_total_steps = len(train_loader)
for epoch in range(epochs):
  n_correct = 0
  n_samples = 0
  for i, (images, labels) in enumerate(train_loader):
    # Forward pass
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    n_samples += labels.size(0)
    n_correct += (predicted == labels).sum().item()

    loss = loss_function(outputs, labels)
    # Backpropagation and optimization 
    optimizer.zero_grad() #empty gradients 
    loss.backward()
    optimizer.step()

    acc = 100.0 * n_correct / n_samples

  print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Accuracy: {round(acc,2)} %, Loss: {loss.item():.4f}')
print('Done!!')

# Testing the model 
with torch.no_grad():
  n_correct = 0
  n_samples = 0
  n_class_correct = [0 for i in range(3)]
  n_class_samples = [0 for i in range(3)]
  for images, labels in test_loader:
    outputs = model(images)
    # max returns (value ,index)
    _, predicted = torch.max(outputs, 1)
    n_samples += labels.size(0)
    n_correct += (predicted == labels).sum().item() 

  acc = 100.0 * n_correct / n_samples
  print(f'Accuracy of the network: {acc} %')

torch.save(model.state_dict(),'/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/model.model')

For loading the model and trying to make predictions on new images, the code is as follows:

checkpoint = torch.load('/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/model.model')
model = CNN()
model.load_state_dict(checkpoint)
model.eval()

#Data Transformation 
transformer = transforms.Compose([
                                  transforms.Resize((224,224)),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) 
                                  ])

#Making preidctions on new data 
from PIL import Image
def prediction(img_path,transformer):
  image = Image.open(img_path).convert('RGB')
  image_tensor = transformer(image)
  image_tensor = image_tensor.unsqueeze_(0) #so img is not treated as a batch 
  input_img = Variable(image_tensor)
  output = model(input_img)
  #print(output)
  index = output.data.numpy().argmax()
  pred = classes[index]
  return pred 

pred_path = '/content/drive/MyDrive/Chest X-ray (Covid-19 & Pneumonia)/Test_images/Data/'
test_imgs = glob.glob(pred_path+'/*')

for i in test_imgs:
    print(prediction(i,transformer))

I'm guessing the problem must be in the way that i am preprocessing the data, although i cannot find my mistake. Any help will be deeply appreciated, since i have been stuck on this for a while now. p.s. i can share my notebook as well, if it is of any help

Upvotes: 0

Views: 3832

Answers (1)

SarthakJain
SarthakJain

Reputation: 1686

Regarding your problem, I have a really good way to debug this to target where the problem most likely will be and so it will be really easy to fix your issue.

So, my debugging process would be based on the fact that your CNN performs well on the test set. Firstly set your test loader batch size to 1 temporarily. After that, One thing to do is in your test loop when you calculate the amount correct, you can run the following code:

#Your code
outputs = model(images) # Really only one image and 1 output.

#Altered Code: 
correct =  (predicted == labels).sum().item() # This will be either 1 or 0 since you have only one image per batch

# My new code: 
if correct:
   # if value is 1 instead of 0 then turn value into a single image with no batch size
   single_correct_image = images.squeeze(0)
   # Then convert tensor image into PIL image
   pil_image = transforms.ToPILImage()(single_correct_image)
   # Save the pil image to any directory specified in quotes.
   pil_image = pil_image.save("/content")

   #Terminate testing process. Ignore Value Error if it says terminating process
   raise ValueError("terminating process")

Now you have an image saved to disk that you know is correct in the test set. The next step would be to open such image and run it to your predict function. Couple of things can happen and thus give info about your situation

  • If your model returns the wrong answer then there is something wrong with the different code you have within the prediction and testing code. One uses a torch.sum and torch.max the other uses np.argmax.Then you can use print statements to debug what is going on there. Perhaps some conversion error or your expectation of the output's format is different.
  • If your code return the right answer then your model is just failing to predict on new images. I suggest running more trial cases with the above process.

For additional reference, if you still get very stuck to the point where you feel like you can't solve it, then I suggest using this notebook to guide and give some suggestions on what code to atleast inspect.

https://www.kaggle.com/salvation23/xray-cnn-pytorch

Sarthak Jain

Upvotes: 1

Related Questions