5 #ifndef FML_GPU_ARCH_CUDA_GPULAPACK_H
6 #define FML_GPU_ARCH_CUDA_GPULAPACK_H
11 #include <cusolverDn.h>
20 inline std::string get_cusolver_error_msg(cusolverStatus_t check)
24 if (check == CUSOLVER_STATUS_NOT_INITIALIZED)
25 return "cuSOLVER not initialized";
26 else if (check == CUSOLVER_STATUS_ALLOC_FAILED)
27 return "internal cuSOLVER memory allocation failed";
28 else if (check == CUSOLVER_STATUS_INVALID_VALUE)
29 return "unsupported parameter";
30 else if (check == CUSOLVER_STATUS_ARCH_MISMATCH)
31 return "function requires feature missing from device architecture";
32 else if (check == CUSOLVER_STATUS_EXECUTION_FAILED)
33 return "GPU program failed to execute";
34 else if (check == CUSOLVER_STATUS_INTERNAL_ERROR)
35 return "internal cuSOLVER operation failed";
36 else if (check == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
37 return "matrix type not supported";
39 return "unknown cuSOLVER error occurred";
42 inline void check_ret(cusolverStatus_t check, std::string op)
44 if (check != CUSOLVER_STATUS_SUCCESS)
46 std::string msg =
"cuSOLVER " + op +
"() failed with error: " + get_cusolver_error_msg(check);
47 throw std::runtime_error(msg);
54 inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle,
int m,
int n,
55 float *A,
int lda,
int *lwork)
57 return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
60 inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle,
int m,
int n,
61 double *A,
int lda,
int *lwork)
63 return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
66 inline cusolverStatus_t getrf(cusolverDnHandle_t handle,
int m,
int n,
67 float *A,
int lda,
float *work,
int *ipiv,
int *info)
69 return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
72 inline cusolverStatus_t getrf(cusolverDnHandle_t handle,
int m,
int n,
73 double *A,
int lda,
double *work,
int *ipiv,
int *info)
75 return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
80 inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle,
int m,
int n,
84 return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
87 inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle,
int m,
int n,
88 double *A,
int *lwork)
91 return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
94 inline cusolverStatus_t gesvd(cusolverDnHandle_t handle,
signed char jobu,
95 signed char jobvt,
const int m,
const int n,
float *A,
const int lda,
96 float *S,
float *U,
const int ldu,
float *VT,
const int ldvt,
float *work,
97 const int lwork,
float *rwork,
int *info)
99 return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
100 ldvt, work, lwork, rwork, info);
103 inline cusolverStatus_t gesvd(cusolverDnHandle_t handle,
signed char jobu,
104 signed char jobvt,
const int m,
const int n,
double *A,
const int lda,
105 double *S,
double *U,
const int ldu,
double *VT,
const int ldvt,
double *work,
106 const int lwork,
double *rwork,
int *info)
108 return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
109 ldvt, work, lwork, rwork, info);
114 inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
115 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const float *A,
116 int lda,
const float *W,
int *lwork)
118 return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
122 inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
123 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const double *A,
124 int lda,
const double *W,
int *lwork)
126 return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
130 inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
131 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
float *A,
int lda,
132 float *W,
float *work,
int lwork,
int *devInfo)
134 return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
138 inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
139 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
double *A,
int lda,
140 double *W,
double *work,
int lwork,
int *devInfo)
142 return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
148 inline cublasStatus_t getri_batched(cublasHandle_t handle,
const int n,
149 const float **Aarray,
const int lda,
const int *devIpiv,
float **Carray,
150 const int ldb,
int *info,
const int batchSize)
152 return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
156 inline cublasStatus_t getri_batched(cublasHandle_t handle,
const int n,
157 const double **Aarray,
const int lda,
const int *devIpiv,
double **Carray,
158 const int ldb,
int *info,
const int batchSize)
160 return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
166 inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
167 cublasOperation_t trans,
const int n,
const int nrhs,
const float *A,
168 const int lda,
const int *devIpiv,
float *B,
const int ldb,
int *info)
170 return cusolverDnSgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
174 inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
175 cublasOperation_t trans,
const int n,
const int nrhs,
const double *A,
176 const int lda,
const int *devIpiv,
double *B,
const int ldb,
int *info)
178 return cusolverDnDgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
184 inline cusolverStatus_t geqrf_buflen(cusolverDnHandle_t handle,
185 const int m,
const int n,
float *A,
const int lda,
int *lwork)
187 return cusolverDnSgeqrf_bufferSize(handle, m, n, A, lda, lwork);
190 inline cusolverStatus_t geqrf_buflen(cusolverDnHandle_t handle,
191 const int m,
const int n,
double *A,
const int lda,
int *lwork)
193 return cusolverDnDgeqrf_bufferSize(handle, m, n, A, lda, lwork);
196 inline cusolverStatus_t geqrf(cusolverDnHandle_t handle,
const int m,
197 const int n,
float *A,
const int lda,
float *tau,
float *work,
198 const int lwork,
int *info)
200 return cusolverDnSgeqrf(handle, m, n, A, lda, tau, work, lwork, info);
203 inline cusolverStatus_t geqrf(cusolverDnHandle_t handle,
const int m,
204 const int n,
double *A,
const int lda,
double *tau,
double *work,
205 const int lwork,
int *info)
207 return cusolverDnDgeqrf(handle, m, n, A, lda, tau, work, lwork, info);
212 inline cusolverStatus_t ormqr_buflen(cusolverDnHandle_t handle,
213 cublasSideMode_t side, cublasOperation_t trans,
const int m,
const int n,
214 const int k,
const float *A,
const int lda,
const float *tau,
215 const float *C,
const int ldc,
int *lwork)
217 return cusolverDnSormqr_bufferSize(handle, side, trans, m, n, k, A, lda,
221 inline cusolverStatus_t ormqr_buflen(cusolverDnHandle_t handle,
222 cublasSideMode_t side, cublasOperation_t trans,
const int m,
const int n,
223 const int k,
const double *A,
const int lda,
const double *tau,
224 const double *C,
const int ldc,
int *lwork)
226 return cusolverDnDormqr_bufferSize(handle, side, trans, m, n, k, A, lda,
230 inline cusolverStatus_t ormqr(cusolverDnHandle_t handle,
231 cublasSideMode_t side, cublasOperation_t trans,
const int m,
const int n,
232 const int k,
const float *A,
const int lda,
const float *tau,
float *C,
233 const int ldc,
float *work,
const int lwork,
int *info)
235 return cusolverDnSormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc,
239 inline cusolverStatus_t ormqr(cusolverDnHandle_t handle,
240 cublasSideMode_t side, cublasOperation_t trans,
const int m,
const int n,
241 const int k,
const double *A,
const int lda,
const double *tau,
double *C,
242 const int ldc,
double *work,
const int lwork,
int *info)
244 return cusolverDnDormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc,
250 inline cusolverStatus_t orgqr_buflen(cusolverDnHandle_t handle,
int m,
251 int n,
int k,
const float *A,
int lda,
const float *tau,
int *lwork)
253 return cusolverDnSorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork);
256 inline cusolverStatus_t orgqr_buflen(cusolverDnHandle_t handle,
int m,
257 int n,
int k,
const double *A,
int lda,
const double *tau,
int *lwork)
259 return cusolverDnDorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork);
262 inline cusolverStatus_t orgqr(cusolverDnHandle_t handle,
int m,
263 int n,
int k,
float *A,
int lda,
const float *tau,
float *work,
int lwork,
266 return cusolverDnSorgqr(handle, m, n, k, A, lda, tau, work, lwork, info);
269 inline cusolverStatus_t orgqr(cusolverDnHandle_t handle,
int m,
270 int n,
int k,
double *A,
int lda,
const double *tau,
double *work,
int lwork,
273 return cusolverDnDorgqr(handle, m, n, k, A, lda, tau, work, lwork, info);
278 inline cusolverStatus_t potrf_buflen(cusolverDnHandle_t handle,
279 cublasFillMode_t uplo,
int n,
float *A,
int lda,
int *lwork)
281 return cusolverDnSpotrf_bufferSize(handle, uplo, n, A, lda, lwork);
284 inline cusolverStatus_t potrf_buflen(cusolverDnHandle_t handle,
285 cublasFillMode_t uplo,
int n,
double *A,
int lda,
int *lwork)
287 return cusolverDnDpotrf_bufferSize(handle, uplo, n, A, lda, lwork);
290 inline cusolverStatus_t potrf(cusolverDnHandle_t handle,
291 cublasFillMode_t uplo,
const int n,
float *A,
const int lda,
float *work,
292 const int lwork,
int *info)
294 return cusolverDnSpotrf(handle, uplo, n, A, lda, work, lwork, info);
297 inline cusolverStatus_t potrf(cusolverDnHandle_t handle,
298 cublasFillMode_t uplo,
const int n,
double *A,
const int lda,
double *work,
299 const int lwork,
int *info)
301 return cusolverDnDpotrf(handle, uplo, n, A, lda, work, lwork, info);
306 inline cusolverStatus_t potri_buflen(cusolverDnHandle_t handle,
307 cublasFillMode_t uplo,
int n,
float *A,
int lda,
int *lwork)
309 return cusolverDnSpotri_bufferSize(handle, uplo, n, A, lda, lwork);
312 inline cusolverStatus_t potri_buflen(cusolverDnHandle_t handle,
313 cublasFillMode_t uplo,
int n,
double *A,
int lda,
int *lwork)
315 return cusolverDnDpotri_bufferSize(handle, uplo, n, A, lda, lwork);
318 inline cusolverStatus_t potri(cusolverDnHandle_t handle,
319 cublasFillMode_t uplo,
const int n,
float *A,
const int lda,
float *work,
320 const int lwork,
int *info)
322 return cusolverDnSpotri(handle, uplo, n, A, lda, work, lwork, info);
325 inline cusolverStatus_t potri(cusolverDnHandle_t handle,
326 cublasFillMode_t uplo,
const int n,
double *A,
const int lda,
double *work,
327 const int lwork,
int *info)
329 return cusolverDnDpotri(handle, uplo, n, A, lda, work, lwork, info);