Reputation: 51
I'm trying to do a time domain multiplication using 2D circular convolution in frequency domain. Actually I know how it works in 1D cases. like:
x = [1 2 3 4 5];
y = [4 5 6 7 8];
xy = 1/5*ifft( cconv(fft(x), fft(y), 5));
xy0 = x.*y;
both xy and xy0 are the same and this is what I want. However, for a 2D case, cconv is not defined in matlab and I don't know how to perform a multiplication between 2 matrices of the same size using convolution in frequency domain.
Suppose we have the following matrices:
x = [3 5 4;
7 6 1;
-1 2 0];
y = [2 7 1;
2 -3 2;
5 6 9];
for sure the command 1/9*ifft2( conv2(fft2(x), fft2(y), 'same')) does not give the same result as x.*y
Can everyone please help me with this problem?
Upvotes: 2
Views: 681
Reputation: 1066
Your convolution should be a circular convolution in your command, the same as you cconv for your 1D example. See https://www.mathworks.com/matlabcentral/answers/59333-convolution-in-frequency-domain-not-convolution-in-time-domain and Convolution of two fft function.
Here's an example I got to work in Python using Numpy and Scipy. I had to shift the results of the convolution in the frequency domain (by +1 +1) BEFORE the IFFT to match the multiplication results as well. I assume something about the way scipy pads shifts the results from their proper places, a quick check of the FFT of the spatial domain results showing the proper FFT conv output should let you know how to shift in matlab or with a different padding.
Using 'wrap' padding in scipy.signal is equivalent to circular convolution.
import numpy as np
from scipy import signal
###################################################
# Check Multiplication in Spatial == Conv in Freq #
###################################################
np.random.seed(1234)
fr = np.random.normal(0.1,1.0,(3,3))
imgs = np.random.normal(0,1.0,(3,3))
spatialMult = fr * imgs
filtFFT = np.fft.fft2(fr)
imgsFFT = np.fft.fft2(imgs)
ac = signal.convolve2d(imgsFFT.real,filtFFT.real,'same','wrap')
ad = signal.convolve2d(imgsFFT.real,filtFFT.imag,'same','wrap')
bc = signal.convolve2d(imgsFFT.imag,filtFFT.real,'same','wrap')
bd = signal.convolve2d(imgsFFT.imag,filtFFT.imag,'same','wrap')
r_conv = (ac - bd)
i_conv = (ad + bc)
res_conv = r_conv + 1j*i_conv
res_conv = res_conv / 9
# Positions are incorrect for some reason
# Need a +1 +1 shift to align
res_conv = np.roll(np.roll(res_conv,1,0),-2,1)
res_ifft = np.fft.ifft2(res_conv)
spat_fft = np.fft.fft2(spatialMult)
print(spat_fft)
print(res_conv)
print()
print(res_ifft.real)
print(spatialMult)
If you want the correct output immediately you can do the padding yourself using numpy:
ir = imgsFFT.real
im = imgsFFT.imag
f_r = filtFFT.real
f_i = filtFFT.imag
ir = np.pad(ir,((2,0),(2,0)),'wrap')
im = np.pad(im,((2,0),(2,0)),'wrap')
ac = signal.convolve2d(ir,f_r,'valid')
ad = signal.convolve2d(ir,f_i,'valid')
bc = signal.convolve2d(im,f_r,'valid')
bd = signal.convolve2d(im,f_i,'valid')
r_conv = (ac - bd)
i_conv = (ad + bc)
res_conv = r_conv + 1j*i_conv
res_conv = res_conv / n
res_ifft = np.fft.ifft2(res_conv)
spat_fft = np.fft.fft2(spatialMult)
print(spat_fft)
print(res_conv)
print()
print(res_ifft.real)
print(spatialMult)
Upvotes: 0