CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups
warp_scan.cuh
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2015, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
34 #pragma once
35 
36 #include "specializations/warp_scan_shfl.cuh"
37 #include "specializations/warp_scan_smem.cuh"
38 #include "../thread/thread_operators.cuh"
39 #include "../util_arch.cuh"
40 #include "../util_type.cuh"
41 #include "../util_namespace.cuh"
42 
44 CUB_NS_PREFIX
45 
47 namespace cub {
48 
142 template <
143  typename T,
144  int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
145  int PTX_ARCH = CUB_PTX_ARCH>
146 class WarpScan
147 {
148 private:
149 
150  /******************************************************************************
151  * Constants and type definitions
152  ******************************************************************************/
153 
154  enum
155  {
157  IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
158 
160  IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0),
161 
163  IS_INTEGER = ((Traits<T>::CATEGORY == SIGNED_INTEGER) || (Traits<T>::CATEGORY == UNSIGNED_INTEGER))
164  };
165 
167  typedef typename If<(PTX_ARCH >= 300) && (IS_POW_OF_TWO),
168  WarpScanShfl<T, LOGICAL_WARP_THREADS, PTX_ARCH>,
169  WarpScanSmem<T, LOGICAL_WARP_THREADS, PTX_ARCH> >::Type InternalWarpScan;
170 
172  typedef typename InternalWarpScan::TempStorage _TempStorage;
173 
174 
175  /******************************************************************************
176  * Thread fields
177  ******************************************************************************/
178 
180  _TempStorage &temp_storage;
181  int lane_id;
182 
183 
184 
185  /******************************************************************************
186  * Public types
187  ******************************************************************************/
188 
189 public:
190 
192  struct TempStorage : Uninitialized<_TempStorage> {};
193 
194 
195  /******************************************************************/
199 
203  __device__ __forceinline__ WarpScan(
204  TempStorage &temp_storage)
205  :
206  temp_storage(temp_storage.Alias()),
207  lane_id(IS_ARCH_WARP ?
208  LaneId() :
209  LaneId() % LOGICAL_WARP_THREADS)
210  {}
211 
212 
214  /******************************************************************/
218 
219 
254  __device__ __forceinline__ void InclusiveSum(
255  T input,
256  T &output)
257  {
258  InternalWarpScan(temp_storage).InclusiveScan(input, output, cub::Sum());
259  }
260 
261 
297  __device__ __forceinline__ void InclusiveSum(
298  T input,
299  T &output,
300  T &warp_aggregate)
301  {
302  InternalWarpScan(temp_storage).InclusiveScan(input, output, cub::Sum(), warp_aggregate);
303  }
304 
305 
307  /******************************************************************/
311 
312 
349  __device__ __forceinline__ void ExclusiveSum(
350  T input,
351  T &output)
352  {
353  InternalWarpScan(temp_storage).ExclusiveScan(input, output, ZeroInitialize<T>(), cub::Sum());
354  }
355 
356 
393  __device__ __forceinline__ void ExclusiveSum(
394  T input,
395  T &output,
396  T &warp_aggregate)
397  {
398  InternalWarpScan(temp_storage).ExclusiveScan(input, output, ZeroInitialize<T>(), cub::Sum(), warp_aggregate);
399  }
400 
401 
403  /******************************************************************/
407 
444  template <typename ScanOp>
445  __device__ __forceinline__ void InclusiveScan(
446  T input,
447  T &output,
448  ScanOp scan_op)
449  {
450  InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op);
451  }
452 
453 
494  template <typename ScanOp>
495  __device__ __forceinline__ void InclusiveScan(
496  T input,
497  T &output,
498  ScanOp scan_op,
499  T &warp_aggregate)
500  {
501  InternalWarpScan(temp_storage).InclusiveScan(input, output, scan_op, warp_aggregate);
502  }
503 
504 
506  /******************************************************************/
510 
547  template <typename ScanOp>
548  __device__ __forceinline__ void ExclusiveScan(
549  T input,
550  T &output,
551  T identity,
552  ScanOp scan_op)
553  {
554  InternalWarpScan(temp_storage).ExclusiveScan(input, output, identity, scan_op);
555  }
556 
557 
597  template <typename ScanOp>
598  __device__ __forceinline__ void ExclusiveScan(
599  T input,
600  T &output,
601  T identity,
602  ScanOp scan_op,
603  T &warp_aggregate)
604  {
605  InternalWarpScan(temp_storage).ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
606  }
607 
608 
610  /******************************************************************/
614 
615 
653  template <typename ScanOp>
654  __device__ __forceinline__ void ExclusiveScan(
655  T input,
656  T &output,
657  ScanOp scan_op)
658  {
659  InternalWarpScan(temp_storage).ExclusiveScan(input, output, scan_op);
660  }
661 
662 
702  template <typename ScanOp>
703  __device__ __forceinline__ void ExclusiveScan(
704  T input,
705  T &output,
706  ScanOp scan_op,
707  T &warp_aggregate)
708  {
709  InternalWarpScan(temp_storage).ExclusiveScan(input, output, scan_op, warp_aggregate);
710  }
711 
712 
713 
715  /******************************************************************/
719 
759  __device__ __forceinline__ void Sum(
760  T input,
761  T &inclusive_output,
762  T &exclusive_output)
763  {
764  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, ZeroInitialize<T>(), cub::Sum());
765  }
766 
767 
807  template <typename ScanOp>
808  __device__ __forceinline__ void Scan(
809  T input,
810  T &inclusive_output,
811  T &exclusive_output,
812  T identity,
813  ScanOp scan_op)
814  {
815  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, identity, scan_op);
816  }
817 
818 
858  template <typename ScanOp>
859  __device__ __forceinline__ void Scan(
860  T input,
861  T &inclusive_output,
862  T &exclusive_output,
863  ScanOp scan_op)
864  {
865  InternalWarpScan(temp_storage).Scan(input, inclusive_output, exclusive_output, scan_op);
866  }
867 
869  /******************************************************************/
873 
910  __device__ __forceinline__ T Broadcast(
911  T input,
912  unsigned int src_lane)
913  {
914  return InternalWarpScan(temp_storage).Broadcast(input, src_lane);
915  }
916 
918 
919 };
920  // end group WarpModule
922 
923 } // CUB namespace
924 CUB_NS_POSTFIX // Optional outer namespace(s)