31 #ifndef _PARALLEL_REDUCTION_H_
32 #define _PARALLEL_REDUCTION_H_
34 #define _CG_ABI_EXPERIMENTAL
35 #include <cooperative_groups.h>
36 #include <cooperative_groups/reduce.h>
39 namespace cg = cooperative_groups;
45 __device__
inline operator T *() {
46 extern __shared__
int __smem[];
50 __device__
inline operator const T *()
const {
51 extern __shared__
int __smem[];
60 __device__
inline operator double *() {
61 extern __shared__
double __smem_d[];
62 return (
double *)__smem_d;
65 __device__
inline operator const double *()
const {
66 extern __shared__
double __smem_d[];
67 return (
double *)__smem_d;
72 __device__ __forceinline__ T
warpReduceSum(
unsigned int mask, T mySum) {
73 for (
int offset = warpSize / 2; offset > 0; offset /= 2) {
74 mySum += __shfl_down_sync(mask, mySum, offset);
79 #if __CUDA_ARCH__ >= 800
83 __device__ __forceinline__
int warpReduceSum<int>(
unsigned int mask,
85 mySum = __reduce_add_sync(mask, mySum);
102 __global__
void reduce0(T *g_idata, T *g_odata,
unsigned int n) {
104 cg::thread_block cta = cg::this_thread_block();
108 unsigned int tid = threadIdx.x;
109 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
111 sdata[tid] = (i < n) ? g_idata[i] : 0;
116 for (
unsigned int s = 1; s < blockDim.x; s *= 2) {
118 if ((tid % (2 * s)) == 0) {
119 sdata[tid] += sdata[tid + s];
126 if (tid == 0) g_odata[blockIdx.x] = sdata[0];
133 __global__
void reduce1(T *g_idata, T *g_odata,
unsigned int n) {
135 cg::thread_block cta = cg::this_thread_block();
139 unsigned int tid = threadIdx.x;
140 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
142 sdata[tid] = (i < n) ? g_idata[i] : 0;
147 for (
unsigned int s = 1; s < blockDim.x; s *= 2) {
148 int index = 2 * s * tid;
150 if (index < blockDim.x) {
151 sdata[index] += sdata[index + s];
158 if (tid == 0) g_odata[blockIdx.x] = sdata[0];
165 __global__
void reduce2(T *g_idata, T *g_odata,
unsigned int n) {
167 cg::thread_block cta = cg::this_thread_block();
171 unsigned int tid = threadIdx.x;
172 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
174 sdata[tid] = (i < n) ? g_idata[i] : 0;
179 for (
unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
181 sdata[tid] += sdata[tid + s];
188 if (tid == 0) g_odata[blockIdx.x] = sdata[0];
196 __global__
void reduce3(T *g_idata, T *g_odata,
unsigned int n) {
198 cg::thread_block cta = cg::this_thread_block();
203 unsigned int tid = threadIdx.x;
204 unsigned int i = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
206 T mySum = (i < n) ? g_idata[i] : 0;
208 if (i + blockDim.x < n) mySum += g_idata[i + blockDim.x];
214 for (
unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
216 sdata[tid] = mySum = mySum + sdata[tid + s];
223 if (tid == 0) g_odata[blockIdx.x] = mySum;
238 template <
class T,
unsigned int blockSize>
239 __global__
void reduce4(T *g_idata, T *g_odata,
unsigned int n) {
241 cg::thread_block cta = cg::this_thread_block();
246 unsigned int tid = threadIdx.x;
247 unsigned int i = blockIdx.x * (blockDim.x * 2) + threadIdx.x;
249 T mySum = (i < n) ? g_idata[i] : 0;
251 if (i + blockSize < n) mySum += g_idata[i + blockSize];
257 for (
unsigned int s = blockDim.x / 2; s > 32; s >>= 1) {
259 sdata[tid] = mySum = mySum + sdata[tid + s];
265 cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
267 if (cta.thread_rank() < 32) {
269 if (blockSize >= 64) mySum += sdata[tid + 32];
271 for (
int offset = tile32.size() / 2; offset > 0; offset /= 2) {
272 mySum += tile32.shfl_down(mySum, offset);
277 if (cta.thread_rank() == 0) g_odata[blockIdx.x] = mySum;
291 template <
class T,
unsigned int blockSize>
292 __global__
void reduce5(T *g_idata, T *g_odata,
unsigned int n) {
294 cg::thread_block cta = cg::this_thread_block();
299 unsigned int tid = threadIdx.x;
300 unsigned int i = blockIdx.x * (blockSize * 2) + threadIdx.x;
302 T mySum = (i < n) ? g_idata[i] : 0;
304 if (i + blockSize < n) mySum += g_idata[i + blockSize];
310 if ((blockSize >= 512) && (tid < 256)) {
311 sdata[tid] = mySum = mySum + sdata[tid + 256];
316 if ((blockSize >= 256) && (tid < 128)) {
317 sdata[tid] = mySum = mySum + sdata[tid + 128];
322 if ((blockSize >= 128) && (tid < 64)) {
323 sdata[tid] = mySum = mySum + sdata[tid + 64];
328 cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
330 if (cta.thread_rank() < 32) {
332 if (blockSize >= 64) mySum += sdata[tid + 32];
334 for (
int offset = tile32.size() / 2; offset > 0; offset /= 2) {
335 mySum += tile32.shfl_down(mySum, offset);
340 if (cta.thread_rank() == 0) g_odata[blockIdx.x] = mySum;
351 template <
class T,
unsigned int blockSize,
bool nIsPow2>
352 __global__
void reduce6(T *g_idata, T *g_odata,
unsigned int n) {
354 cg::thread_block cta = cg::this_thread_block();
359 unsigned int tid = threadIdx.x;
360 unsigned int gridSize = blockSize * gridDim.x;
368 unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
369 gridSize = gridSize << 1;
375 if ((i + blockSize) < n) {
376 mySum += g_idata[i + blockSize];
381 unsigned int i = blockIdx.x * blockSize + threadIdx.x;
393 if ((blockSize >= 512) && (tid < 256)) {
394 sdata[tid] = mySum = mySum + sdata[tid + 256];
399 if ((blockSize >= 256) && (tid < 128)) {
400 sdata[tid] = mySum = mySum + sdata[tid + 128];
405 if ((blockSize >= 128) && (tid < 64)) {
406 sdata[tid] = mySum = mySum + sdata[tid + 64];
411 cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
413 if (cta.thread_rank() < 32) {
415 if (blockSize >= 64) mySum += sdata[tid + 32];
417 for (
int offset = tile32.size() / 2; offset > 0; offset /= 2) {
418 mySum += tile32.shfl_down(mySum, offset);
423 if (cta.thread_rank() == 0) g_odata[blockIdx.x] = mySum;
426 template <
typename T,
unsigned int blockSize,
bool nIsPow2>
427 __global__
void reduce7(
const T *__restrict__ g_idata, T *__restrict__ g_odata,
433 unsigned int tid = threadIdx.x;
434 unsigned int gridSize = blockSize * gridDim.x;
435 unsigned int maskLength = (blockSize & 31);
436 maskLength = (maskLength > 0) ? (32 - maskLength) : maskLength;
437 const unsigned int mask = (0xffffffff) >> maskLength;
445 unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
446 gridSize = gridSize << 1;
452 if ((i + blockSize) < n) {
453 mySum += g_idata[i + blockSize];
458 unsigned int i = blockIdx.x * blockSize + threadIdx.x;
467 mySum = warpReduceSum<T>(mask, mySum);
470 if ((tid % warpSize) == 0) {
471 sdata[tid / warpSize] = mySum;
476 const unsigned int shmem_extent =
477 (blockSize / warpSize) > 0 ? (blockSize / warpSize) : 1;
478 const unsigned int ballot_result = __ballot_sync(mask, tid < shmem_extent);
479 if (tid < shmem_extent) {
483 mySum = warpReduceSum<T>(ballot_result, mySum);
488 g_odata[blockIdx.x] = mySum;
493 template <
typename T,
typename Group>
495 return cg::reduce(threads, in, cg::plus<T>());
499 __global__
void cg_reduce(T *g_idata, T *g_odata,
unsigned int n) {
503 cg::thread_block cta = cg::this_thread_block();
505 cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
507 unsigned int ctaSize = cta.size();
508 unsigned int numCtas = gridDim.x;
509 unsigned int threadRank = cta.thread_rank();
510 unsigned int threadIndex = (blockIdx.x * ctaSize) + threadRank;
514 unsigned int i = threadIndex;
515 unsigned int indexStride = (numCtas * ctaSize);
517 threadVal += g_idata[i];
520 sdata[threadRank] = threadVal;
525 unsigned int ctaSteps = tile.meta_group_size();
526 unsigned int ctaIndex = ctaSize >> 1;
527 while (ctaIndex >= 32) {
529 if (threadRank < ctaIndex) {
530 threadVal += sdata[threadRank + ctaIndex];
531 sdata[threadRank] = threadVal;
541 if (tile.meta_group_rank() == 0) {
546 if (threadRank == 0) g_odata[blockIdx.x] = threadVal;
549 template <
class T,
size_t BlockSize,
size_t MultiWarpGroupSize>
553 __shared__ cg::experimental::block_tile_memory<sizeof(T), BlockSize> scratch;
556 auto cta = cg::experimental::this_thread_block(scratch);
558 auto multiWarpTile = cg::experimental::tiled_partition<MultiWarpGroupSize>(cta);
560 unsigned int gridSize = BlockSize * gridDim.x;
566 int nIsPow2 = !(n & n-1);
568 unsigned int i = blockIdx.x * BlockSize * 2 + threadIdx.x;
569 gridSize = gridSize << 1;
572 threadVal += g_idata[i];
575 if ((i + BlockSize) < n) {
576 threadVal += g_idata[i + blockDim.x];
581 unsigned int i = blockIdx.x * BlockSize + threadIdx.x;
583 threadVal += g_idata[i];
590 if (multiWarpTile.thread_rank() == 0) {
591 sdata[multiWarpTile.meta_group_rank()] = threadVal;
595 if (threadIdx.x == 0) {
597 for (
int i=0; i < multiWarpTile.meta_group_size(); i++) {
598 threadVal += sdata[i];
600 g_odata[blockIdx.x] = threadVal;
604 extern "C" bool isPow2(
unsigned int x);
610 void reduce(
int size,
int threads,
int blocks,
int whichKernel, T *d_idata,
612 dim3 dimBlock(threads, 1, 1);
613 dim3 dimGrid(blocks, 1, 1);
618 (threads <= 32) ? 2 * threads *
sizeof(T) : threads *
sizeof(T);
622 if (threads < 64 && whichKernel == 9)
628 switch (whichKernel) {
630 reduce0<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
634 reduce1<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
638 reduce2<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
642 reduce3<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
649 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
654 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
659 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
664 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
669 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
674 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
679 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
684 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
689 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
694 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
704 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
709 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
714 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
719 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
724 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
729 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
734 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
739 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
744 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
749 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
759 reduce6<T, 512, true>
760 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
764 reduce6<T, 256, true>
765 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
769 reduce6<T, 128, true>
770 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
775 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
780 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
785 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
790 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
795 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
800 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
805 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
811 reduce6<T, 512, false>
812 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
816 reduce6<T, 256, false>
817 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
821 reduce6<T, 128, false>
822 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
826 reduce6<T, 64, false>
827 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
831 reduce6<T, 32, false>
832 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
836 reduce6<T, 16, false>
837 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
842 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
847 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
852 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
857 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
867 smemSize = ((threads / 32) + 1) *
sizeof(T);
871 reduce7<T, 1024, true>
872 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
875 reduce7<T, 512, true>
876 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
880 reduce7<T, 256, true>
881 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
885 reduce7<T, 128, true>
886 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
891 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
896 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
901 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
906 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
911 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
916 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
921 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
927 reduce7<T, 1024, true>
928 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
931 reduce7<T, 512, false>
932 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
936 reduce7<T, 256, false>
937 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
941 reduce7<T, 128, false>
942 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
946 reduce7<T, 64, false>
947 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
951 reduce7<T, 32, false>
952 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
956 reduce7<T, 16, false>
957 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
962 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
967 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
972 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
977 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
984 cg_reduce<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
987 constexpr
int numOfMultiWarpGroups = 2;
988 smemSize = numOfMultiWarpGroups *
sizeof(T);
992 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
997 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
1002 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
1007 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
1012 <<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
1016 printf(
"thread block size of < 64 is not supported for this kernel\n");
1024 template void reduce<int>(
int size,
int threads,
int blocks,
int whichKernel,
1025 int *d_idata,
int *d_odata);
1027 template void reduce<float>(
int size,
int threads,
int blocks,
int whichKernel,
1028 float *d_idata,
float *d_odata);
1030 template void reduce<double>(
int size,
int threads,
int blocks,
int whichKernel,
1031 double *d_idata,
double *d_odata);
1033 #endif // #ifndef _REDUCE_KERNEL_H_
__global__ void reduce3(T *g_idata, T *g_odata, unsigned int n)
__device__ __forceinline__ T warpReduceSum(unsigned int mask, T mySum)
__global__ void reduce4(T *g_idata, T *g_odata, unsigned int n)
__global__ void reduce5(T *g_idata, T *g_odata, unsigned int n)
__global__ void reduce1(T *g_idata, T *g_odata, unsigned int n)
__device__ T cg_reduce_n(T in, Group &threads)
template void reduce< float >(int size, int threads, int blocks, int whichKernel, float *d_idata, float *d_odata)
__global__ void multi_warp_cg_reduce(T *g_idata, T *g_odata, unsigned int n)
void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata, T *d_odata)
__global__ void reduce2(T *g_idata, T *g_odata, unsigned int n)
__global__ void cg_reduce(T *g_idata, T *g_odata, unsigned int n)
__global__ void reduce7(const T *__restrict__ g_idata, T *__restrict__ g_odata, unsigned int n)
template void reduce< int >(int size, int threads, int blocks, int whichKernel, int *d_idata, int *d_odata)
__global__ void reduce0(T *g_idata, T *g_odata, unsigned int n)
template void reduce< double >(int size, int threads, int blocks, int whichKernel, double *d_idata, double *d_odata)
__global__ void reduce6(T *g_idata, T *g_odata, unsigned int n)
bool isPow2(unsigned int x)