Faze Maze
Faze Maze

Reputation: 11

How to correctly use Numpy's FFT function in PyTorch?

I was recently introduced to PyTorch and began running through the library's documentation and tutorials. In the "Creating extensions using numpy and scipy" tutorial, under "Parameter-less example", a sample function is created using numpy called BadFFTFunction.

The description for the function states:

"This layer doesn’t particularly do anything useful or mathematically correct.

It is aptly named BadFFTFunction"

The function and its usage are given as:

from numpy.fft import rfft2, irfft2

class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

def incorrect_fft(input):
    return BadFFTFunction()(input)

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

Unfortunately, I was only recently introduced to signal processing as well, and am unsure of where the (likely obvious) error is in this function.

I am wondering, how might one go about fixing this function so that its forward and backward outputs are correct?

How can BadFFTFunction be fixed so that a differentiable FFT function can be used in PyTorch?

Upvotes: 1

Views: 5114

Answers (2)

iacob
iacob

Reputation: 24281

As of version 1.8, PyTorch has torch.fft:

torch.fft.fft(input)

Upvotes: 1

robintibor
robintibor

Reputation: 1010

I think the errors are: First, the function, despite having FFT in its name, only returns the amplitudes/absolute values of the FFT output, not the full complex coefficients. Also, just using the inverse FFT to compute the gradient of the amplitudes probably doesn't make much sense mathematically (?).

There is a package called pytorch-fft that tries to make an FFT-function available in pytorch. You can see some experimental code for autograd functionality here. Also note discussion in this issue.

Upvotes: 2

Related Questions