Reputation: 6614
PyTorch newbie here. I wrote a script (code below) that performs the following operations: load an image, perform a 2D convolution operation and then display the output and the input.
At present I have the image below, which seems off. How can I plot the feature map correctly?
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys
A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1).reshape([516, 780, 3])
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(fc1.data.numpy())
plt.show()
Upvotes: 0
Views: 766
Reputation: 3345
The issue with your code is this line
image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))
You can't just reshape the image you need to transpose the channels. As a remark for the future, if you get a stripy result like you did it's most likely some permutation/transposition or reshaping operation that's not correct.
Other than that I also scaled the input image to [0, 1]
to show it properly. Below is the working code:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys
A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
# from [H, W, C] to [C, H, W]
transposed_image = A.transpose((2, 0, 1))
# add batch dim
transposed_image = np.expand_dims(transposed_image, 0)
image_d = torch.FloatTensor(transposed_image)
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1)[0]
result = fc1.data.numpy()
max_ = np.max(result)
min_ = np.min(result)
result -= min_
result /= max_
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(result)
plt.show()
Upvotes: 1
Reputation: 1402
To my understanding, the problem lies in how you are permuting channels position in the image by using reshape. Instead, 'np.transpose
or tensor.permute
should be used. Using torch for permutation:
image_d = torch.FloatTensor(np.asarray(A)).unsqueeze(0).permute(0,3,1,2)
Or, if we want to handle the permutation part in numpy:
image_d = np.transpose(np.asarray(A), (2,0,1))
image_d = torch.FloatTensor(image_d).unsqueeze(0)
Upvotes: 1