5 #ifndef FML_GPU_LINALG_LINALG_FACTORIZATIONS_H
6 #define FML_GPU_LINALG_LINALG_FACTORIZATIONS_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpu_utils.hh"
17 #include "../internals/gpuscalar.hh"
18 #include "../internals/kernelfuns.hh"
21 #include "../gpumat.hh"
22 #include "../gpuvec.hh"
24 #include "linalg_err.hh"
25 #include "linalg_blas.hh"
53 template <
typename REAL>
56 err::check_card(x, p);
59 const len_t m = x.
nrows();
60 const len_t n = x.
ncols();
61 auto c = x.get_card();
63 const len_t lipiv = std::min(m, n);
64 if (!p.get_card()->valid_card())
69 #if defined(FML_GPULAPACK_VENDOR)
71 gpulapack_status_t check = gpulapack::getrf_buflen(c->lapack_handle(), m,
73 gpulapack::err::check_ret(check,
"getrf_bufferSize");
78 check = gpulapack::getrf(c->lapack_handle(), m, n, x.
data_ptr(), m,
81 info_device.get_val(&info);
82 gpulapack::err::check_ret(check,
"getrf");
83 #elif defined(FML_GPULAPACK_MAGMA)
88 #error "Unsupported GPU lapack"
93 template <
typename REAL>
101 fml::linalgutils::check_info(info,
"getrf");
108 template <
typename REAL>
111 auto c = x.get_card();
113 const len_t m = x.
nrows();
114 const len_t n = x.
ncols();
115 const len_t minmn = std::min(m, n);
119 signed char jobu, jobvt;
120 if (nu == 0 && nv == 0)
135 gpulapack_status_t check = gpulapack::gesvd_buflen(c->lapack_handle(), m, n,
137 gpulapack::err::check_ret(check,
"gesvd_bufferSize");
139 gpuvec<REAL> work(c, lwork);
140 gpuvec<REAL> rwork(c, minmn-1);
143 gpuscalar<int> info_device(c, info);
145 check = gpulapack::gesvd(c->lapack_handle(), jobu, jobvt, m, n, x.
data_ptr(),
147 lwork, rwork.data_ptr(), info_device.data_ptr());
149 info_device.get_val(&info);
150 gpulapack::err::check_ret(check,
"gesvd");
176 template <
typename REAL>
179 err::check_card(x, s);
182 int info = svd_internals(0, 0, x, s, ignored, ignored);
183 fml::linalgutils::check_info(info,
"gesvd");
187 template <
typename REAL>
190 err::check_card(x, s);
191 err::check_card(x, u);
192 err::check_card(x, vt);
196 int info = svd_internals(1, 1, x, s, u, vt);
197 fml::linalgutils::check_info(info,
"gesvd");
203 int info = svd_internals(1, 1, tx, s, v, u);
205 fml::linalgutils::check_info(info,
"gesvd");
213 template <
typename REAL>
214 int eig_sym_internals(
const bool only_values,
gpumat<REAL> &x,
218 throw std::runtime_error(
"'x' must be a square matrix");
220 auto c = x.get_card();
225 cusolverEigMode_t jobz;
227 jobz = CUSOLVER_EIG_MODE_NOVECTOR;
229 jobz = CUSOLVER_EIG_MODE_VECTOR;
232 gpulapack_status_t check = gpulapack::syevd_buflen(c->lapack_handle(), jobz,
234 gpulapack::err::check_ret(check,
"syevd_bufferSize");
241 check = gpulapack::syevd(c->lapack_handle(), jobz, GPUBLAS_FILL_L,
243 info_device.data_ptr());
245 info_device.get_val(&info);
246 gpulapack::err::check_ret(check,
"syevd");
247 fml::linalgutils::check_info(info,
"syevd");
280 template <
typename REAL>
283 err::check_card(x, values);
286 int info = eig_sym_internals(
true, x, values, ignored);
287 fml::linalgutils::check_info(info,
"syevd");
291 template <
typename REAL>
294 err::check_card(x, values);
295 err::check_card(x, vectors);
297 int info = eig_sym_internals(
false, x, values, vectors);
298 fml::linalgutils::check_info(info,
"syevd");
321 template <
typename REAL>
325 throw std::runtime_error(
"'x' must be a square matrix");
328 auto c = x.get_card();
332 fml::linalgutils::check_info(info,
"getrf");
335 const len_t n = x.
nrows();
336 const len_t nrhs = n;
342 gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N, n,
345 info_device.get_val(&info);
346 gpulapack::err::check_ret(check,
"getrs");
347 fml::linalgutils::check_info(info,
"getrs");
356 template <
typename REAL>
357 void solver(
gpumat<REAL> &x, len_t ylen, len_t nrhs, REAL *y_d)
359 const len_t n = x.
nrows();
361 throw std::runtime_error(
"'x' must be a square matrix");
363 throw std::runtime_error(
"rhs 'y' must be compatible with data matrix 'x'");
366 auto c = x.get_card();
370 fml::linalgutils::check_info(info,
"getrf");
375 gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N,
376 n, nrhs, x.
data_ptr(), n, p.data_ptr(), y_d, n, info_device.data_ptr());
378 info_device.get_val(&info);
379 gpulapack::err::check_ret(check,
"getrs");
380 fml::linalgutils::check_info(info,
"getrs");
403 template <
typename REAL>
406 err::check_card(x, y);
411 template <
typename REAL>
414 err::check_card(x, y);
422 template <
typename REAL>
426 throw std::runtime_error(
"pivoting not supported at this time");
428 const len_t m = x.
nrows();
429 const len_t n = x.
ncols();
430 const len_t minmn = std::min(m, n);
431 auto c = x.get_card();
436 gpulapack_status_t check = gpulapack::geqrf_buflen(c->lapack_handle(), m,
438 gpulapack::err::check_ret(check,
"geqrf_bufferSize");
440 if (lwork > work.
size())
446 check = gpulapack::geqrf(c->lapack_handle(), m, n, x.
data_ptr(), m,
449 info_device.get_val(&info);
450 gpulapack::err::check_ret(check,
"syevd");
451 fml::linalgutils::check_info(info,
"geqrf");
477 template <
typename REAL>
480 err::check_card(x, qraux);
482 qr_internals(pivot, x, qraux, work);
504 template <
typename REAL>
508 err::check_card(QR, qraux);
509 err::check_card(QR, Q);
510 err::check_card(QR, work);
512 const len_t m = QR.
nrows();
513 const len_t n = QR.
ncols();
514 const len_t minmn = std::min(m, n);
516 auto c = QR.get_card();
519 gpulapack_status_t check = gpulapack::ormqr_buflen(c->lapack_handle(),
520 GPUBLAS_SIDE_LEFT, GPUBLAS_OP_N, m, minmn, minmn, QR.
data_ptr(), m,
523 if (lwork > work.
size())
532 check = gpulapack::ormqr(c->lapack_handle(), GPUBLAS_SIDE_LEFT,
536 info_device.get_val(&info);
537 gpulapack::err::check_ret(check,
"ormqr");
538 fml::linalgutils::check_info(info,
"ormqr");
558 template <
typename REAL>
561 err::check_card(QR, R);
563 const len_t m = QR.
nrows();
564 const len_t n = QR.
ncols();
565 const len_t minmn = std::min(m, n);
569 fml::gpu_utils::lacpy(GPUBLAS_FILL_U, m, n, QR.
data_ptr(), m, R.
data_ptr(), minmn);
596 template <
typename REAL>
599 err::check_card(x, lqaux);
605 qr_internals(
false, tx, lqaux, work);
628 template <
typename REAL>
631 err::check_card(LQ, L);
633 const len_t m = LQ.
nrows();
634 const len_t n = LQ.
ncols();
635 const len_t minmn = std::min(m, n);
639 fml::gpu_utils::lacpy(GPUBLAS_FILL_L, m, n, LQ.
data_ptr(), m, L.
data_ptr(), m);
662 template <
typename REAL>
666 err::check_card(LQ, lqaux);
667 err::check_card(LQ, Q);
668 err::check_card(LQ, work);
673 const len_t m = LQ.
nrows();
674 const len_t n = LQ.
ncols();
675 const len_t minmn = std::min(m, n);
677 auto c = QR.get_card();
680 gpulapack_status_t check = gpulapack::ormqr_buflen(c->lapack_handle(),
681 GPUBLAS_SIDE_RIGHT, GPUBLAS_OP_T, minmn, n, n, QR.
data_ptr(), QR.
nrows(),
684 if (lwork > work.
size())
693 check = gpulapack::ormqr(c->lapack_handle(), GPUBLAS_SIDE_RIGHT,
697 info_device.get_val(&info);
698 gpulapack::err::check_ret(check,
"ormqr");
699 fml::linalgutils::check_info(info,
"ormqr");
723 template <
typename REAL>
726 const len_t n = x.
nrows();
728 throw std::runtime_error(
"'x' must be a square matrix");
730 auto c = x.get_card();
731 const auto fill = GPUBLAS_FILL_L;
734 gpulapack_status_t check = gpulapack::potrf_buflen(c->lapack_handle(), fill, n,
736 gpulapack::err::check_ret(check,
"potrf_bufferSize");
742 check = gpulapack::potrf(c->lapack_handle(), fill, n, x.
data_ptr(), n,
743 work.
data_ptr(), lwork, info_device.data_ptr());
745 info_device.get_val(&info);
746 gpulapack::err::check_ret(check,
"potrf");
748 fml::linalgutils::check_info(info,
"potrf");
750 throw std::runtime_error(
"chol: leading minor of order " + std::to_string(info) +
" is not positive definite");
752 fml::gpu_utils::tri2zero(
'U',
false, n, n, x.
data_ptr(), n);