randomal
randomal

Reputation: 6614

Display result of convolution in PyTorch

PyTorch newbie here. I wrote a script (code below) that performs the following operations: load an image, perform a 2D convolution operation and then display the output and the input.

At present I have the image below, which seems off. How can I plot the feature map correctly?

ted

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys

A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)

image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1).reshape([516, 780, 3])

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(fc1.data.numpy())

plt.show()

Upvotes: 0

Views: 766

Answers (2)

Florian Blume
Florian Blume

Reputation: 3345

The issue with your code is this line

image_d = torch.FloatTensor(np.asarray(A.reshape(1, 3, A.shape[0] , A.shape[1])))

You can't just reshape the image you need to transpose the channels. As a remark for the future, if you get a stripy result like you did it's most likely some permutation/transposition or reshaping operation that's not correct.

Other than that I also scaled the input image to [0, 1] to show it properly. Below is the working code:

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio
import sys

A = imageio.imread('LiT.png')
# Define how the convolution operation works
conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)

# from [H, W, C] to [C, H, W]
transposed_image = A.transpose((2, 0, 1))
# add batch dim
transposed_image = np.expand_dims(transposed_image, 0)

image_d = torch.FloatTensor(transposed_image)
fc = conv2(image_d)
fc1 = fc.permute(0, 2, 3, 1)[0]
result = fc1.data.numpy()
max_ = np.max(result)
min_ = np.min(result)
result -= min_
result /= max_

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(A)
plt.subplot(1,2,2)
plt.imshow(result)

plt.show()

Upvotes: 1

asymptote
asymptote

Reputation: 1402

To my understanding, the problem lies in how you are permuting channels position in the image by using reshape. Instead, 'np.transpose or tensor.permute should be used. Using torch for permutation:

image_d  = torch.FloatTensor(np.asarray(A)).unsqueeze(0).permute(0,3,1,2)

Or, if we want to handle the permutation part in numpy:

image_d = np.transpose(np.asarray(A), (2,0,1))
image_d = torch.FloatTensor(image_d).unsqueeze(0)

Upvotes: 1

Related Questions