Mona Jalal
Mona Jalal

Reputation: 38275

Extracting feature vector for grey images via ResNet18: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

I have 600x800 images that have only 1 channel. I am trying to use pre-trained ResNet18 to extract their features however the code expects 3 channel:

import torch
import torchvision
import torchvision.models as models
from PIL import Image

img = Image.open("labeled-data/train_moth/moth/frame163.png")


# Load the pretrained model
model = models.resnet18(pretrained=True)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode
model.eval()

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def get_vector(image):
    # Create a PyTorch tensor with the transformed image
    t_img = transforms(image)
    t_img = torch.cat((t_img, t_img, t_img), 0)
    # Create a vector of zeros that will hold our feature vector
    # The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512)

    # Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        my_embedding.copy_(o.flatten())                 # <-- flatten

    # Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)
    # Run the model on our transformed image
    with torch.no_grad():                               # <-- no_grad context
        model(t_img.unsqueeze(0))                       # <-- unsqueeze
    # Detach our copy function from the layer
    h.remove()
    # Return the feature vector
    return my_embedding

Here's the error I am getting:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-59ab45f8c1e6> in <module>
     42 
     43 
---> 44 pic_vector = get_vector(img)

<ipython-input-5-59ab45f8c1e6> in get_vector(image)
     21 def get_vector(image):
     22     # Create a PyTorch tensor with the transformed image
---> 23     t_img = transforms(image)
     24     t_img = torch.cat((t_img, t_img, t_img), 0)
     25     # Create a vector of zeros that will hold our feature vector

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
     59     def __call__(self, img):
     60         for t in self.transforms:
---> 61             img = t(img)
     62         return img
     63 

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, tensor)
    210             Tensor: Normalized Tensor image.
    211         """
--> 212         return F.normalize(tensor, self.mean, self.std, self.inplace)
    213 
    214     def __repr__(self):

~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
    296     if std.ndim == 1:
    297         std = std[:, None, None]
--> 298     tensor.sub_(mean).div_(std)
    299     return tensor
    300 

RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
    
    pic_vector = get_vector(img)



Code is from: https://stackoverflow.com/a/63552285/2414957

I thought using

t_img = torch.cat((t_img, t_img, t_img), 0)

would be helpful but I was wrong.

Here's a bit about image:

$ identify frame163.png 
frame163.png PNG 800x600 800x600+0+0 8-bit Gray 256c 175297B 0.000u 0:00.000

Upvotes: 0

Views: 1149

Answers (2)

Prajot Kuvalekar
Prajot Kuvalekar

Reputation: 6678

many models (almost all models) from torchvision module expects our input to be in 3 channel.
So when ever you are using pretrained model , just convert your image to RGB scale.
So if i see your code

just change this

img = Image.open("labeled-data/train_moth/moth/frame163.png")

to this

img = Image.open("labeled-data/train_moth/moth/frame163.png").convert('RGB')

The above line will just stack your gray scale image to have 3 channel

Second option what you have is Defining our model class...with single channel as input

model = models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

Plz vote if you find this useful

Upvotes: 1

Kenan
Kenan

Reputation: 14124

Change the order of

t_img = transforms(image)
t_img = torch.cat((t_img, t_img, t_img), 0)

to

t_img = torch.cat((image, image, image), 0)
t_img = transforms(t_img)

transforms expects input to be of shape [C, W, H]

Upvotes: 0

Related Questions