Mohammad Bazrafshan
Mohammad Bazrafshan

Reputation: 51

Multiplication in time domain using 2D circular convolution in frequency domain

I'm trying to do a time domain multiplication using 2D circular convolution in frequency domain. Actually I know how it works in 1D cases. like:

x = [1 2 3 4 5];

y = [4 5 6 7 8];

xy = 1/5*ifft( cconv(fft(x), fft(y), 5));

xy0 = x.*y;

both xy and xy0 are the same and this is what I want. However, for a 2D case, cconv is not defined in matlab and I don't know how to perform a multiplication between 2 matrices of the same size using convolution in frequency domain.

Suppose we have the following matrices:

x = [3 5 4;
    7 6 1;
    -1 2 0];

y = [2 7 1;
    2 -3 2;
    5 6 9];

for sure the command 1/9*ifft2( conv2(fft2(x), fft2(y), 'same')) does not give the same result as x.*y

Can everyone please help me with this problem?

Upvotes: 2

Views: 681

Answers (1)

Kevinj22
Kevinj22

Reputation: 1066

Your convolution should be a circular convolution in your command, the same as you cconv for your 1D example. See https://www.mathworks.com/matlabcentral/answers/59333-convolution-in-frequency-domain-not-convolution-in-time-domain and Convolution of two fft function.

Here's an example I got to work in Python using Numpy and Scipy. I had to shift the results of the convolution in the frequency domain (by +1 +1) BEFORE the IFFT to match the multiplication results as well. I assume something about the way scipy pads shifts the results from their proper places, a quick check of the FFT of the spatial domain results showing the proper FFT conv output should let you know how to shift in matlab or with a different padding.

Using 'wrap' padding in scipy.signal is equivalent to circular convolution.

import numpy as np
from scipy import signal

###################################################
# Check Multiplication in Spatial == Conv in Freq #
###################################################

np.random.seed(1234)
fr = np.random.normal(0.1,1.0,(3,3))
imgs = np.random.normal(0,1.0,(3,3))

spatialMult = fr * imgs

filtFFT = np.fft.fft2(fr)
imgsFFT = np.fft.fft2(imgs)

ac = signal.convolve2d(imgsFFT.real,filtFFT.real,'same','wrap')
ad = signal.convolve2d(imgsFFT.real,filtFFT.imag,'same','wrap')
bc = signal.convolve2d(imgsFFT.imag,filtFFT.real,'same','wrap')
bd = signal.convolve2d(imgsFFT.imag,filtFFT.imag,'same','wrap')

r_conv = (ac - bd)
i_conv = (ad + bc)

res_conv = r_conv + 1j*i_conv
res_conv = res_conv / 9
# Positions are incorrect for some reason
# Need a +1 +1 shift to align
res_conv = np.roll(np.roll(res_conv,1,0),-2,1)
res_ifft = np.fft.ifft2(res_conv)

spat_fft = np.fft.fft2(spatialMult)
print(spat_fft)
print(res_conv)
print()
print(res_ifft.real)
print(spatialMult)

If you want the correct output immediately you can do the padding yourself using numpy:

ir = imgsFFT.real
im = imgsFFT.imag
f_r = filtFFT.real
f_i = filtFFT.imag

ir = np.pad(ir,((2,0),(2,0)),'wrap')
im = np.pad(im,((2,0),(2,0)),'wrap')

ac = signal.convolve2d(ir,f_r,'valid')
ad = signal.convolve2d(ir,f_i,'valid')
bc = signal.convolve2d(im,f_r,'valid')
bd = signal.convolve2d(im,f_i,'valid')

r_conv = (ac - bd)
i_conv = (ad + bc)

res_conv = r_conv + 1j*i_conv
res_conv = res_conv / n
res_ifft = np.fft.ifft2(res_conv)

spat_fft = np.fft.fft2(spatialMult)
print(spat_fft)
print(res_conv)
print()
print(res_ifft.real)
print(spatialMult)

Upvotes: 0

Related Questions