5 #ifndef FML_GPU_ARCH_CUDA_GPUBLAS_H
6 #define FML_GPU_ARCH_CUDA_GPUBLAS_H
19 inline std::string get_cublas_error_msg(cublasStatus_t check)
21 if (check == CUBLAS_STATUS_SUCCESS)
23 else if (check == CUBLAS_STATUS_NOT_INITIALIZED)
24 return "cuBLAS not initialized";
25 else if (check == CUBLAS_STATUS_ALLOC_FAILED)
26 return "internal cuBLAS memory allocation failed";
27 else if (check == CUBLAS_STATUS_INVALID_VALUE)
28 return "unsupported parameter";
29 else if (check == CUBLAS_STATUS_ARCH_MISMATCH)
30 return "function requires feature missing from device architecture";
31 else if (check == CUBLAS_STATUS_MAPPING_ERROR)
32 return "access to GPU memory space failed";
33 else if (check == CUBLAS_STATUS_EXECUTION_FAILED)
34 return "GPU program failed to execute";
35 else if (check == CUBLAS_STATUS_INTERNAL_ERROR)
36 return "internal cuBLAS operation failed";
37 else if (check == CUBLAS_STATUS_NOT_SUPPORTED)
38 return "requested functionality is not supported";
39 else if (check == CUBLAS_STATUS_LICENSE_ERROR)
40 return "error with cuBLAS license check";
42 return "unknown cuBLAS error occurred";
45 inline void check_ret(cublasStatus_t check, std::string op)
47 if (check != CUBLAS_STATUS_SUCCESS)
49 std::string msg =
"cuBLAS " + op +
"() failed with error: " + get_cublas_error_msg(check);
50 throw std::runtime_error(msg);
57 inline cublasStatus_t set_math_mode(cublasHandle_t handle, cublasMath_t mode)
59 return cublasSetMathMode(handle, mode);
62 inline cublasStatus_t get_math_mode(cublasHandle_t handle, cublasMath_t *mode)
64 return cublasGetMathMode(handle, mode);
67 inline std::string get_math_mode_string(cublasHandle_t handle)
70 cublasStatus_t check = get_math_mode(handle, &mode);
71 err::check_ret(check,
"cublasGetMathMode");
74 if (mode == CUBLAS_DEFAULT_MATH)
76 #if __CUDACC_VER_MAJOR__ >= 11
77 else if (mode == CUBLAS_PEDANTIC_MATH)
79 else if (mode == CUBLAS_TF32_TENSOR_OP_MATH)
80 ret =
"TF32 tensor op";
81 else if (mode == CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION)
82 ret =
"disallow reduced precision";
84 else if (mode == CUBLAS_TENSOR_OP_MATH)
88 throw std::runtime_error(
"unable to determine cuBLAS math mode");
95 inline cublasStatus_t Iamax(cublasHandle_t handle,
int n,
const float *x,
96 int incx,
int *result)
98 return cublasIsamax(handle, n, x, incx, result);
101 inline cublasStatus_t Iamax(cublasHandle_t handle,
int n,
const double *x,
102 int incx,
int *result)
104 return cublasIdamax(handle, n, x, incx, result);
109 inline cublasStatus_t Iamin(cublasHandle_t handle,
int n,
const float *x,
110 int incx,
int *result)
112 return cublasIsamin(handle, n, x, incx, result);
115 inline cublasStatus_t Iamin(cublasHandle_t handle,
int n,
const double *x,
116 int incx,
int *result)
118 return cublasIdamin(handle, n, x, incx, result);
123 inline cublasStatus_t
dot(cublasHandle_t handle,
int n,
const float *x,
124 int incx,
const float *y,
int incy,
float *result)
126 return cublasSdot(handle, n, x, incx, y, incy, result);
129 inline cublasStatus_t
dot(cublasHandle_t handle,
int n,
const double *x,
130 int incx,
const double *y,
int incy,
double *result)
132 return cublasDdot(handle, n, x, incx, y, incy, result);
137 inline cublasStatus_t axpy(cublasHandle_t handle,
int n,
const float *alpha,
138 const float *x,
int incx,
float *y,
int incy)
140 return cublasSaxpy(handle, n, alpha, x, incx, y, incy);
143 inline cublasStatus_t axpy(cublasHandle_t handle,
int n,
const double *alpha,
144 const double *x,
int incx,
double *y,
int incy)
146 return cublasDaxpy(handle, n, alpha, x, incx, y, incy);
151 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
152 cublasOperation_t transb,
int m,
int n,
int k,
const __half alpha,
153 const __half *A,
int lda,
const __half *B,
int ldb,
const __half beta,
156 return cublasHgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
160 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
161 cublasOperation_t transb,
int m,
int n,
int k,
const float alpha,
162 const float *A,
int lda,
const float *B,
int ldb,
const float beta,
165 return cublasSgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
169 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
170 cublasOperation_t transb,
int m,
int n,
int k,
const double alpha,
171 const double *A,
int lda,
const double *B,
int ldb,
const double beta,
174 return cublasDgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
180 inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
181 cublasOperation_t trans,
int n,
int k,
const float alpha,
const float *A,
182 int lda,
const float beta,
float *C,
int ldc)
184 return cublasSsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
187 inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
188 cublasOperation_t trans,
int n,
int k,
const double alpha,
const double *A,
189 int lda,
const double beta,
double *C,
int ldc)
191 return cublasDsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
196 inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
197 cublasOperation_t transb,
int m,
int n,
const float alpha,
const float *A,
198 int lda,
const float beta,
const float *B,
int ldb,
float *C,
int ldc)
200 return cublasSgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
204 inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
205 cublasOperation_t transb,
int m,
int n,
const double alpha,
const double *A,
206 int lda,
const double beta,
const double *B,
int ldb,
double *C,
int ldc)
208 return cublasDgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
214 inline cublasStatus_t trsm(cublasHandle_t handle, cublasSideMode_t side,
215 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag,
216 int m,
int n,
const float alpha,
const float *A,
int lda,
float *B,
219 return cublasStrsm(handle, side, uplo, trans, diag, m, n, &alpha, A, lda,
223 inline cublasStatus_t trsm(cublasHandle_t handle, cublasSideMode_t side,
224 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag,
225 int m,
int n,
const double alpha,
const double *A,
int lda,
double *B,
228 return cublasDtrsm(handle, side, uplo, trans, diag, m, n, &alpha, A, lda,
234 inline cublasStatus_t trmm(cublasHandle_t handle, cublasSideMode_t side,
235 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag,
236 int m,
int n,
const float alpha,
const float *A,
int lda,
const float *B,
237 int ldb,
float *C,
int ldc)
239 return cublasStrmm(handle, side, uplo, trans, diag, m, n, &alpha, A, lda,
243 inline cublasStatus_t trmm(cublasHandle_t handle, cublasSideMode_t side,
244 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag,
245 int m,
int n,
const double alpha,
const double *A,
int lda,
const double *B,
246 int ldb,
double *C,
int ldc)
248 return cublasDtrmm(handle, side, uplo, trans, diag, m, n, &alpha, A, lda,