Reputation: 6098
TensorFlow.conv2d()
is impractically slow for convolving large images with large kernels (filters). It takes a few minutes to convolve a 1024x1024 image with a kernel of the same size. For comparison, cv2.filter2D()
returns the result immediately.
I found tf.fft2()
and tf.rfft()
.
However it was not clear to me how to perform simple image filtering with these functions.
How can I implement fast 2D image filtering with TensorFlow using FFT?
Upvotes: 3
Views: 14630
Reputation: 6187
A linear discrete convolution of the form x * y
can be computed using convolution theorem and the discrete time Fourier transform (DTFT). If x * y
is a circular discrete convolution than it can be computed with the discrete Fourier transform (DFT).
The convolution theorem states x * y
can be computed using the Fourier transform as
where denotes the Fourier transform and the inverse Fourier transform. When x
and y
are discrete and their convolution is a linear convolution this is computed using the DTFT as
If x
and y
are discrete and their convolution is a circular convolution the DTFT above is replaced by the DFT. Note: linear convolution problems can be embedded in circular convolution problems.
I'm more familiar with MATLAB but from reading the TensorFlow documentation for tf.signal.fft2d
and tf.signal.ifft2d
the solution below should be easily convertible to TensorFlow by replacing the MATLAB functions fft2
and ifft2
.
In MATLAB (and TensorFlow) fft2
(and tf.signal.fft2d
) computes the DFT using the fast Fourier transform algorithm. If the convolution of x
and y
is circular this can be computed by
ifft2(fft2(x).*fft2(y))
where .*
represents element by element multiplication in MATLAB. However, if it is linear then we zero pad the data to length 2N-1
where N
is the length of one dimension (1024 in the question). In MATLAB this can be computed one of two ways. Firstly, by
h = ifft2(fft2(x, 2*N-1, 2*N-1).*fft2(y, 2*N-1, 2*N-1));
where MATLAB computes the 2*N-1
-point 2D Fourier transform of x
and y
by zero padding and then the 2*N-1
-point 2D inverse Fourier transform. This method can't be used in TensorFlow (from my understanding of the documentation) so the next is the only option. In MATLAB and TensorFlow the convolution can be computed by first extending x
and y
to size 2*N-1
x 2*N-1
and then computing the 2*N-1
-point 2D Fourier transform and inverse Fourier transform
x_extended = x;
x_extended(2*N-1, 2*N-1) = 0;
y_extended = y;
y_extended(2*N-1, 2*N-1) = 0;
h_extended = ifft2(fft2(x_extended).*fft2(y_extended));
In MATLAB, h
and h_extended
are exactly equal. The convolution of x
and y
can be computed without the Fourier transform with
hC = conv2(x, y);
in MATLAB.
In MATLAB on my laptop conv2(x, y)
takes 55 seconds whereas the Fourier transform approach takes less than 0.4 seconds.
Upvotes: 6
Reputation: 41
This can be done in a way similar to which for instance scipy.signal.fftconvolve
is implemented.
Here is an example, assume we have an image (2 dimensions, if you have also multiple channels you can use the 3d instead of 2 functions) (im), and a filter (e.g. gaussian).
First, take the Fourier transform of the image and define the fft_lenghts
(useful if the filter is of a different shape, in which case it will get zero padded.)
fft_lenght1 = tf.shape(im)[0]
fft_lenght2 = tf.shape(im)[1]
im_fft = tf.signal.rfft2d(im, fft_length=[fft_lenght1, fft_lenght2])
Next, take the FFT of the filter (note, for instance for a 2d gaussian filter make sure the center is in the top left corner, i.e. use only a 'quarter')
kernel_fft = tf.signal.rfft2d(kernel, fft_length=[fft_lenght1, fft_lenght2])
Finally, take the inverse transform back to get the convolved image
im_blurred = tf.signal.irfft2d(im_fft * kernel_fft, [fft_lenght1, fft_lenght2])
Upvotes: 4