Reputation:
I started using cooperative_groups and find myself often to wish for a method that replaces the second line.
thread_block_tile<32> tile = tiled_partition<32>(this_thread_block());
int tileId = this_thread_block().thread_rank()/tile.size();
My assumption here is that:
tileId
is the same for every thread in the same tile.tileId
goes from 0
to (this_thread_block().size())/tile.size()
.I looked into https://devblogs.nvidia.com/cooperative-groups/ and https://docs.nvidia.com/cuda/archive/9.2/cuda-c-programming-guide/index.html#thread-block-tiles-cg. In both sources there is an example similar to:
thread_group tile4 = tiled_partition(this_thread_block(), 4);
if (tile4.thread_rank()==0)
printf("Hello from tile4 rank 0: %d\n",
this_thread_block().thread_rank());
that produces:
Hello from tile4 rank 0: 0
Hello from tile4 rank 0: 4
Hello from tile4 rank 0: 8
Hello from tile4 rank 0: 12
Which seems to fit with the assumptions.
I am left with two questions:
tileId
?Example usecase
__device__
int someFkt(thread_block_tile<16> tile, int* data)
{
// some stuff that works best using 16 threads
}
__global__
void some_kernel(int* data)
{
thread_block_tile<16> tile = tiled_partition<16>(this_thread_block());
int tileId = this_thread_block().thread_rank()/tile.size();
int result = someFkt(tile,data+tileId*tile.size());
}
Upvotes: 1
Views: 649
Reputation: 7374
It is correct that the
tileId
goes from0
to(this_thread_block().size())/32
if the tile size was 32
.
And the tileId
is indeed the same for all the threads in the same tile. These tileId
s are also the same for all the blocks, so all the blocks have tileId
0, 1, ...
Only thread_block
provide its index:
whereas
thread_block
provides the following additional block-specific functionality:
dim3 group_index(); // 3-dimensional block index within the grid
dim3 thread_index(); // 3-dimensional thread index within the block
Not sure if this was a typo in your example use case:
int tileId = this_thread_block().thread_rank()/32;
The correct from is:
int tileId = this_thread_block().thread_rank()/16;
Upvotes: 1