fml  0.1-0
Fused Matrix Library
gpulapack.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_ARCH_CUDA_GPULAPACK_H
6 #define FML_GPU_ARCH_CUDA_GPULAPACK_H
7 #pragma once
8 
9 
10 #include <cublas.h>
11 #include <cusolverDn.h>
12 
13 
14 namespace fml
15 {
16 namespace gpulapack
17 {
18  namespace err
19  {
20  inline std::string get_cusolver_error_msg(cusolverStatus_t check)
21  {
22  // if (check == CUSOLVER_STATUS_SUCCESS)
23  // return "";
24  if (check == CUSOLVER_STATUS_NOT_INITIALIZED)
25  return "cuSOLVER not initialized";
26  else if (check == CUSOLVER_STATUS_ALLOC_FAILED)
27  return "internal cuSOLVER memory allocation failed";
28  else if (check == CUSOLVER_STATUS_INVALID_VALUE)
29  return "unsupported parameter";
30  else if (check == CUSOLVER_STATUS_ARCH_MISMATCH)
31  return "function requires feature missing from device architecture";
32  else if (check == CUSOLVER_STATUS_EXECUTION_FAILED)
33  return "GPU program failed to execute";
34  else if (check == CUSOLVER_STATUS_INTERNAL_ERROR)
35  return "internal cuSOLVER operation failed";
36  else if (check == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
37  return "matrix type not supported";
38  else
39  return "unknown cuSOLVER error occurred";
40  }
41 
42  inline void check_ret(cusolverStatus_t check, std::string op)
43  {
44  if (check != CUSOLVER_STATUS_SUCCESS)
45  {
46  std::string msg = "cuSOLVER " + op + "() failed with error: " + get_cusolver_error_msg(check);
47  throw std::runtime_error(msg);
48  }
49  }
50  }
51 
52 
53 
54  inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle, int m, int n,
55  float *A, int lda, int *lwork)
56  {
57  return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
58  }
59 
60  inline cusolverStatus_t getrf_buflen(cusolverDnHandle_t handle, int m, int n,
61  double *A, int lda, int *lwork)
62  {
63  return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
64  }
65 
66  inline cusolverStatus_t getrf(cusolverDnHandle_t handle, int m, int n,
67  float *A, int lda, float *work, int *ipiv, int *info)
68  {
69  return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
70  }
71 
72  inline cusolverStatus_t getrf(cusolverDnHandle_t handle, int m, int n,
73  double *A, int lda, double *work, int *ipiv, int *info)
74  {
75  return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
76  }
77 
78 
79 
80  inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle, int m, int n,
81  float *A, int *lwork)
82  {
83  (void)A;
84  return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
85  }
86 
87  inline cusolverStatus_t gesvd_buflen(cusolverDnHandle_t handle, int m, int n,
88  double *A, int *lwork)
89  {
90  (void)A;
91  return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
92  }
93 
94  inline cusolverStatus_t gesvd(cusolverDnHandle_t handle, signed char jobu,
95  signed char jobvt, const int m, const int n, float *A, const int lda,
96  float *S, float *U, const int ldu, float *VT, const int ldvt, float *work,
97  const int lwork, float *rwork, int *info)
98  {
99  return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
100  ldvt, work, lwork, rwork, info);
101  }
102 
103  inline cusolverStatus_t gesvd(cusolverDnHandle_t handle, signed char jobu,
104  signed char jobvt, const int m, const int n, double *A, const int lda,
105  double *S, double *U, const int ldu, double *VT, const int ldvt, double *work,
106  const int lwork, double *rwork, int *info)
107  {
108  return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
109  ldvt, work, lwork, rwork, info);
110  }
111 
112 
113 
114  inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
115  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const float *A,
116  int lda, const float *W, int *lwork)
117  {
118  return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
119  lwork);
120  }
121 
122  inline cusolverStatus_t syevd_buflen(cusolverDnHandle_t handle,
123  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const double *A,
124  int lda, const double *W, int *lwork)
125  {
126  return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
127  lwork);
128  }
129 
130  inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
131  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, float *A, int lda,
132  float *W, float *work, int lwork, int *devInfo)
133  {
134  return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
135  devInfo);
136  }
137 
138  inline cusolverStatus_t syevd(cusolverDnHandle_t handle,
139  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, double *A, int lda,
140  double *W, double *work, int lwork, int *devInfo)
141  {
142  return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
143  devInfo);
144  }
145 
146 
147 
148  inline cublasStatus_t getri_batched(cublasHandle_t handle, const int n,
149  const float **Aarray, const int lda, const int *devIpiv, float **Carray,
150  const int ldb, int *info, const int batchSize)
151  {
152  return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
153  info, batchSize);
154  }
155 
156  inline cublasStatus_t getri_batched(cublasHandle_t handle, const int n,
157  const double **Aarray, const int lda, const int *devIpiv, double **Carray,
158  const int ldb, int *info, const int batchSize)
159  {
160  return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
161  info, batchSize);
162  }
163 
164 
165 
166  inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
167  cublasOperation_t trans, const int n, const int nrhs, const float *A,
168  const int lda, const int *devIpiv, float *B, const int ldb, int *info)
169  {
170  return cusolverDnSgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
171  info);
172  }
173 
174  inline cusolverStatus_t getrs(cusolverDnHandle_t handle,
175  cublasOperation_t trans, const int n, const int nrhs, const double *A,
176  const int lda, const int *devIpiv, double *B, const int ldb, int *info)
177  {
178  return cusolverDnDgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
179  info);
180  }
181 
182 
183 
184  inline cusolverStatus_t geqrf_buflen(cusolverDnHandle_t handle,
185  const int m, const int n, float *A, const int lda, int *lwork)
186  {
187  return cusolverDnSgeqrf_bufferSize(handle, m, n, A, lda, lwork);
188  }
189 
190  inline cusolverStatus_t geqrf_buflen(cusolverDnHandle_t handle,
191  const int m, const int n, double *A, const int lda, int *lwork)
192  {
193  return cusolverDnDgeqrf_bufferSize(handle, m, n, A, lda, lwork);
194  }
195 
196  inline cusolverStatus_t geqrf(cusolverDnHandle_t handle, const int m,
197  const int n, float *A, const int lda, float *tau, float *work,
198  const int lwork, int *info)
199  {
200  return cusolverDnSgeqrf(handle, m, n, A, lda, tau, work, lwork, info);
201  }
202 
203  inline cusolverStatus_t geqrf(cusolverDnHandle_t handle, const int m,
204  const int n, double *A, const int lda, double *tau, double *work,
205  const int lwork, int *info)
206  {
207  return cusolverDnDgeqrf(handle, m, n, A, lda, tau, work, lwork, info);
208  }
209 
210 
211 
212  inline cusolverStatus_t ormqr_buflen(cusolverDnHandle_t handle,
213  cublasSideMode_t side, cublasOperation_t trans, const int m, const int n,
214  const int k, const float *A, const int lda, const float *tau,
215  const float *C, const int ldc, int *lwork)
216  {
217  return cusolverDnSormqr_bufferSize(handle, side, trans, m, n, k, A, lda,
218  tau, C, ldc, lwork);
219  }
220 
221  inline cusolverStatus_t ormqr_buflen(cusolverDnHandle_t handle,
222  cublasSideMode_t side, cublasOperation_t trans, const int m, const int n,
223  const int k, const double *A, const int lda, const double *tau,
224  const double *C, const int ldc, int *lwork)
225  {
226  return cusolverDnDormqr_bufferSize(handle, side, trans, m, n, k, A, lda,
227  tau, C, ldc, lwork);
228  }
229 
230  inline cusolverStatus_t ormqr(cusolverDnHandle_t handle,
231  cublasSideMode_t side, cublasOperation_t trans, const int m, const int n,
232  const int k, const float *A, const int lda, const float *tau, float *C,
233  const int ldc, float *work, const int lwork, int *info)
234  {
235  return cusolverDnSormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc,
236  work, lwork, info);
237  }
238 
239  inline cusolverStatus_t ormqr(cusolverDnHandle_t handle,
240  cublasSideMode_t side, cublasOperation_t trans, const int m, const int n,
241  const int k, const double *A, const int lda, const double *tau, double *C,
242  const int ldc, double *work, const int lwork, int *info)
243  {
244  return cusolverDnDormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc,
245  work, lwork, info);
246  }
247 
248 
249 
250  inline cusolverStatus_t orgqr_buflen(cusolverDnHandle_t handle, int m,
251  int n, int k, const float *A, int lda, const float *tau, int *lwork)
252  {
253  return cusolverDnSorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork);
254  }
255 
256  inline cusolverStatus_t orgqr_buflen(cusolverDnHandle_t handle, int m,
257  int n, int k, const double *A, int lda, const double *tau, int *lwork)
258  {
259  return cusolverDnDorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork);
260  }
261 
262  inline cusolverStatus_t orgqr(cusolverDnHandle_t handle, int m,
263  int n, int k, float *A, int lda, const float *tau, float *work, int lwork,
264  int *info)
265  {
266  return cusolverDnSorgqr(handle, m, n, k, A, lda, tau, work, lwork, info);
267  }
268 
269  inline cusolverStatus_t orgqr(cusolverDnHandle_t handle, int m,
270  int n, int k, double *A, int lda, const double *tau, double *work, int lwork,
271  int *info)
272  {
273  return cusolverDnDorgqr(handle, m, n, k, A, lda, tau, work, lwork, info);
274  }
275 
276 
277 
278  inline cusolverStatus_t potrf_buflen(cusolverDnHandle_t handle,
279  cublasFillMode_t uplo, int n, float *A, int lda, int *lwork)
280  {
281  return cusolverDnSpotrf_bufferSize(handle, uplo, n, A, lda, lwork);
282  }
283 
284  inline cusolverStatus_t potrf_buflen(cusolverDnHandle_t handle,
285  cublasFillMode_t uplo, int n, double *A, int lda, int *lwork)
286  {
287  return cusolverDnDpotrf_bufferSize(handle, uplo, n, A, lda, lwork);
288  }
289 
290  inline cusolverStatus_t potrf(cusolverDnHandle_t handle,
291  cublasFillMode_t uplo, const int n, float *A, const int lda, float *work,
292  const int lwork, int *info)
293  {
294  return cusolverDnSpotrf(handle, uplo, n, A, lda, work, lwork, info);
295  }
296 
297  inline cusolverStatus_t potrf(cusolverDnHandle_t handle,
298  cublasFillMode_t uplo, const int n, double *A, const int lda, double *work,
299  const int lwork, int *info)
300  {
301  return cusolverDnDpotrf(handle, uplo, n, A, lda, work, lwork, info);
302  }
303 
304 
305 
306  inline cusolverStatus_t potri_buflen(cusolverDnHandle_t handle,
307  cublasFillMode_t uplo, int n, float *A, int lda, int *lwork)
308  {
309  return cusolverDnSpotri_bufferSize(handle, uplo, n, A, lda, lwork);
310  }
311 
312  inline cusolverStatus_t potri_buflen(cusolverDnHandle_t handle,
313  cublasFillMode_t uplo, int n, double *A, int lda, int *lwork)
314  {
315  return cusolverDnDpotri_bufferSize(handle, uplo, n, A, lda, lwork);
316  }
317 
318  inline cusolverStatus_t potri(cusolverDnHandle_t handle,
319  cublasFillMode_t uplo, const int n, float *A, const int lda, float *work,
320  const int lwork, int *info)
321  {
322  return cusolverDnSpotri(handle, uplo, n, A, lda, work, lwork, info);
323  }
324 
325  inline cusolverStatus_t potri(cusolverDnHandle_t handle,
326  cublasFillMode_t uplo, const int n, double *A, const int lda, double *work,
327  const int lwork, int *info)
328  {
329  return cusolverDnDpotri(handle, uplo, n, A, lda, work, lwork, info);
330  }
331 }
332 }
333 
334 
335 #endif
fml
Core namespace.
Definition: dimops.hh:10