czajah
czajah

Reputation: 439

Can I use FFTW to make transform along first and last axis of 3D array?

Using fftw_plan_many_dft I can do transforms along x,y and y,z axis:

vector<complex<float_type>> yz_fft(vector<complex<float_type>> input, int N_X, int N_Y, int N_Z){
    vector<complex<float_type>> result(input.size());
    int rank = 2;
    int n[] = {N_Y,N_Z};
    int *inembed = n;
    int *onembed = n;
    int istride = 1;
    int ostride = 1;
    int idist = N_Y*N_Z;
    int odist = N_Y*N_Z;
    int howmany = N_X;
    fftw_plan plan = fftw_plan_many_dft(
            rank,
            n,
            howmany,
            reinterpret_cast<fftw_complex *>(input.data()),
            inembed,
            istride,
            idist,
            reinterpret_cast<fftw_complex *>(result.data()),
            onembed,
            ostride,
            odist,
            FFTW_FORWARD,
            FFTW_ESTIMATE);
    fftw_execute(plan);
    return result;
}

vector<complex<float_type>> xy_fft(vector<complex<float_type>> input, int N_X, int N_Y, int N_Z){
    vector<complex<float_type>> result(input.size());
    int rank = 2;
    int n[] = {N_X,N_Y};
    int *inembed = n;
    int *onembed = n;
    int istride = N_Z;
    int ostride = N_Z;
    int idist = 1;
    int odist = 1;
    int howmany = N_Z;
    fftw_plan plan = fftw_plan_many_dft(
            rank,
            n,
            howmany,
            reinterpret_cast<fftw_complex *>(input.data()),
            inembed,
            istride,
            idist,
            reinterpret_cast<fftw_complex *>(result.data()),
            onembed,
            ostride,
            odist,
            FFTW_FORWARD,
            FFTW_ESTIMATE);
    fftw_execute(plan);
    return result;
}

but I can't figure out how to do x,z transform. How do I do this?

Upvotes: 2

Views: 305

Answers (1)

czajah
czajah

Reputation: 439

So there is a way to use fftw_plan_many_dft to do xz transform. Downvotes may suggest that people are not interested in that but I decided to share it anyway. For solutnion check struct xz_fft_many below.

#include <iostream>
#include <numeric>
#include <complex>
#include <fftw3.h>

#include <benchmark/benchmark.h>


using namespace std;

using float_type = double;
using index_type = unsigned long;

vector<complex<float_type>> get_data(index_type N){

    std::vector<complex<float_type>> data(N);
    iota(data.begin(), data.end(),0);

    return data;
}

void print(vector<complex<float_type>> data,index_type N_X,index_type N_Y,index_type N_Z){
    for(int i=0; i!=N_X; ++i){
        for(int j=0; j!=N_Y; ++j){
            for(int k=0; k!=N_Z; ++k){
                index_type idx = i*(N_Y*N_Z)+j*N_Z+k;
                cout<<"[ "<<i<<", "<<j<<", "<<k<<" ] = "<<data.data()[idx]<<endl;
            }
        }
    }
}

struct x_fft {
    vector<complex<float_type>>& data;
    vector<complex<float_type>> result;
    fftw_plan fft_plan;
    index_type N_X;
    index_type N_Y;
    index_type N_Z;


    x_fft(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
            : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
    {
        result = vector<complex<float_type>>(data.size());
        int rank = 1;
        int n[] = {(int)N_X};
        int *inembed = n;
        int *onembed = n;
        int istride = N_Y*N_Z;
        int ostride = istride;
        int idist = 1;
        int odist = idist;
        int howmany = N_Y*N_Z;
        fft_plan = fftw_plan_many_dft(
                rank,
                n,
                howmany,
                reinterpret_cast<fftw_complex *>(data.data()),
                inembed,
                istride,
                idist,
                reinterpret_cast<fftw_complex *>(result.data()),
                onembed,
                ostride,
                odist,
                FFTW_FORWARD,
                FFTW_MEASURE);
    }

    const vector<complex<float_type>> &getResult() const {
        return result;
    }

    vector<complex<float_type>>& run(){
        fftw_execute(fft_plan);
        return result;
    }

};

struct z_fft {
    vector<complex<float_type>>& data;
    vector<complex<float_type>> result;
    fftw_plan fft_plan;
    index_type N_X;
    index_type N_Y;
    index_type N_Z;


