fml  0.1-0
Fused Matrix Library
gpulapack.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_ARCH_HIP_CUSOLVER_H
6 #define FML_GPU_ARCH_HIP_CUSOLVER_H
7 #pragma once
8 
9 
10 #include <rocblas.h>
11 #include <rocsolver.h>
12 
13 
14 namespace gpulapack
15 {
16  namespace err
17  {
18  inline std::string get_rocsolver_error_msg(cusolverStatus_t check)
19  {
20  if (check == CUSOLVER_STATUS_SUCCESS)
21  return "";
22  else if (check == CUSOLVER_STATUS_NOT_INITIALIZED)
23  return "cuSOLVER not initialized";
24  else if (check == CUSOLVER_STATUS_ALLOC_FAILED)
25  return "internal cuSOLVER memory allocation failed";
26  else if (check == CUSOLVER_STATUS_INVALID_VALUE)
27  return "unsupported parameter";
28  else if (check == CUSOLVER_STATUS_ARCH_MISMATCH)
29  return "function requires feature missing from device architecture";
30  else if (check == CUSOLVER_STATUS_EXECUTION_FAILED)
31  return "GPU program failed to execute";
32  else if (check == CUSOLVER_STATUS_INTERNAL_ERROR)
33  return "internal cuSOLVER operation failed";
34  else if (check == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
35  return "matrix type not supported";
36  else
37  return "unknown cuSOLVER error occurred";
38  }
39 
40  inline void check_gpusolver_ret(rocsolver_status check, std::string op)
41  {
42  if (check != CUSOLVER_STATUS_SUCCESS)
43  {
44  std::string msg = "rocsolver " + op + "() failed with error: " + get_rocsolver_error_msg(check);
45  throw std::runtime_error(msg);
46  }
47  }
48  }
49 
50 
51 
52  inline rocsolver_status getrf_buflen(rocsolver_handle handle, int m, int n,
53  float *A, int lda, int *lwork)
54  {
55  return cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, lwork);
56  }
57 
58  inline rocsolver_status getrf_buflen(rocsolver_handle handle, int m, int n,
59  double *A, int lda, int *lwork)
60  {
61  return cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, lwork);
62  }
63 
64  inline rocsolver_status getrf(rocsolver_handle handle, int m, int n,
65  float *A, int lda, float *work, int *ipiv, int *info)
66  {
67  return cusolverDnSgetrf(handle, m, n, A, lda, work, ipiv, info);
68  }
69 
70  inline rocsolver_status getrf(rocsolver_handle handle, int m, int n,
71  double *A, int lda, double *work, int *ipiv, int *info)
72  {
73  return cusolverDnDgetrf(handle, m, n, A, lda, work, ipiv, info);
74  }
75 
76 
77 
78  inline rocsolver_status gesvd_buflen(rocsolver_handle handle, int m, int n,
79  float *A, int *lwork)
80  {
81  (void)A;
82  return cusolverDnSgesvd_bufferSize(handle, m, n, lwork);
83  }
84 
85  inline rocsolver_status gesvd_buflen(rocsolver_handle handle, int m, int n,
86  double *A, int *lwork)
87  {
88  (void)A;
89  return cusolverDnDgesvd_bufferSize(handle, m, n, lwork);
90  }
91 
92  inline rocsolver_status gesvd(rocsolver_handle handle, signed char jobu,
93  signed char jobvt, const int m, const int n, float *A, const int lda,
94  float *S, float *U, const int ldu, float *VT, const int ldvt, float *work,
95  const int lwork, float *rwork, int *info)
96  {
97  return cusolverDnSgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
98  ldvt, work, lwork, rwork, info);
99  }
100 
101  inline rocsolver_status gesvd(rocsolver_handle handle, signed char jobu,
102  signed char jobvt, const int m, const int n, double *A, const int lda,
103  double *S, double *U, const int ldu, double *VT, const int ldvt, double *work,
104  const int lwork, double *rwork, int *info)
105  {
106  return cusolverDnDgesvd(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT,
107  ldvt, work, lwork, rwork, info);
108  }
109 
110 
111 
112  inline rocsolver_status syevd_buflen(rocsolver_handle handle,
113  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const float *A,
114  int lda, const float *W, int *lwork)
115  {
116  return cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
117  lwork);
118  }
119 
120  inline rocsolver_status syevd_buflen(rocsolver_handle handle,
121  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, const double *A,
122  int lda, const double *W, int *lwork)
123  {
124  return cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W,
125  lwork);
126  }
127 
128  inline rocsolver_status syevd(rocsolver_handle handle,
129  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, float *A, int lda,
130  float *W, float *work, int lwork, int *devInfo)
131  {
132  return cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
133  devInfo);
134  }
135 
136  inline rocsolver_status syevd(rocsolver_handle handle,
137  cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, double *A, int lda,
138  double *W, double *work, int lwork, int *devInfo)
139  {
140  return cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork,
141  devInfo);
142  }
143 
144 
145 
146  inline rocblas_status getri_batched(rocblas_handle handle, const int n,
147  const float **Aarray, const int lda, const int *devIpiv, float **Carray,
148  const int ldb, int *info, const int batchSize)
149  {
150  return cublasSgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
151  info, batchSize);
152  }
153 
154  inline rocblas_status getri_batched(rocblas_handle handle, const int n,
155  const double **Aarray, const int lda, const int *devIpiv, double **Carray,
156  const int ldb, int *info, const int batchSize)
157  {
158  return cublasDgetriBatched(handle, n, Aarray, lda, devIpiv, Carray, ldb,
159  info, batchSize);
160  }
161 
162 
163 
164  inline rocsolver_status getrs(rocsolver_handle handle,
165  rocblas_operation trans, const int n, const int nrhs, const float *A,
166  const int lda, const int *devIpiv, float *B, const int ldb, int *info)
167  {
168  return rocsolver_sgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
169  info);
170  }
171 
172  inline rocsolver_status getrs(rocsolver_handle handle,
173  rocblas_operation trans, const int n, const int nrhs, const double *A,
174  const int lda, const int *devIpiv, double *B, const int ldb, int *info)
175  {
176  return rocsolver_dgetrs(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb,
177  info);
178  }
179 }
180 
181 
182 #endif