DeepQuantum
DeepQuantum

Reputation: 241

Misalignment in 2-dimensional CUDA-FFTShift Function

I'm new to CUDA programming and need to perform an fft-shift operation on a flattened two-dimensional array. I've done some searching and come across this library but have so far been unable to make it work even after numerous attempts. The output is either:

I decided to write the in-place 2D ffshift-Function in Python, and there it works perfectly. I cannot figure out what I'm doing wrong in the CUDA version. I'm not an expert on the fftshift function in general, but the fact that this version works in Python is quite confusing to me.

template <typename T> 
__global__ void cufftShift_2D_kernel(T* array, int N)
{
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;

    int index = y * N + x;

    int offsetA = (N * N + N) / 2;
    int offsetB = (N * N - N) / 2;

    T temp;

    if (x < N / 2) {
        if (y < N / 2) {
            temp = array[index];
            array[index] = array[index + offsetA];
            array[index + offsetA] = temp;
        }
    }
    else if (y < N / 2) {
        temp = array[index];
        array[index] = array[index + offsetB];
        array[index + offsetB] = temp;
    }
}

template <class T>
struct GPUBuffer {
    thrust::device_vector<T> buffer;
    T* p;
    std::size_t mem_size;

    GPUBuffer() = delete;

    GPUBuffer(std::size_t size) : 
        buffer(thrust::device_vector<T>(size)),
        p(thrust::raw_pointer_cast(buffer.data())),
        mem_size(buffer.size() * sizeof(T)) {}
};

__global__ void multiplyBuffers(cufftComplex *kernel, cufftComplex *field, int size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (idx < size) {
        cufftComplex x = kernel[idx];
        cufftComplex y = field[idx];

        field[idx].x = x.x * y.x - x.y * y.y;
        field[idx].y = x.x * y.y + x.y * y.x; 
    }
}

int main()
{
    const int field_size = 32 * scale;

    std::vector<float> field = std::vector<float>(field_size * field_size, 0.f);

    const int threadsPerBlock = 256;
    const int blocksPerGrid = (field.size() + threadsPerBlock - 1) / threadsPerBlock;

    cufftHandle normal, inv;

    cufftPlan2d(&normal, field_size, field_size, CUFFT_C2C);
    cufftPlan2d(&inv, field_size, field_size, CUFFT_C2C);

    GPUBuffer<cufftComplex> kernel_gpu(field.size());
    GPUBuffer<cufftComplex> field_gpu(field.size());
    GPUBuffer<cufftComplex> shift_gpu(field.size());

    std::vector<cufftComplex> host_output(field.size());

    std::vector<float> real_output(field.size());

    for (size_t i = 0; i < field.size(); i++)
    {
        field_gpu.buffer[i] = { static_cast<float>(i), 0.f };
        kernel_gpu.buffer[i] = { static_cast<float>(i), 0.f };
    }

    cufftExecC2C(normal, kernel_gpu.p, kernel_gpu.p, CUFFT_FORWARD);

    for (int i = 0; i < 1; ++i) {
        cufftExecC2C(normal, field_gpu.p, field_gpu.p, CUFFT_FORWARD);
        multiplyBuffers<<<threadsPerBlock, blocksPerGrid>>>(kernel_gpu.p, field_gpu.p, field.size());
        cufftExecC2C(inv, field_gpu.p, field_gpu.p, CUFFT_INVERSE);
        cudaMemcpy(shift_gpu.p, field_gpu.p, field_gpu.mem_size, cudaMemcpyDeviceToDevice);
        cufftShift_2D_kernel<<<threadsPerBlock, blocksPerGrid>>>(shift_gpu.p, field_size);
    }

    cudaMemcpy(host_output.data(), field_gpu.p, field_gpu.mem_size, cudaMemcpyDeviceToHost);

    for (size_t i = 0; i < field.size(); ++i) {
        real_output[i] = sqrt(host_output[i].x * host_output[i].x + host_output[i].y * host_output[i].y);
    }

    dump_array_to_file(real_output, field_size, field_size, "cuda_inv.txt");

    cudaMemcpy(host_output.data(), shift_gpu.p, shift_gpu.mem_size, cudaMemcpyDeviceToHost);

    for (size_t i = 0; i < field.size(); ++i) {
        real_output[i] = sqrt(host_output[i].x * host_output[i].x + host_output[i].y * host_output[i].y);
    }

    dump_array_to_file(real_output, field_size, field_size, "cuda_shifted.txt");

    cufftDestroy(normal);
    cufftDestroy(inv);
    return 0;
}

Here's the equivalent Python version that produces correct output:

def shift(array: NDArray):
    N = np.int32(np.sqrt(array.size))
    offsetA = (N * N + N) // 2
    offsetB = (N * N - N) // 2
    for (y, x), _ in np.ndenumerate(array.reshape((N, N))):
        idx = y * N + x
        if x < N // 2:
            if y < N // 2:
                array[idx], array[idx + offsetA] = array[idx + offsetA], array[idx]
        else:
            if y < N // 2:
                array[idx], array[idx + offsetB] = array[idx + offsetB], array[idx]
    return array.reshape((N, N))

And what it looks like: FFT-Results

Upvotes: 1

Views: 119

Answers (0)

Related Questions