5 #ifndef FML_GPU_ARCH_HIP_CUSOLVER_H
6 #define FML_GPU_ARCH_HIP_CUSOLVER_H
11 #include <rocsolver.h>
18 inline std::string get_rocsolver_error_msg(cusolverStatus_t check)
20 if (check == CUSOLVER_STATUS_SUCCESS)
22 else if (check == CUSOLVER_STATUS_NOT_INITIALIZED)
23 return "cuSOLVER not initialized";
24 else if (check == CUSOLVER_STATUS_ALLOC_FAILED)
25 return "internal cuSOLVER memory allocation failed";
26 else if (check == CUSOLVER_STATUS_INVALID_VALUE)
27 return "unsupported parameter";
28 else if (check == CUSOLVER_STATUS_ARCH_MISMATCH)
29 return "function requires feature missing from device architecture";
30 else if (check == CUSOLVER_STATUS_EXECUTION_FAILED)
31 return "GPU program failed to execute";
32 else if (check == CUSOLVER_STATUS_INTERNAL_ERROR)
33 return "internal cuSOLVER operation failed";
34 else if (check == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
35 return "matrix type not supported";
37 return "unknown cuSOLVER error occurred";
40 inline void check_gpusolver_ret(rocsolver_status check, std::string op)
42 if (check != CUSOLVER_STATUS_SUCCESS)
44 std::string msg =
"rocsolver " + op +
"() failed with error: " + get_rocsolver_error_msg(check);
45 throw std::runtime_error(msg);
52 inline rocsolver_status getrf_buflen(rocsolver_handle handle,
int m,
int n,
53 float *A,
int lda,
int *lwork)
55 return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
58 inline rocsolver_status getrf_buflen(rocsolver_handle handle,
int m,
int n,
59 double *A,
int lda,
int *lwork)
61 return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
64 inline rocsolver_status getrf(rocsolver_handle handle,
int m,
int n,
65 float *A,
int lda,
float *work,
int *ipiv,
int *info)
67 return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
70 inline rocsolver_status getrf(rocsolver_handle handle,
int m,
int n,
71 double *A,
int lda,
double *work,
int *ipiv,
int *info)
73 return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
78 inline rocsolver_status gesvd_buflen(rocsolver_handle handle,
int m,
int n,
82 return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
85 inline rocsolver_status gesvd_buflen(rocsolver_handle handle,
int m,
int n,
86 double *A,
int *lwork)
89 return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
92 inline rocsolver_status gesvd(rocsolver_handle handle,
signed char jobu,
93 signed char jobvt,
const int m,
const int n,
float *A,
const int lda,
94 float *S,
float *U,
const int ldu,
float *VT,
const int ldvt,
float *work,
95 const int lwork,
float *rwork,
int *info)
97 return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
98 ldvt, work, lwork, rwork, info);
101 inline rocsolver_status gesvd(rocsolver_handle handle,
signed char jobu,
102 signed char jobvt,
const int m,
const int n,
double *A,
const int lda,
103 double *S,
double *U,
const int ldu,
double *VT,
const int ldvt,
double *work,
104 const int lwork,
double *rwork,
int *info)
106 return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
107 ldvt, work, lwork, rwork, info);
112 inline rocsolver_status syevd_buflen(rocsolver_handle handle,
113 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const float *A,
114 int lda,
const float *W,
int *lwork)
116 return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
120 inline rocsolver_status syevd_buflen(rocsolver_handle handle,
121 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const double *A,
122 int lda,
const double *W,
int *lwork)
124 return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
128 inline rocsolver_status syevd(rocsolver_handle handle,
129 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
float *A,
int lda,
130 float *W,
float *work,
int lwork,
int *devInfo)
132 return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
136 inline rocsolver_status syevd(rocsolver_handle handle,
137 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
double *A,
int lda,
138 double *W,
double *work,
int lwork,
int *devInfo)
140 return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
146 inline rocblas_status getri_batched(rocblas_handle handle,
const int n,
147 const float **Aarray,
const int lda,
const int *devIpiv,
float **Carray,
148 const int ldb,
int *info,
const int batchSize)
150 return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
154 inline rocblas_status getri_batched(rocblas_handle handle,
const int n,
155 const double **Aarray,
const int lda,
const int *devIpiv,
double **Carray,
156 const int ldb,
int *info,
const int batchSize)
158 return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
164 inline rocsolver_status getrs(rocsolver_handle handle,
165 rocblas_operation trans,
const int n,
const int nrhs,
const float *A,
166 const int lda,
const int *devIpiv,
float *B,
const int ldb,
int *info)
168 return rocsolver_sgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
172 inline rocsolver_status getrs(rocsolver_handle handle,
173 rocblas_operation trans,
const int n,
const int nrhs,
const double *A,
174 const int lda,
const int *devIpiv,
double *B,
const int ldb,
int *info)
176 return rocsolver_dgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,