fml  0.1-0
Fused Matrix Library
culapack.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_INTERNALS_CUDA_CUSOLVER_H
6 #define FML_GPU_INTERNALS_CUDA_CUSOLVER_H
7 #pragma once
8 
9 
10 #include <cublas.h>
11 #include <cusolverDn.h>
12 
13 
14 namespace culapack
15 {
16  inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
17  cublasOperation_t transb, int m, int n, int k, const __half alpha,
18  const __half *A, int lda, const __half *B, int ldb, const __half beta,
19  __half *C, int ldc)
20  {
21  return cublasHgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
22  &beta, C, ldc);
23  }
24 
25  inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
26  cublasOperation_t transb, int m, int n, int k, const float alpha,
27  const float *A, int lda, const float *B, int ldb, const float beta,
28  float *C, int ldc)
29  {
30  return cublasSgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
31  &beta, C, ldc);
32  }
33 
34  inline cublasStatus_t gemm(cublasHandle_t handle, cublasOperation_t transa,
35  cublasOperation_t transb, int m, int n, int k, const double alpha,
36  const double *A, int lda, const double *B, int ldb, const double beta,
37  double *C, int ldc)
38  {
39  return cublasDgemm(handle, transa, transb, m, n, k, &alpha, A, lda, B, ldb,
40  &beta, C, ldc);
41  }
42 
43 
44 
45  inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
46  cublasOperation_t trans, int n, int k, const float alpha, const float *A,
47  int lda, const float beta, float *C, int ldc)
48  {
49  return cublasSsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
50  }
51 
52  inline cublasStatus_t syrk(cublasHandle_t handle, cublasFillMode_t uplo,
53  cublasOperation_t trans, int n, int k, const double alpha, const double *A,
54  int lda, const double beta, double *C, int ldc)
55  {
56  return cublasDsyrk(handle, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc);
57  }
58 
59 
60 
61  inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
62  cublasOperation_t transb, int m, int n, const float alpha, const float *A,
63  int lda, const float beta, const float *B, int ldb, float *C, int ldc)
64  {
65  return cublasSgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
66  ldb, C, ldc);
67  }
68 
69  inline cublasStatus_t geam(cublasHandle_t handle, cublasOperation_t transa,
70  cublasOperation_t transb, int m, int n, const double alpha, const double *A,
71  int lda, const double beta, const double *B, int ldb, double *C, int ldc)
72  {
73  return cublasDgeam(handle, transa, transb, m, n, &alpha, A, lda, &beta, B,
74  ldb, C, ldc);
75  }
76 
77 
78 
79  inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle, int m, int n,
80  float *A, int lda, int *lwork)
81  {
82  return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
83  }
84 
85  inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle, int m, int n,
86  double *A, int lda, int *lwork)
87  {
88  return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
89  }
90 
91  inline cusolverStatus_t getrf(cusolverDnHandle_t handle, int m, int n,
92  float *A, int lda, float *work, int *ipiv, int *info)
93  {
94  return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
95  }
96 
97  inline cusolverStatus_t getrf(cusolverDnHandle_t handle, int m, int n,
98  double *A, int lda, double *work, int *ipiv, int *info)
99  {
100  return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
101  }
102 
103 
104 
105  inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle, int m, int n,
106  float *A, int *lwork)
107  {
108  (void)A;
109  return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
110  }
111 
112  inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle, int m, int n,
113  double *A, int *lwork)
114  {
115  (void)A;
116  return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
117  }
118 
119  inline cusolverStatus_t gesvd(cusolverDnHandle_t handle, signed char jobu,
120  signed char jobvt, const int m, const int n, float *A, const int lda,
121  float *S, float *U, const int ldu, float *VT, const int ldvt, float *work,
122  const int lwork, float *rwork, int *info)
123  {
124  return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
125  ldvt, work, lwork, rwork, info);
126  }
127 
128  inline cusolverStatus_t gesvd(cusolverDnHandle_t handle, signed char jobu,
129  signed char jobvt, const int m, const int n, double *A, const int lda,
130  double *S, double *U, const int ldu, double *VT, const int ldvt, double *work,
131  const int lwork, double *rwork, int *info)
132  {
133  return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
134  ldvt, work, lwork, rwork, info);
135  }
136 
137 
138 
139  inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
140  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const float *A,
141  int lda, const float *W, int *lwork)
142  {
143  return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
144  lwork);
145  }
146 
147  inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
148  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const double *A,
149  int lda, const double *W, int *lwork)
150  {
151  return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
152  lwork);
153  }
154 
155  inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
156  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, float *A, int lda,
157  float *W, float *work, int lwork, int *devInfo)
158  {
159  return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
160  devInfo);
161  }
162 
163  inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
164  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, double *A, int lda,
165  double *W, double *work, int lwork, int *devInfo)
166  {
167  return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
168  devInfo);
169  }
170 
171 
172 
173  inline cublasStatus_t getri_batched(cublasHandle_t handle, const int n,
174  const float **Aarray, const int lda, const int *devIpiv, float **Carray,
175  const int ldb, int *info, const int batchSize)
176  {
177  return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
178  info, batchSize);
179  }
180 
181  inline cublasStatus_t getri_batched(cublasHandle_t handle, const int n,
182  const double **Aarray, const int lda, const int *devIpiv, double **Carray,
183  const int ldb, int *info, const int batchSize)
184  {
185  return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
186  info, batchSize);
187  }
188 
189 
190 
191  inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
192  cublasOperation_t trans, const int n, const int nrhs, const float *A,
193  const int lda, const int *devIpiv, float *B, const int ldb, int *info)
194  {
195  return cusolverDnSgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
196  info);
197  }
198 
199  inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
200  cublasOperation_t trans, const int n, const int nrhs, const double *A,
201  const int lda, const int *devIpiv, double *B, const int ldb, int *info)
202  {
203  return cusolverDnDgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
204  info);
205  }
206 }
207 
208 
209 #endif
Definition: culapack.hh:14