38 #include "block_radix_rank.cuh"
39 #include "../util_ptx.cuh"
40 #include "../util_arch.cuh"
41 #include "../util_type.cuh"
42 #include "../util_namespace.cuh"
122 int ITEMS_PER_THREAD,
123 typename ValueT = NullType,
125 bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ?
true :
false,
127 cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
130 int PTX_ARCH = CUB_PTX_ARCH>
142 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
149 typedef Traits<KeyT> KeyTraits;
150 typedef typename KeyTraits::UnsignedBits UnsignedBits;
153 typedef BlockRadixRank<
158 INNER_SCAN_ALGORITHM,
163 AscendingBlockRadixRank;
166 typedef BlockRadixRank<
171 INNER_SCAN_ALGORITHM,
176 DescendingBlockRadixRank;
189 typename AscendingBlockRadixRank::TempStorage asending_ranking_storage;
190 typename DescendingBlockRadixRank::TempStorage descending_ranking_storage;
202 _TempStorage &temp_storage;
212 __device__ __forceinline__ _TempStorage& PrivateStorage()
214 __shared__ _TempStorage private_storage;
215 return private_storage;
219 __device__ __forceinline__
void RankKeys(
220 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
221 int (&ranks)[ITEMS_PER_THREAD],
224 Int2Type<false> is_descending)
226 AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys(
234 __device__ __forceinline__
void RankKeys(
235 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
236 int (&ranks)[ITEMS_PER_THREAD],
239 Int2Type<true> is_descending)
241 DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys(
249 __device__ __forceinline__
void ExchangeValues(
250 ValueT (&values)[ITEMS_PER_THREAD],
251 int (&ranks)[ITEMS_PER_THREAD],
252 Int2Type<false> is_keys_only,
253 Int2Type<true> is_blocked)
258 BlockExchangeValues(temp_storage.exchange_values).ScatterToBlocked(values, ranks);
262 __device__ __forceinline__
void ExchangeValues(
263 ValueT (&values)[ITEMS_PER_THREAD],
264 int (&ranks)[ITEMS_PER_THREAD],
265 Int2Type<false> is_keys_only,
266 Int2Type<false> is_blocked)
271 BlockExchangeValues(temp_storage.exchange_values).ScatterToStriped(values, ranks);
275 template <
int IS_BLOCKED>
276 __device__ __forceinline__
void ExchangeValues(
277 ValueT (&values)[ITEMS_PER_THREAD],
278 int (&ranks)[ITEMS_PER_THREAD],
279 Int2Type<true> is_keys_only,
280 Int2Type<IS_BLOCKED> is_blocked)
284 template <
int DESCENDING,
int KEYS_ONLY>
285 __device__ __forceinline__
void SortBlocked(
286 KeyT (&keys)[ITEMS_PER_THREAD],
287 ValueT (&values)[ITEMS_PER_THREAD],
290 Int2Type<DESCENDING> is_descending,
291 Int2Type<KEYS_ONLY> is_keys_only)
293 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
294 reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]
>(keys);
298 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
300 unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
306 int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
309 int ranks[ITEMS_PER_THREAD];
310 RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
311 begin_bit += RADIX_BITS;
316 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
319 ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
322 if (begin_bit >= end_bit)
break;
329 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
331 unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
337 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
340 template <
int DESCENDING,
int KEYS_ONLY>
341 __device__ __forceinline__
void SortBlockedToStriped(
342 KeyT (&keys)[ITEMS_PER_THREAD],
343 ValueT (&values)[ITEMS_PER_THREAD],
346 Int2Type<DESCENDING> is_descending,
347 Int2Type<KEYS_ONLY> is_keys_only)
349 UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
350 reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]
>(keys);
354 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
356 unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
362 int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
365 int ranks[ITEMS_PER_THREAD];
366 RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
367 begin_bit += RADIX_BITS;
372 if (begin_bit >= end_bit)
375 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks);
378 ExchangeValues(values, ranks, is_keys_only, Int2Type<false>());
385 BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
388 ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
395 for (
int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
397 unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
401 #endif // DOXYGEN_SHOULD_SKIP_THIS
417 temp_storage(PrivateStorage()),
418 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
428 temp_storage(temp_storage.Alias()),
429 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
476 __device__ __forceinline__
void Sort(
477 KeyT (&keys)[ITEMS_PER_THREAD],
479 int end_bit =
sizeof(KeyT) * 8)
481 NullType values[ITEMS_PER_THREAD];
483 SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
531 __device__ __forceinline__
void Sort(
532 KeyT (&keys)[ITEMS_PER_THREAD],
533 ValueT (&values)[ITEMS_PER_THREAD],
535 int end_bit =
sizeof(KeyT) * 8)
537 SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
578 KeyT (&keys)[ITEMS_PER_THREAD],
580 int end_bit =
sizeof(KeyT) * 8)
582 NullType values[ITEMS_PER_THREAD];
584 SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
633 KeyT (&keys)[ITEMS_PER_THREAD],
634 ValueT (&values)[ITEMS_PER_THREAD],
636 int end_bit =
sizeof(KeyT) * 8)
638 SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
688 KeyT (&keys)[ITEMS_PER_THREAD],
690 int end_bit =
sizeof(KeyT) * 8)
692 NullType values[ITEMS_PER_THREAD];
694 SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
743 KeyT (&keys)[ITEMS_PER_THREAD],
744 ValueT (&values)[ITEMS_PER_THREAD],
746 int end_bit =
sizeof(KeyT) * 8)
748 SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
791 KeyT (&keys)[ITEMS_PER_THREAD],
793 int end_bit =
sizeof(KeyT) * 8)
795 NullType values[ITEMS_PER_THREAD];
797 SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
846 KeyT (&keys)[ITEMS_PER_THREAD],
847 ValueT (&values)[ITEMS_PER_THREAD],
849 int end_bit =
sizeof(KeyT) * 8)
851 SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());