CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups
block_radix_sort.cuh
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2015, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
35 #pragma once
36 
37 #include "block_exchange.cuh"
38 #include "block_radix_rank.cuh"
39 #include "../util_ptx.cuh"
40 #include "../util_arch.cuh"
41 #include "../util_type.cuh"
42 #include "../util_namespace.cuh"
43 
45 CUB_NS_PREFIX
46 
48 namespace cub {
49 
119 template <
120  typename KeyT,
121  int BLOCK_DIM_X,
122  int ITEMS_PER_THREAD,
123  typename ValueT = NullType,
124  int RADIX_BITS = 4,
125  bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false,
126  BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
127  cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte,
128  int BLOCK_DIM_Y = 1,
129  int BLOCK_DIM_Z = 1,
130  int PTX_ARCH = CUB_PTX_ARCH>
132 {
133 private:
134 
135  /******************************************************************************
136  * Constants and type definitions
137  ******************************************************************************/
138 
139  enum
140  {
141  // The thread block size in threads
142  BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
143 
144  // Whether or not there are values to be trucked along with keys
146  };
147 
148  // KeyT traits and unsigned bits type
149  typedef Traits<KeyT> KeyTraits;
150  typedef typename KeyTraits::UnsignedBits UnsignedBits;
151 
153  typedef BlockRadixRank<
154  BLOCK_DIM_X,
155  RADIX_BITS,
156  false,
157  MEMOIZE_OUTER_SCAN,
158  INNER_SCAN_ALGORITHM,
159  SMEM_CONFIG,
160  BLOCK_DIM_Y,
161  BLOCK_DIM_Z,
162  PTX_ARCH>
163  AscendingBlockRadixRank;
164 
166  typedef BlockRadixRank<
167  BLOCK_DIM_X,
168  RADIX_BITS,
169  true,
170  MEMOIZE_OUTER_SCAN,
171  INNER_SCAN_ALGORITHM,
172  SMEM_CONFIG,
173  BLOCK_DIM_Y,
174  BLOCK_DIM_Z,
175  PTX_ARCH>
176  DescendingBlockRadixRank;
177 
180 
183 
185  struct _TempStorage
186  {
187  union
188  {
189  typename AscendingBlockRadixRank::TempStorage asending_ranking_storage;
190  typename DescendingBlockRadixRank::TempStorage descending_ranking_storage;
191  typename BlockExchangeKeys::TempStorage exchange_keys;
192  typename BlockExchangeValues::TempStorage exchange_values;
193  };
194  };
195 
196 
197  /******************************************************************************
198  * Thread fields
199  ******************************************************************************/
200 
202  _TempStorage &temp_storage;
203 
205  int linear_tid;
206 
207  /******************************************************************************
208  * Utility methods
209  ******************************************************************************/
210 
212  __device__ __forceinline__ _TempStorage& PrivateStorage()
213  {
214  __shared__ _TempStorage private_storage;
215  return private_storage;
216  }
217 
219  __device__ __forceinline__ void RankKeys(
220  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
221  int (&ranks)[ITEMS_PER_THREAD],
222  int begin_bit,
223  int pass_bits,
224  Int2Type<false> is_descending)
225  {
226  AscendingBlockRadixRank(temp_storage.asending_ranking_storage).RankKeys(
227  unsigned_keys,
228  ranks,
229  begin_bit,
230  pass_bits);
231  }
232 
234  __device__ __forceinline__ void RankKeys(
235  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD],
236  int (&ranks)[ITEMS_PER_THREAD],
237  int begin_bit,
238  int pass_bits,
239  Int2Type<true> is_descending)
240  {
241  DescendingBlockRadixRank(temp_storage.descending_ranking_storage).RankKeys(
242  unsigned_keys,
243  ranks,
244  begin_bit,
245  pass_bits);
246  }
247 
249  __device__ __forceinline__ void ExchangeValues(
250  ValueT (&values)[ITEMS_PER_THREAD],
251  int (&ranks)[ITEMS_PER_THREAD],
252  Int2Type<false> is_keys_only,
253  Int2Type<true> is_blocked)
254  {
255  __syncthreads();
256 
257  // Exchange values through shared memory in blocked arrangement
258  BlockExchangeValues(temp_storage.exchange_values).ScatterToBlocked(values, ranks);
259  }
260 
262  __device__ __forceinline__ void ExchangeValues(
263  ValueT (&values)[ITEMS_PER_THREAD],
264  int (&ranks)[ITEMS_PER_THREAD],
265  Int2Type<false> is_keys_only,
266  Int2Type<false> is_blocked)
267  {
268  __syncthreads();
269 
270  // Exchange values through shared memory in blocked arrangement
271  BlockExchangeValues(temp_storage.exchange_values).ScatterToStriped(values, ranks);
272  }
273 
275  template <int IS_BLOCKED>
276  __device__ __forceinline__ void ExchangeValues(
277  ValueT (&values)[ITEMS_PER_THREAD],
278  int (&ranks)[ITEMS_PER_THREAD],
279  Int2Type<true> is_keys_only,
280  Int2Type<IS_BLOCKED> is_blocked)
281  {}
282 
284  template <int DESCENDING, int KEYS_ONLY>
285  __device__ __forceinline__ void SortBlocked(
286  KeyT (&keys)[ITEMS_PER_THREAD],
287  ValueT (&values)[ITEMS_PER_THREAD],
288  int begin_bit,
289  int end_bit,
290  Int2Type<DESCENDING> is_descending,
291  Int2Type<KEYS_ONLY> is_keys_only)
292  {
293  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
294  reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
295 
296  // Twiddle bits if necessary
297  #pragma unroll
298  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
299  {
300  unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
301  }
302 
303  // Radix sorting passes
304  while (true)
305  {
306  int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
307 
308  // Rank the blocked keys
309  int ranks[ITEMS_PER_THREAD];
310  RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
311  begin_bit += RADIX_BITS;
312 
313  __syncthreads();
314 
315  // Exchange keys through shared memory in blocked arrangement
316  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
317 
318  // Exchange values through shared memory in blocked arrangement
319  ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
320 
321  // Quit if done
322  if (begin_bit >= end_bit) break;
323 
324  __syncthreads();
325  }
326 
327  // Untwiddle bits if necessary
328  #pragma unroll
329  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
330  {
331  unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
332  }
333  }
334 
335 public:
336 
337 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
338 
340  template <int DESCENDING, int KEYS_ONLY>
341  __device__ __forceinline__ void SortBlockedToStriped(
342  KeyT (&keys)[ITEMS_PER_THREAD],
343  ValueT (&values)[ITEMS_PER_THREAD],
344  int begin_bit,
345  int end_bit,
346  Int2Type<DESCENDING> is_descending,
347  Int2Type<KEYS_ONLY> is_keys_only)
348  {
349  UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD] =
350  reinterpret_cast<UnsignedBits (&)[ITEMS_PER_THREAD]>(keys);
351 
352  // Twiddle bits if necessary
353  #pragma unroll
354  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
355  {
356  unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]);
357  }
358 
359  // Radix sorting passes
360  while (true)
361  {
362  int pass_bits = CUB_MIN(RADIX_BITS, end_bit - begin_bit);
363 
364  // Rank the blocked keys
365  int ranks[ITEMS_PER_THREAD];
366  RankKeys(unsigned_keys, ranks, begin_bit, pass_bits, is_descending);
367  begin_bit += RADIX_BITS;
368 
369  __syncthreads();
370 
371  // Check if this is the last pass
372  if (begin_bit >= end_bit)
373  {
374  // Last pass exchanges keys through shared memory in striped arrangement
375  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToStriped(keys, ranks);
376 
377  // Last pass exchanges through shared memory in striped arrangement
378  ExchangeValues(values, ranks, is_keys_only, Int2Type<false>());
379 
380  // Quit
381  break;
382  }
383 
384  // Exchange keys through shared memory in blocked arrangement
385  BlockExchangeKeys(temp_storage.exchange_keys).ScatterToBlocked(keys, ranks);
386 
387  // Exchange values through shared memory in blocked arrangement
388  ExchangeValues(values, ranks, is_keys_only, Int2Type<true>());
389 
390  __syncthreads();
391  }
392 
393  // Untwiddle bits if necessary
394  #pragma unroll
395  for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
396  {
397  unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]);
398  }
399  }
400 
401 #endif // DOXYGEN_SHOULD_SKIP_THIS
402 
404  struct TempStorage : Uninitialized<_TempStorage> {};
405 
406 
407  /******************************************************************/
411 
415  __device__ __forceinline__ BlockRadixSort()
416  :
417  temp_storage(PrivateStorage()),
418  linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
419  {}
420 
421 
425  __device__ __forceinline__ BlockRadixSort(
426  TempStorage &temp_storage)
427  :
428  temp_storage(temp_storage.Alias()),
429  linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
430  {}
431 
432 
434  /******************************************************************/
438 
476  __device__ __forceinline__ void Sort(
477  KeyT (&keys)[ITEMS_PER_THREAD],
478  int begin_bit = 0,
479  int end_bit = sizeof(KeyT) * 8)
480  {
481  NullType values[ITEMS_PER_THREAD];
482 
483  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
484  }
485 
486 
531  __device__ __forceinline__ void Sort(
532  KeyT (&keys)[ITEMS_PER_THREAD],
533  ValueT (&values)[ITEMS_PER_THREAD],
534  int begin_bit = 0,
535  int end_bit = sizeof(KeyT) * 8)
536  {
537  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
538  }
539 
577  __device__ __forceinline__ void SortDescending(
578  KeyT (&keys)[ITEMS_PER_THREAD],
579  int begin_bit = 0,
580  int end_bit = sizeof(KeyT) * 8)
581  {
582  NullType values[ITEMS_PER_THREAD];
583 
584  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
585  }
586 
587 
632  __device__ __forceinline__ void SortDescending(
633  KeyT (&keys)[ITEMS_PER_THREAD],
634  ValueT (&values)[ITEMS_PER_THREAD],
635  int begin_bit = 0,
636  int end_bit = sizeof(KeyT) * 8)
637  {
638  SortBlocked(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
639  }
640 
641 
643  /******************************************************************/
647 
648 
687  __device__ __forceinline__ void SortBlockedToStriped(
688  KeyT (&keys)[ITEMS_PER_THREAD],
689  int begin_bit = 0,
690  int end_bit = sizeof(KeyT) * 8)
691  {
692  NullType values[ITEMS_PER_THREAD];
693 
694  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
695  }
696 
697 
742  __device__ __forceinline__ void SortBlockedToStriped(
743  KeyT (&keys)[ITEMS_PER_THREAD],
744  ValueT (&values)[ITEMS_PER_THREAD],
745  int begin_bit = 0,
746  int end_bit = sizeof(KeyT) * 8)
747  {
748  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<false>(), Int2Type<KEYS_ONLY>());
749  }
750 
751 
790  __device__ __forceinline__ void SortDescendingBlockedToStriped(
791  KeyT (&keys)[ITEMS_PER_THREAD],
792  int begin_bit = 0,
793  int end_bit = sizeof(KeyT) * 8)
794  {
795  NullType values[ITEMS_PER_THREAD];
796 
797  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
798  }
799 
800 
845  __device__ __forceinline__ void SortDescendingBlockedToStriped(
846  KeyT (&keys)[ITEMS_PER_THREAD],
847  ValueT (&values)[ITEMS_PER_THREAD],
848  int begin_bit = 0,
849  int end_bit = sizeof(KeyT) * 8)
850  {
851  SortBlockedToStriped(keys, values, begin_bit, end_bit, Int2Type<true>(), Int2Type<KEYS_ONLY>());
852  }
853 
854 
856 
857 };
858 
863 } // CUB namespace
864 CUB_NS_POSTFIX // Optional outer namespace(s)
865