Reputation: 376
I implemented FFT-based convolution in Pytorch and compared the result with spatial convolution via conv2d() function. The convolution filter used is an average filter. The conv2d() function produced smoothened output due to average filtering as expected but the fft-based convolution returned a more blurry output. I have attached the code and outputs here -
spatial convolution -
from PIL import Image, ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np
im = Image.open("/kaggle/input/tiger.jpg")
im = im.resize((256,256))
gray_im = im.convert('L')
gray_im = ToTensor()(gray_im)
gray_im = gray_im.squeeze()
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0)
conv_fil = fil.unsqueeze(0).unsqueeze(0)
conv_op = F.conv2d(conv_gray_im,conv_fil)
conv_op = conv_op.squeeze()
plt.figure()
plt.imshow(conv_op, cmap='gray')
FFT-based convolution -
def fftshift(image):
sh = image.shape
x = np.arange(0, sh[2], 1)
y = np.arange(0, sh[3], 1)
xm, ym = np.meshgrid(x,y)
shifter = (-1)**(xm + ym)
shifter = torch.from_numpy(shifter)
return image*shifter
shift_im = fftshift(conv_gray_im)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
shift_fil = fftshift(padded_fil)
fft_shift_im = torch.rfft(shift_im, 2, onesided=False)
fft_shift_fil = torch.rfft(shift_fil, 2, onesided=False)
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift(torch.irfft(shift_prod, 2, onesided=False))
fft_op = shift_fft_conv.squeeze()
plt.figure('shifted fft')
plt.imshow(fft_op, cmap='gray')
original image -
spatial convolution output -
fft-based convolution output -
Could someone kindly explain the issue?
Upvotes: 1
Views: 3152
Reputation: 60660
The main problem with your code is that Torch doesn't do complex numbers, the output of its FFT is a 3D array, with the 3rd dimension having two values, one for the real component and one for the imaginary. Consequently, the multiplication does not do a complex multiplication.
There currently is no complex multiplication defined in Torch (see this issue), we'll have to define our own.
A minor issue, but also important if you want to compare the two convolution operations, is the following:
The FFT takes the origin of its input in the first element (top-left pixel for an image). To avoid a shifted output, you need to generate a padded kernel where the origin of the kernel is the top-left pixel. This is quite tricky, actually...
Your current code:
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
conv_fil = fil.unsqueeze(0).unsqueeze(0)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
generates a padded kernel where the origin is in pixel (1,1), rather than (0,0). It needs to be shifted by one pixel in each direction. NumPy has a function roll
that is useful for this, I don't know the Torch equivalent (I'm not at all familiar with Torch). This should work:
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
padded_fil = fil.unsqueeze(0).unsqueeze(0).numpy()
padded_fil = np.pad(padded_fil, ((0, gray_im.shape[0]-fil.shape[0]), (0, gray_im.shape[1]-fil.shape[1])))
padded_fil = np.roll(padded_fil, -1, axis=(0, 1))
padded_fil = torch.from_numpy(padded_fil)
Finally, your fftshift
function, applied to the spatial-domain image, causes the frequency-domain image (the result of the FFT applied to the image) to be shifted such that the origin is in the middle of the image, rather than the top-left. This shift is useful when looking at the output of the FFT, but is pointless when computing the convolution.
Putting these things together, the convolution is now:
def complex_multiplication(t1, t2):
real1, imag1 = t1[:,:,0], t1[:,:,1]
real2, imag2 = t2[:,:,0], t2[:,:,1]
return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)
fft_im = torch.rfft(gray_im, 2, onesided=False)
fft_fil = torch.rfft(padded_fil, 2, onesided=False)
fft_conv = torch.irfft(complex_multiplication(fft_im, fft_fil), 2, onesided=False)
Note that you can do one-sided FFTs to save a bit of computation time:
fft_im = torch.rfft(gray_im, 2, onesided=True)
fft_fil = torch.rfft(padded_fil, 2, onesided=True)
fft_conv = torch.irfft(complex_multiplication(fft_im, fft_fil), 2, onesided=True, signal_sizes=gray_im.shape)
Here the frequency domain is about half the size as in the full FFT, but it is only redundant parts that are left out. The result of the convolution is unchanged.
Upvotes: 1