Reputation: 1526
I am trying to implement a fixed-size version of Thrust's device vector. I coded some initial version but I am getting a curious template error.
Here is the code:
#include <iostream>
#include <array>
enum class memcpy_t {
host_to_host,
host_to_device,
device_to_host,
device_to_device
};
template <typename T, std::size_t N>
struct cuda_allocator {
using pointer = T*;
static void allocate(T *dev_mem) {
cudaMalloc(&dev_mem, N * sizeof(T));
}
static void deallocate(T *dev_mem) {
cudaFree(dev_mem);
}
template <memcpy_t ct>
static void copy (T *dst, T *src) {
switch(ct) {
case memcpy_t::host_to_host:
cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyHostToHost);
break;
case memcpy_t::host_to_device:
cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyHostToDevice);
break;
case memcpy_t::device_to_host:
cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyDeviceToHost);
break;
case memcpy_t::device_to_device:
cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyDeviceToDevice);
break;
default:
break;
}
}
};
template <typename T, std::size_t N>
struct gpu_array {
using allocator = cuda_allocator<T, N>;
using pointer = typename allocator::pointer;
using value_type = T;
using iterator = T*;
using const_iterator = T const*;
gpu_array() {
allocator::allocate(data);
}
gpu_array(std::array<T, N> host_arr) {
allocator::allocate(data);
allocator::copy<memcpy_t::host_to_device>(data, host_arr.begin());
}
gpu_array& operator=(gpu_array const& o) {
allocator::allocate(data);
allocator::copy<memcpy_t::device_to_device>(data, o.begin());
}
operator std::array<T, N>() {
std::array<T, N> res;
allocator::copy<memcpy_t::device_to_host>(res.begin(), data);
return res;
}
~gpu_array() {
allocator::deallocate(data);
}
__device__ iterator begin() { return data; }
__device__ iterator end() { return data + N; }
__device__ const_iterator begin() const { return data; }
__device__ const_iterator end() const { return data + N; }
private:
T* data;
};
template <typename T, std::size_t N>
__global__ void add_kernel(gpu_array<T,N> &r,
gpu_array<T,N> const&a1,
gpu_array<T,N> const&a2) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
r.begin()[i] = a1.begin()[i] + a2.begin()[i];
}
template <typename T, std::size_t N>
gpu_array<T, N> operator+(gpu_array<T,N> const&a1,
gpu_array<T,N> const&a2)
{
gpu_array<T, N> res;
add_kernel<<<(N+255)/256, 256>>>(res, a1, a2);
return res;
}
const int N = 1<<20;
int main() {
std::array<float, N> x,y;
for (int i = 0; i < N; i++) {
x[i] = 1.0f;
y[i] = 2.0f;
}
gpu_array<float, N> dx{x};
gpu_array<float, N> dy{y};
std::array<float, N> res = dx + dy;
for(const auto& elem : res) {
std::cout << elem << ", ";
}
}
There might be tons of other errors but I am stuck at a curious one. nvcc
gives me the following error:
error: no match for 'operator<' (operand types are '<unresolved overloaded function type>' and 'memcpy_t')
allocator::copy<memcpy_t::host_to_device>(data, host_arr.begin());
For some reason, does it see my enum class template parameter as operator<
? By the way, this is compiled with options -arch=sm_70 -std=c++14
. I am not well-educated on how C++ and CUDA interacts so I could not solve the problem.
Upvotes: 1
Views: 313
Reputation: 72349
It took a bit of head scratching, but the underlying problem here is defective syntax according to the C++ standard. It is the host compiler generating the error, and it is perfectly correct for it to do so, as far as I can see. Refer here for all the gory details.
Your code which uses the specializations of copy
should look like this:
gpu_array(std::array<T, N> host_arr) {
allocator::allocate(data);
allocator::template copy<memcpy_t::host_to_device>(data, host_arr.begin());
}
gpu_array& operator=(gpu_array const& o) {
allocator::allocate(data);
allocator::template copy<memcpy_t::device_to_device>(data, o.begin());
}
operator std::array<T, N>() {
std::array<T, N> res;
allocator::template copy<memcpy_t::device_to_host>(res.begin(), data);
return res;
}
That might be the strangest looking syntax ever, but it is what is required to make the compiler honor <
as a template token and not an operator. Fix that everywhere in your code and this particular compiler error should disappear.
Upvotes: 2