5 #ifndef FML_GPU_ARCH_HIP_GPURAND_H
6 #define FML_GPU_ARCH_HIP_GPURAND_H
10 #include <hiprand.hpp>
17 static const hiprandRngType_t gen_type = HIPRAND_RNG_PSEUDO_MTGP32;
24 static inline void check_init(hiprandStatus_t st)
26 if (st != HIPRAND_STATUS_SUCCESS)
27 throw std::runtime_error(
"unable to initialize GPU generator");
30 static inline void check_seed_set(hiprandStatus_t st)
32 if (st != HIPRAND_STATUS_SUCCESS)
33 throw std::runtime_error(
"unable to set GPU generator seed");
36 static inline void check_generation(hiprandStatus_t st)
38 if (st != HIPRAND_STATUS_SUCCESS)
39 throw std::runtime_error(
"unable to utilize GPU generator");
47 static inline hiprandStatus_t gpu_rand_unif(hiprandGenerator_t generator,
float *outputPtr,
size_t num)
49 return hiprandGenerateUniform(generator, outputPtr, num);
52 static inline hiprandStatus_t gpu_rand_unif(hiprandGenerator_t generator,
double *outputPtr,
size_t num)
54 return hiprandGenerateUniformDouble(generator, outputPtr, num);
59 static inline hiprandStatus_t gpu_rand_norm(hiprandGenerator_t generator,
float *outputPtr,
size_t num,
float mean,
float stddev)
61 return hiprandGenerateNormal(generator, outputPtr, num, mean, stddev);
64 static inline hiprandStatus_t gpu_rand_norm(hiprandGenerator_t generator,
double *outputPtr,
size_t num,
double mean,
double stddev)
66 return hiprandGenerateNormalDouble(generator, outputPtr, num, mean, stddev);
72 template <
typename REAL>
73 inline void gen_runif(
const uint32_t seed,
const size_t len, REAL *x)
76 hiprandGenerator_t gen;
78 st = hiprandCreateGenerator(&gen, defs::gen_type);
81 st = hiprandSetPseudoRandomGeneratorSeed(gen, seed);
82 err::check_seed_set(st);
83 st = generics::gpu_rand_unif(gen, x, len);
84 err::check_generation(st);
86 hiprandDestroyGenerator(gen);
91 template <
typename REAL>
92 inline void gen_rnorm(
const uint32_t seed,
const REAL mean,
const REAL sd,
const size_t len, REAL *x)
95 hiprandGenerator_t gen;
97 st = hiprandCreateGenerator(&gen, defs::gen_type);
100 st = hiprandSetPseudoRandomGeneratorSeed(gen, seed);
101 err::check_seed_set(st);
102 st = generics::gpu_rand_norm(gen, x, len, mean, sd);
103 err::check_generation(st);
105 hiprandDestroyGenerator(gen);