5 #ifndef FML_GPU_ARCH_CUDA_GPURAND_H
6 #define FML_GPU_ARCH_CUDA_GPURAND_H
19 static const curandRngType_t gen_type = CURAND_RNG_PSEUDO_MTGP32;
26 static inline void check_init(curandStatus_t st)
28 if (st != CURAND_STATUS_SUCCESS)
29 throw std::runtime_error(
"unable to initialize GPU generator");
32 static inline void check_seed_set(curandStatus_t st)
34 if (st != CURAND_STATUS_SUCCESS)
35 throw std::runtime_error(
"unable to set GPU generator seed");
38 static inline void check_generation(curandStatus_t st)
40 if (st != CURAND_STATUS_SUCCESS)
41 throw std::runtime_error(
"unable to utilize GPU generator");
49 static inline curandStatus_t gpu_rand_unif(curandGenerator_t generator,
float *outputPtr,
size_t num)
51 return curandGenerateUniform(generator, outputPtr, num);
54 static inline curandStatus_t gpu_rand_unif(curandGenerator_t generator,
double *outputPtr,
size_t num)
56 return curandGenerateUniformDouble(generator, outputPtr, num);
61 static inline curandStatus_t gpu_rand_norm(curandGenerator_t generator,
float *outputPtr,
size_t num,
float mean,
float stddev)
63 return curandGenerateNormal(generator, outputPtr, num, mean, stddev);
66 static inline curandStatus_t gpu_rand_norm(curandGenerator_t generator,
double *outputPtr,
size_t num,
double mean,
double stddev)
68 return curandGenerateNormalDouble(generator, outputPtr, num, mean, stddev);
74 template <
typename REAL>
75 inline void gen_runif(
const uint32_t seed,
const size_t len, REAL *x)
78 curandGenerator_t gen;
80 st = curandCreateGenerator(&gen, defs::gen_type);
83 st = curandSetPseudoRandomGeneratorSeed(gen, seed);
84 err::check_seed_set(st);
85 st = generics::gpu_rand_unif(gen, x, len);
86 err::check_generation(st);
88 curandDestroyGenerator(gen);
93 template <
typename REAL>
94 inline void gen_rnorm(
const uint32_t seed,
const REAL mean,
const REAL sd,
const size_t len, REAL *x)
97 curandGenerator_t gen;
99 st = curandCreateGenerator(&gen, defs::gen_type);
102 st = curandSetPseudoRandomGeneratorSeed(gen, seed);
103 err::check_seed_set(st);
104 st = generics::gpu_rand_norm(gen, x, len, mean, sd);
105 err::check_generation(st);
107 curandDestroyGenerator(gen);