    z_fft(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
            : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
    {
        result = vector<complex<float_type>>(data.size());
        int rank = 1;
        int n[] = {(int)N_Z};
        int *inembed = n;
        int *onembed = n;
        int istride = 1;
        int ostride = istride;
        int idist = N_Z;
        int odist = idist;
        int howmany = N_X*N_Y;
        fft_plan = fftw_plan_many_dft(
                rank,
                n,
                howmany,
                reinterpret_cast<fftw_complex *>(data.data()),
                inembed,
                istride,
                idist,
                reinterpret_cast<fftw_complex *>(result.data()),
                onembed,
                ostride,
                odist,
                FFTW_FORWARD,
                FFTW_MEASURE);
    }

    vector<complex<float_type>>& run(){
        fftw_execute(fft_plan);
        return result;
    }

};


struct xz_fft_many {
    vector<complex<float_type>>& data;
    vector<complex<float_type>> result;
    fftw_plan fft_plan;
    index_type N_X;
    index_type N_Y;
    index_type N_Z;


    xz_fft_many(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
            : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
    {
        result = vector<complex<float_type>>(data.size());
        int rank = 2;
        int n[] = {(int) N_X, (int) N_Z};
        int inembed[] = {(int) N_X, (int) (N_Z*N_Y)};
        int *onembed = inembed;
        int istride = 1;
        int ostride = 1;
        int idist = N_Z;
        int odist = N_Z;
        int howmany = N_Y;
        fft_plan = fftw_plan_many_dft(
                rank,
                n,
                howmany,
                reinterpret_cast<fftw_complex *>(data.data()),
                inembed,
                istride,
                idist,
                reinterpret_cast<fftw_complex *>(result.data()),
                onembed,
                ostride,
                odist,
                FFTW_FORWARD,FFTW_MEASURE);
    }

    vector<complex<float_type>>& run(){
        fftw_execute(fft_plan);
        return result;
    }

};

struct xz_fft_composition {
    vector<complex<float_type>>& data;
    index_type N_X;
    index_type N_Y;
    index_type N_Z;
    x_fft* xFft;
    z_fft* zFft;


    xz_fft_composition(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
            : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
    {
        xFft = new x_fft(data,N_X,N_Y,N_Z);
        zFft = new z_fft(xFft->result,N_X,N_Y,N_Z);
    }

    vector<complex<float_type>>& run(){
        xFft->run();
        return zFft->run();
    }

};

struct TestData{
    index_type N_X = 512;
    index_type N_Y = 16;
    index_type N_Z = 16;

    index_type ARRAY_SIZE = N_X * N_Y * N_Z;
    std::vector<complex<float_type>> data = get_data(ARRAY_SIZE);

    TestData() {
//        print(data,N_X,N_Y,N_Z);
    }
};

TestData testData;

struct SanityTest{
    SanityTest() {
        xz_fft_many fft_many(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
        xz_fft_composition fft_composition(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
        std::vector<complex<float_type>> fft_many_result =  fft_many.run();
        std::vector<complex<float_type>> fft_composition_result =  fft_composition.run();

        bool equal = std::equal(fft_composition_result.begin(), fft_composition_result.end(), fft_many_result.begin());
        assert(equal);
        if(equal){
            cout << "ok" << endl;
        }
    }
};

SanityTest sanityTest;

static void XZ_test_many(benchmark::State& state) {
    xz_fft_many fft(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
    for (auto _ : state) {
        auto result = fft.run();
    }
}

static void XZ_test_composition(benchmark::State& state) {
    xz_fft_composition fft(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
    for (auto _ : state) {
        auto result = fft.run();
    }
}

BENCHMARK(XZ_test_many)->Iterations(1000);
BENCHMARK(XZ_test_composition)->Iterations(1000);

BENCHMARK_MAIN();

If I done benchmarks correctly there are some significant differences beetwen fftw_plan_many_dft and composition approaches for different N_X, N_Y, N_Z combinations. For example using

    index_type N_X = 512;
    index_type N_Y = 16;
    index_type N_Z = 16;

I've got almost two times difference in favour of fftw_plan_many_dft but for other sets of input parameters I've often found composition aproach to be faster but not that much.

------------------------------------------------------------------------------
Benchmark                                    Time             CPU   Iterations
------------------------------------------------------------------------------
XZ_test_many/iterations:1000           1412647 ns      1364813 ns         1000
XZ_test_composition/iterations:1000    2619807 ns      2542472 ns         1000

Upvotes: 3

Related Questions