36 #include "../util_type.cuh"
37 #include "../util_ptx.cuh"
38 #include "../util_namespace.cuh"
107 int PTX_ARCH = CUB_PTX_ARCH>
120 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
127 T first_items[BLOCK_THREADS];
128 T last_items[BLOCK_THREADS];
137 __device__ __forceinline__ _TempStorage& PrivateStorage()
139 __shared__ _TempStorage private_storage;
140 return private_storage;
145 template <typename FlagOp, bool HAS_PARAM = BinaryOpHasIdxParam<T, FlagOp>::HAS_PARAM>
149 static __device__ __forceinline__
bool FlagT(FlagOp flag_op,
const T &a,
const T &b,
int idx)
151 return flag_op(a, b, idx);
156 template <
typename FlagOp>
157 struct ApplyOp<FlagOp, false>
160 static __device__ __forceinline__
bool FlagT(FlagOp flag_op,
const T &a,
const T &b,
int idx)
162 return flag_op(a, b);
167 template <
int ITERATION,
int MAX_ITERATIONS>
172 int ITEMS_PER_THREAD,
175 static __device__ __forceinline__
void FlagHeads(
177 FlagT (&flags)[ITEMS_PER_THREAD],
178 T (&input)[ITEMS_PER_THREAD],
179 T (&preds)[ITEMS_PER_THREAD],
182 preds[ITERATION] = input[ITERATION - 1];
184 flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
188 (linear_tid * ITEMS_PER_THREAD) + ITERATION);
190 Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagHeads(linear_tid, flags, input, preds, flag_op);
195 int ITEMS_PER_THREAD,
198 static __device__ __forceinline__
void FlagTails(
200 FlagT (&flags)[ITEMS_PER_THREAD],
201 T (&input)[ITEMS_PER_THREAD],
204 flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
207 input[ITERATION + 1],
208 (linear_tid * ITEMS_PER_THREAD) + ITERATION + 1);
210 Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagTails(linear_tid, flags, input, flag_op);
216 template <
int MAX_ITERATIONS>
217 struct Iterate<MAX_ITERATIONS, MAX_ITERATIONS>
221 int ITEMS_PER_THREAD,
224 static __device__ __forceinline__
void FlagHeads(
226 FlagT (&flags)[ITEMS_PER_THREAD],
227 T (&input)[ITEMS_PER_THREAD],
228 T (&preds)[ITEMS_PER_THREAD],
234 int ITEMS_PER_THREAD,
237 static __device__ __forceinline__
void FlagTails(
239 FlagT (&flags)[ITEMS_PER_THREAD],
240 T (&input)[ITEMS_PER_THREAD],
251 _TempStorage &temp_storage;
273 temp_storage(PrivateStorage()),
274 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
284 temp_storage(temp_storage.Alias()),
285 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
296 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
299 int ITEMS_PER_THREAD,
302 __device__ __forceinline__
void FlagHeads(
303 FlagT (&head_flags)[ITEMS_PER_THREAD],
304 T (&input)[ITEMS_PER_THREAD],
305 T (&preds)[ITEMS_PER_THREAD],
309 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
320 preds[0] = temp_storage.last_items[linear_tid - 1];
321 head_flags[0] = ApplyOp<FlagOp>::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD);
325 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
329 int ITEMS_PER_THREAD,
332 __device__ __forceinline__
void FlagHeads(
333 FlagT (&head_flags)[ITEMS_PER_THREAD],
334 T (&input)[ITEMS_PER_THREAD],
335 T (&preds)[ITEMS_PER_THREAD],
337 T tile_predecessor_item)
340 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
345 preds[0] = (linear_tid == 0) ?
346 tile_predecessor_item :
347 temp_storage.last_items[linear_tid - 1];
349 head_flags[0] = ApplyOp<FlagOp>::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD);
352 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
355 #endif // DOXYGEN_SHOULD_SKIP_THIS
408 int ITEMS_PER_THREAD,
412 FlagT (&head_flags)[ITEMS_PER_THREAD],
413 T (&input)[ITEMS_PER_THREAD],
416 T preds[ITEMS_PER_THREAD];
417 FlagHeads(head_flags, input, preds, flag_op);
477 int ITEMS_PER_THREAD,
481 FlagT (&head_flags)[ITEMS_PER_THREAD],
482 T (&input)[ITEMS_PER_THREAD],
484 T tile_predecessor_item)
486 T preds[ITEMS_PER_THREAD];
487 FlagHeads(head_flags, input, preds, flag_op, tile_predecessor_item);
550 int ITEMS_PER_THREAD,
554 FlagT (&tail_flags)[ITEMS_PER_THREAD],
555 T (&input)[ITEMS_PER_THREAD],
559 temp_storage.first_items[linear_tid] = input[0];
564 tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ?
566 ApplyOp<FlagOp>::FlagT(
568 input[ITEMS_PER_THREAD - 1],
569 temp_storage.first_items[linear_tid + 1],
570 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
573 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
634 int ITEMS_PER_THREAD,
638 FlagT (&tail_flags)[ITEMS_PER_THREAD],
639 T (&input)[ITEMS_PER_THREAD],
641 T tile_successor_item)
644 temp_storage.first_items[linear_tid] = input[0];
649 T successor_item = (linear_tid == BLOCK_THREADS - 1) ?
650 tile_successor_item :
651 temp_storage.first_items[linear_tid + 1];
653 tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp<FlagOp>::FlagT(
655 input[ITEMS_PER_THREAD - 1],
657 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
660 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
732 int ITEMS_PER_THREAD,
736 FlagT (&head_flags)[ITEMS_PER_THREAD],
737 FlagT (&tail_flags)[ITEMS_PER_THREAD],
738 T (&input)[ITEMS_PER_THREAD],
742 temp_storage.first_items[linear_tid] = input[0];
743 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
747 T preds[ITEMS_PER_THREAD];
750 preds[0] = temp_storage.last_items[linear_tid - 1];
757 head_flags[0] = ApplyOp<FlagOp>::FlagT(
761 linear_tid * ITEMS_PER_THREAD);
766 tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ?
768 ApplyOp<FlagOp>::FlagT(
770 input[ITEMS_PER_THREAD - 1],
771 temp_storage.first_items[linear_tid + 1],
772 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
775 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
778 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
848 int ITEMS_PER_THREAD,
852 FlagT (&head_flags)[ITEMS_PER_THREAD],
853 FlagT (&tail_flags)[ITEMS_PER_THREAD],
854 T tile_successor_item,
855 T (&input)[ITEMS_PER_THREAD],
859 temp_storage.first_items[linear_tid] = input[0];
860 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
864 T preds[ITEMS_PER_THREAD];
873 preds[0] = temp_storage.last_items[linear_tid - 1];
874 head_flags[0] = ApplyOp<FlagOp>::FlagT(
878 linear_tid * ITEMS_PER_THREAD);
882 T successor_item = (linear_tid == BLOCK_THREADS - 1) ?
883 tile_successor_item :
884 temp_storage.first_items[linear_tid + 1];
886 tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp<FlagOp>::FlagT(
888 input[ITEMS_PER_THREAD - 1],
890 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
893 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
896 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
972 int ITEMS_PER_THREAD,
976 FlagT (&head_flags)[ITEMS_PER_THREAD],
977 T tile_predecessor_item,
978 FlagT (&tail_flags)[ITEMS_PER_THREAD],
979 T (&input)[ITEMS_PER_THREAD],
983 temp_storage.first_items[linear_tid] = input[0];
984 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
988 T preds[ITEMS_PER_THREAD];
991 preds[0] = (linear_tid == 0) ?
992 tile_predecessor_item :
993 temp_storage.last_items[linear_tid - 1];
995 head_flags[0] = ApplyOp<FlagOp>::FlagT(
999 linear_tid * ITEMS_PER_THREAD);
1002 tail_flags[ITEMS_PER_THREAD - 1] = (linear_tid == BLOCK_THREADS - 1) ?
1004 ApplyOp<FlagOp>::FlagT(
1006 input[ITEMS_PER_THREAD - 1],
1007 temp_storage.first_items[linear_tid + 1],
1008 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
1011 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
1014 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
1091 int ITEMS_PER_THREAD,
1095 FlagT (&head_flags)[ITEMS_PER_THREAD],
1096 T tile_predecessor_item,
1097 FlagT (&tail_flags)[ITEMS_PER_THREAD],
1098 T tile_successor_item,
1099 T (&input)[ITEMS_PER_THREAD],
1103 temp_storage.first_items[linear_tid] = input[0];
1104 temp_storage.last_items[linear_tid] = input[ITEMS_PER_THREAD - 1];
1108 T preds[ITEMS_PER_THREAD];
1111 preds[0] = (linear_tid == 0) ?
1112 tile_predecessor_item :
1113 temp_storage.last_items[linear_tid - 1];
1115 head_flags[0] = ApplyOp<FlagOp>::FlagT(
1119 linear_tid * ITEMS_PER_THREAD);
1122 T successor_item = (linear_tid == BLOCK_THREADS - 1) ?
1123 tile_successor_item :
1124 temp_storage.first_items[linear_tid + 1];
1126 tail_flags[ITEMS_PER_THREAD - 1] = ApplyOp<FlagOp>::FlagT(
1128 input[ITEMS_PER_THREAD - 1],
1130 (linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);
1133 Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
1136 Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);