Reputation: 38275
I have 600x800 images that have only 1 channel. I am trying to use pre-trained ResNet18 to extract their features however the code expects 3 channel:
import torch
import torchvision
import torchvision.models as models
from PIL import Image
img = Image.open("labeled-data/train_moth/moth/frame163.png")
# Load the pretrained model
model = models.resnet18(pretrained=True)
# Use the model object to select the desired layer
layer = model._modules.get('avgpool')
# Set model to evaluation mode
model.eval()
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def get_vector(image):
# Create a PyTorch tensor with the transformed image
t_img = transforms(image)
t_img = torch.cat((t_img, t_img, t_img), 0)
# Create a vector of zeros that will hold our feature vector
# The 'avgpool' layer has an output size of 512
my_embedding = torch.zeros(512)
# Define a function that will copy the output of a layer
def copy_data(m, i, o):
my_embedding.copy_(o.flatten()) # <-- flatten
# Attach that function to our selected layer
h = layer.register_forward_hook(copy_data)
# Run the model on our transformed image
with torch.no_grad(): # <-- no_grad context
model(t_img.unsqueeze(0)) # <-- unsqueeze
# Detach our copy function from the layer
h.remove()
# Return the feature vector
return my_embedding
Here's the error I am getting:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-59ab45f8c1e6> in <module>
42
43
---> 44 pic_vector = get_vector(img)
<ipython-input-5-59ab45f8c1e6> in get_vector(image)
21 def get_vector(image):
22 # Create a PyTorch tensor with the transformed image
---> 23 t_img = transforms(image)
24 t_img = torch.cat((t_img, t_img, t_img), 0)
25 # Create a vector of zeros that will hold our feature vector
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
59 def __call__(self, img):
60 for t in self.transforms:
---> 61 img = t(img)
62 return img
63
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, tensor)
210 Tensor: Normalized Tensor image.
211 """
--> 212 return F.normalize(tensor, self.mean, self.std, self.inplace)
213
214 def __repr__(self):
~/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
296 if std.ndim == 1:
297 std = std[:, None, None]
--> 298 tensor.sub_(mean).div_(std)
299 return tensor
300
RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
pic_vector = get_vector(img)
Code is from: https://stackoverflow.com/a/63552285/2414957
I thought using
t_img = torch.cat((t_img, t_img, t_img), 0)
would be helpful but I was wrong.
Here's a bit about image:
$ identify frame163.png
frame163.png PNG 800x600 800x600+0+0 8-bit Gray 256c 175297B 0.000u 0:00.000
Upvotes: 0
Views: 1149
Reputation: 6678
many models (almost all models) from torchvision module expects our input to be in 3 channel.
So when ever you are using pretrained model , just convert your image to RGB scale.
So if i see your code
just change this
img = Image.open("labeled-data/train_moth/moth/frame163.png")
to this
img = Image.open("labeled-data/train_moth/moth/frame163.png").convert('RGB')
The above line will just stack your gray scale image to have 3 channel
Second option what you have is Defining our model class...with single channel as input
model = models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Plz vote if you find this useful
Upvotes: 1
Reputation: 14124
Change the order of
t_img = transforms(image)
t_img = torch.cat((t_img, t_img, t_img), 0)
to
t_img = torch.cat((image, image, image), 0)
t_img = transforms(t_img)
transforms expects input to be of shape [C, W, H]
Upvotes: 0