aiwyn
aiwyn

Reputation: 278

How to use thrust::transform on larger Vector derived from smaller Vector?

Input and starting arrays:

dv_A = { 5, -3, 2, 6} //4 elements
dv_B = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }

Expected output:

dv_B = { 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1 }

For every element in dv_A{}, there are (dv_A.size - 1) elements in dv_B{}. This is because each element of dv_A should have a child element in dv_B for each of the other dv_A elements (i.e. should exclude itself). Therefore, if there are 4 elements in dv_A, there should be 3 elements in dv_B for each of the dv_A elements.

I want to transform each dv_B element to have a value of 1 if its corresponding dv_A element has a value > 0. Correspondence is determined based on the position of the element in dv_B. For example:

The first 3 dv_B values will be transformed by the value in dv_A[0], The second 3 dv_B values will be transformed by the value in dv_A[1], Etc.

Here's my attempt so far

thrust::transform(
    dv_B.begin(),
    dv_B.end(),
    thrust::make_transform_iterator(
        dv_A.begin(),
        _1 % dv_A
    ), 
    dv_B.begin(),
    _2 > 0 //When 2nd argument is greater than 0 then set the position in dv_A to 1.
);

Upvotes: -2

Views: 303

Answers (2)

paleonix
paleonix

Reputation: 3095

Packing the fancy iterator creation into an appropriately named factory function makes this version quite readable as well. Especially if you need this kind of pattern more than once, this solution might be more elegant.

#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

#include <iostream>

// nvcc doesn't seem to like __device__ or __host__ __device__ lambdas in auto
// return types, so I defined this functor instead
template <typename T>
class Div {
    T div_{};
    public:
    Div(T div) : div_{div} {}
    __host__ __device__ T operator()(T in) const noexcept { return in / div_; }
};

// I stole "repeat" from numpy. Another version using modulo (%) and therefore
// repeating the whole input instead of single elements would be called "tile".
template <class It>
auto make_repeat_it_begin(It input, int repetitions) {
    using diff_t = typename It::difference_type;
    return thrust::make_permutation_iterator(
                input,
                thrust::make_transform_iterator(
                    thrust::make_counting_iterator(diff_t{0}),
                    Div{static_cast<diff_t>(repetitions)}));
}

int main() {
    int A[] = {5, -3, 2, 6};
    constexpr int size_A = sizeof(A) / sizeof(A[0]);

    thrust::host_vector<int> hv_A(A, A + size_A);
    thrust::device_vector<int> dv_A(hv_A);
    thrust::device_vector<int> dv_B(size_A * (size_A - 1));

    auto A_repeat_it = make_repeat_it_begin(dv_A.begin(), size_A - 1);
    
    thrust::transform(A_repeat_it, A_repeat_it + dv_B.size(), 
                      dv_B.begin(),
                      [] __device__ (int a){ return a > 0 ? 1 : 0; });

    thrust::host_vector<int> hv_B(dv_B);
    thrust::copy(hv_B.begin(), hv_B.end(),
                 std::ostream_iterator<int>(std::cout, ","));
}

Due to the device lambda, nvcc needs the -extended-lambda flag.

Upvotes: 2

Abator Abetor
Abator Abetor

Reputation: 2598

The serial code could look something like this:

for(int i = 0; i < dv_b.size(); i++){
    const int readIndex = i / (dv_a.size() - 1);
    if(dv_a[readIndex] > 0) dv_b[i] = 1;
    else dv_b[i] = 0;
}

which can easily be written using for_each. I think this makes the code more clear compared to using transform together with various fancy iterators.

thrust::for_each(
    thrust::device,
    thrust::make_counting_iterator(0),
    thrust::make_counting_iterator(0) + dv_b.size(),
    [
     s = dv_a.size() - 1,
     dv_a = thrust::raw_pointer_cast(dv_a.data()),
     dv_b = thrust::raw_pointer_cast(dv_b.data())
    ] __device__ (int i){
        const int readIndex = i / s;
        if(dv_a[readIndex] > 0) dv_b[i] = 1;
        else dv_b[i] = 0;
    }
);

Upvotes: 2

Related Questions