sungjun cho
sungjun cho

Reputation: 941

How can I synchronize threads within warp in conditional while statement in CUDA?

Let's assume that we have following codes:

while (condition) {
  ...

  for (uint32_t gap = x >> 1; gap > 0; gap >>= 1) {
    val += __shfl_down_sync(mask, val, gap);
  }

  if (warpLane == 0)
    atomicAdd(&global_memory[threadIdx.x], val);

  ...
}

In this scenario, if threads in the warp enter the while loop as the following sequence:

all 32 threads, all 32 threads, only 16 threads.

how can I get thread mask that participates in while loop statements?

Below code may cause undefined behavior according to the guide described in https://devblogs.nvidia.com/using-cuda-warp-level-primitives:

while (condition) {
  uint32_t active = __activemask();
  for (uint32_t gap = x >> 1; gap > 0; gap >>= 1) {
    val += __shfl_down_sync(active, val, gap);
  }

  if (warpLane == 0)
    atomicAdd(&global_memory[threadIdx.x], val);

  ...
}

According to the guide, __activemask() might not generate mask as I expected.

Below also causes undefined behavior according to the above guide:

while (condition) {
  uint32_t active = __activemask();
  for (uint32_t gap = x >> 1; gap > 0; gap >>= 1) {
    val += __shfl_down_sync(active, val, gap);
  }

  if (warpLane == 0)
    atomicAdd(&global_memory[threadIdx.x], val);

  ...
  __warpsync(active);
}

Then, how I can get mask correctly?

Upvotes: 1

Views: 306

Answers (1)

Oblivion
Oblivion

Reputation: 7374

You can use cooperative groups like:

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

while (condition) { 
...
auto active = cg::coalesced_threads(); // this line can be moved out of while if the condition does not cause thread divergence

 for (uint32_t gap = x >> 1; gap > 0; gap >>= 1) { 
        //val += __shfl_down_sync(mask, val, gap);
        val += active.shfl_down(val, gap);
 }
 if (warpLane == 0)
    atomicAdd(&global_memory[threadIdx.x], val); 

... 
}

If you want to generate the mask yourself and do old fashioned you can use:

uint32_t FullMask = 0xFFFFFFFF;
uint32_t mask =  __ballot_sync(FullMask, someCondition);

However if you had further branching in your code you have to always keep track of the mask before branching and use it instead of FullMask in the ballot. So the second update before branch will be:

uint32_t newMask =  __ballot_sync(mask, someNewCondition);

Upvotes: 2

Related Questions