CUB
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups
warp_reduce.cuh
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2015, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
34 #pragma once
35 
36 #include "specializations/warp_reduce_shfl.cuh"
37 #include "specializations/warp_reduce_smem.cuh"
38 #include "../thread/thread_operators.cuh"
39 #include "../util_arch.cuh"
40 #include "../util_type.cuh"
41 #include "../util_namespace.cuh"
42 
44 CUB_NS_PREFIX
45 
47 namespace cub {
48 
49 
137 template <
138  typename T,
139  int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
140  int PTX_ARCH = CUB_PTX_ARCH>
142 {
143 private:
144 
145  /******************************************************************************
146  * Constants and type definitions
147  ******************************************************************************/
148 
149  enum
150  {
152  IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
153 
156  };
157 
158 public:
159 
160  #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
161 
163  typedef typename If<(PTX_ARCH >= 300) && (IS_POW_OF_TWO),
164  WarpReduceShfl<T, LOGICAL_WARP_THREADS, PTX_ARCH>,
165  WarpReduceSmem<T, LOGICAL_WARP_THREADS, PTX_ARCH> >::Type InternalWarpReduce;
166 
167  #endif // DOXYGEN_SHOULD_SKIP_THIS
168 
169 
170 private:
171 
173  typedef typename InternalWarpReduce::TempStorage _TempStorage;
174 
175 
176  /******************************************************************************
177  * Thread fields
178  ******************************************************************************/
179 
181  _TempStorage &temp_storage;
182 
183 
184  /******************************************************************************
185  * Utility methods
186  ******************************************************************************/
187 
188 public:
189 
191  struct TempStorage : Uninitialized<_TempStorage> {};
192 
193 
194  /******************************************************************/
198 
199 
203  __device__ __forceinline__ WarpReduce(
204  TempStorage &temp_storage)
205  :
206  temp_storage(temp_storage.Alias())
207  {}
208 
209 
211  /******************************************************************/
215 
216 
251  __device__ __forceinline__ T Sum(
252  T input)
253  {
254  return InternalWarpReduce(temp_storage).Reduce<true, 1>(input, LOGICAL_WARP_THREADS, cub::Sum());
255  }
256 
295  __device__ __forceinline__ T Sum(
296  T input,
297  int valid_items)
298  {
299  // Determine if we don't need bounds checking
300  return InternalWarpReduce(temp_storage).Reduce<false, 1>(input, valid_items, cub::Sum());
301  }
302 
303 
342  template <
343  typename FlagT>
344  __device__ __forceinline__ T HeadSegmentedSum(
345  T input,
346  FlagT head_flag)
347  {
348  return HeadSegmentedReduce(input, head_flag, cub::Sum());
349  }
350 
351 
389  template <
390  typename FlagT>
391  __device__ __forceinline__ T TailSegmentedSum(
392  T input,
393  FlagT tail_flag)
394  {
395  return TailSegmentedReduce(input, tail_flag, cub::Sum());
396  }
397 
398 
399 
401  /******************************************************************/
405 
444  template <typename ReductionOp>
445  __device__ __forceinline__ T Reduce(
446  T input,
447  ReductionOp reduction_op)
448  {
449  return InternalWarpReduce(temp_storage).Reduce<true, 1>(input, LOGICAL_WARP_THREADS, reduction_op);
450  }
451 
493  template <typename ReductionOp>
494  __device__ __forceinline__ T Reduce(
495  T input,
496  ReductionOp reduction_op,
497  int valid_items)
498  {
499  return InternalWarpReduce(temp_storage).Reduce<false, 1>(input, valid_items, reduction_op);
500  }
501 
502 
542  template <
543  typename ReductionOp,
544  typename FlagT>
545  __device__ __forceinline__ T HeadSegmentedReduce(
546  T input,
547  FlagT head_flag,
548  ReductionOp reduction_op)
549  {
550  return InternalWarpReduce(temp_storage).template SegmentedReduce<true>(input, head_flag, reduction_op);
551  }
552 
553 
593  template <
594  typename ReductionOp,
595  typename FlagT>
596  __device__ __forceinline__ T TailSegmentedReduce(
597  T input,
598  FlagT tail_flag,
599  ReductionOp reduction_op)
600  {
601  return InternalWarpReduce(temp_storage).template SegmentedReduce<false>(input, tail_flag, reduction_op);
602  }
603 
604 
605 
607 };
608  // end group WarpModule
610 
611 } // CUB namespace
612 CUB_NS_POSTFIX // Optional outer namespace(s)