Reputation: 11
I was recently introduced to PyTorch and began running through the library's documentation and tutorials.
In the "Creating extensions using numpy and scipy" tutorial, under "Parameter-less example", a sample function is created using numpy called BadFFTFunction
.
The description for the function states:
"This layer doesn’t particularly do anything useful or mathematically correct.
It is aptly named BadFFTFunction"
The function and its usage are given as:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
def forward(self, input):
numpy_input = input.numpy()
result = abs(rfft2(numpy_input))
return torch.FloatTensor(result)
def backward(self, grad_output):
numpy_go = grad_output.numpy()
result = irfft2(numpy_go)
return torch.FloatTensor(result)
def incorrect_fft(input):
return BadFFTFunction()(input)
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)
Unfortunately, I was only recently introduced to signal processing as well, and am unsure of where the (likely obvious) error is in this function.
I am wondering, how might one go about fixing this function so that its forward and backward outputs are correct?
How can BadFFTFunction
be fixed so that a differentiable FFT function can be used in PyTorch?
Upvotes: 1
Views: 5114
Reputation: 1010
I think the errors are: First, the function, despite having FFT in its name, only returns the amplitudes/absolute values of the FFT output, not the full complex coefficients. Also, just using the inverse FFT to compute the gradient of the amplitudes probably doesn't make much sense mathematically (?).
There is a package called pytorch-fft that tries to make an FFT-function available in pytorch. You can see some experimental code for autograd functionality here. Also note discussion in this issue.
Upvotes: 2