5 #ifndef FML_GPU_INTERNALS_CUDA_CUSOLVER_H 6 #define FML_GPU_INTERNALS_CUDA_CUSOLVER_H 11 #include <cusolverDn.h> 16 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
17 cublasOperation_t transb,
int m,
int n,
int k,
const __half alpha,
18 const __half *A,
int lda,
const __half *B,
int ldb,
const __half beta,
21 return cublasHgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
25 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
26 cublasOperation_t transb,
int m,
int n,
int k,
const float alpha,
27 const float *A,
int lda,
const float *B,
int ldb,
const float beta,
30 return cublasSgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
34 inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
35 cublasOperation_t transb,
int m,
int n,
int k,
const double alpha,
36 const double *A,
int lda,
const double *B,
int ldb,
const double beta,
39 return cublasDgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
45 inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
46 cublasOperation_t trans,
int n,
int k,
const float alpha,
const float *A,
47 int lda,
const float beta,
float *C,
int ldc)
49 return cublasSsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
52 inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
53 cublasOperation_t trans,
int n,
int k,
const double alpha,
const double *A,
54 int lda,
const double beta,
double *C,
int ldc)
56 return cublasDsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
61 inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
62 cublasOperation_t transb,
int m,
int n,
const float alpha,
const float *A,
63 int lda,
const float beta,
const float *B,
int ldb,
float *C,
int ldc)
65 return cublasSgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
69 inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
70 cublasOperation_t transb,
int m,
int n,
const double alpha,
const double *A,
71 int lda,
const double beta,
const double *B,
int ldb,
double *C,
int ldc)
73 return cublasDgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
79 inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle,
int m,
int n,
80 float *A,
int lda,
int *lwork)
82 return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
85 inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle,
int m,
int n,
86 double *A,
int lda,
int *lwork)
88 return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
91 inline cusolverStatus_t getrf(cusolverDnHandle_t handle,
int m,
int n,
92 float *A,
int lda,
float *work,
int *ipiv,
int *info)
94 return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
97 inline cusolverStatus_t getrf(cusolverDnHandle_t handle,
int m,
int n,
98 double *A,
int lda,
double *work,
int *ipiv,
int *info)
100 return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
105 inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle,
int m,
int n,
106 float *A,
int *lwork)
109 return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
112 inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle,
int m,
int n,
113 double *A,
int *lwork)
116 return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
119 inline cusolverStatus_t gesvd(cusolverDnHandle_t handle,
signed char jobu,
120 signed char jobvt,
const int m,
const int n,
float *A,
const int lda,
121 float *S,
float *U,
const int ldu,
float *VT,
const int ldvt,
float *work,
122 const int lwork,
float *rwork,
int *info)
124 return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
125 ldvt, work, lwork, rwork, info);
128 inline cusolverStatus_t gesvd(cusolverDnHandle_t handle,
signed char jobu,
129 signed char jobvt,
const int m,
const int n,
double *A,
const int lda,
130 double *S,
double *U,
const int ldu,
double *VT,
const int ldvt,
double *work,
131 const int lwork,
double *rwork,
int *info)
133 return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
134 ldvt, work, lwork, rwork, info);
139 inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
140 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const float *A,
141 int lda,
const float *W,
int *lwork)
143 return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
147 inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
148 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
const double *A,
149 int lda,
const double *W,
int *lwork)
151 return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
155 inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
156 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
float *A,
int lda,
157 float *W,
float *work,
int lwork,
int *devInfo)
159 return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
163 inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
164 cusolverEigMode_t jobz, cublasFillMode_t uplo,
int n,
double *A,
int lda,
165 double *W,
double *work,
int lwork,
int *devInfo)
167 return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
173 inline cublasStatus_t getri_batched(cublasHandle_t handle,
const int n,
174 const float **Aarray,
const int lda,
const int *devIpiv,
float **Carray,
175 const int ldb,
int *info,
const int batchSize)
177 return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
181 inline cublasStatus_t getri_batched(cublasHandle_t handle,
const int n,
182 const double **Aarray,
const int lda,
const int *devIpiv,
double **Carray,
183 const int ldb,
int *info,
const int batchSize)
185 return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
191 inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
192 cublasOperation_t trans,
const int n,
const int nrhs,
const float *A,
193 const int lda,
const int *devIpiv,
float *B,
const int ldb,
int *info)
195 return cusolverDnSgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
199 inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
200 cublasOperation_t trans,
const int n,
const int nrhs,
const double *A,
201 const int lda,
const int *devIpiv,
double *B,
const int ldb,
int *info)
203 return cusolverDnDgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
Definition: culapack.hh:14