5 #ifndef FML_MPI_LINALG_LINALG_FACTORIZATIONS_H
6 #define FML_MPI_LINALG_LINALG_FACTORIZATIONS_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
19 #include "../mpimat.hh"
21 #include "linalg_blas.hh"
22 #include "linalg_err.hh"
23 #include "scalapack.hh"
53 template <
typename REAL>
57 const len_t m = x.
nrows();
58 const len_t lipiv = std::min(m, x.
ncols());
66 template <
typename REAL>
74 fml::linalgutils::check_info(info,
"getrf");
81 template <
typename REAL>
87 const len_t m = x.
nrows();
88 const len_t n = x.
ncols();
89 const len_t minmn = std::min(m, n);
93 if (nu == 0 && nv == 0)
103 const int mb = x.bf_rows();
104 const int nb = x.bf_cols();
106 u.
resize(m, minmn, mb, nb);
107 vt.
resize(minmn, n, mb, nb);
111 fml::scalapack::gesvd(jobu, jobvt, m, n, x.
data_ptr(), x.desc_ptr(), s.
data_ptr(), u.
data_ptr(), u.desc_ptr(), vt.
data_ptr(), vt.desc_ptr(), &tmp, -1, &info);
112 int lwork = (int) tmp;
113 cpuvec<REAL> work(lwork);
115 fml::scalapack::gesvd(jobu, jobvt, m, n, x.
data_ptr(), x.desc_ptr(), s.
data_ptr(), u.
data_ptr(), u.desc_ptr(), vt.
data_ptr(), vt.desc_ptr(), work.data_ptr(), lwork, &info);
142 template <
typename REAL>
146 int info = svd_internals(0, 0, x, s, ignored, ignored);
147 fml::linalgutils::check_info(info,
"gesvd");
151 template <
typename REAL>
154 err::check_grid(x, u);
155 err::check_grid(x, vt);
157 int info = svd_internals(1, 1, x, s, u, vt);
158 fml::linalgutils::check_info(info,
"gesvd");
165 template <
typename REAL>
166 int eig_sym_internals(
const bool only_values,
mpimat<REAL> &x,
170 throw std::runtime_error(
"'x' must be a square matrix");
173 int val_found, vec_found;
184 vectors.
resize(n, n, x.bf_rows(), x.bf_cols());
190 fml::scalapack::syevr(jobz,
'A',
'L', n, x.
data_ptr(), x.desc_ptr(),
191 (REAL) 0.f, (REAL) 0.f, 0, 0, &val_found, &vec_found,
193 &worksize, -1, &liwork, -1, &info);
195 lwork = (int) worksize;
196 cpuvec<REAL> work(lwork);
197 cpuvec<int> iwork(liwork);
199 fml::scalapack::syevr(jobz,
'A',
'L', n, x.
data_ptr(), x.desc_ptr(),
200 (REAL) 0.f, (REAL) 0.f, 0, 0, &val_found, &vec_found,
202 work.data_ptr(), lwork, iwork.data_ptr(), liwork, &info);
230 template <
typename REAL>
235 int info = eig_sym_internals(
true, x, values, ignored);
236 fml::linalgutils::check_info(info,
"syevr");
240 template <
typename REAL>
243 err::check_grid(x, vectors);
245 int info = eig_sym_internals(
false, x, values, vectors);
246 fml::linalgutils::check_info(info,
"syevr");
270 template <
typename REAL>
274 throw std::runtime_error(
"'x' must be a square matrix");
280 fml::linalgutils::check_info(info,
"getrf");
283 const len_t n = x.
nrows();
286 fml::scalapack::getri(n, x.
data_ptr(), x.desc_ptr(), p.
data_ptr(), &tmp, -1, &liwork, -1, &info);
287 int lwork = std::max(1, (
int)tmp);
292 fml::linalgutils::check_info(info,
"getri");
318 template <
typename REAL>
321 err::check_grid(x, y);
323 const len_t n = x.
nrows();
325 throw std::runtime_error(
"'x' must be a square matrix");
327 throw std::runtime_error(
"rhs 'y' must be compatible with data matrix 'x'");
332 fml::linalgutils::check_info(info,
"gesv");
339 template <
typename REAL>
342 const len_t m = x.
nrows();
343 const len_t n = x.
ncols();
344 const len_t minmn = std::min(m, n);
346 const int *descx = x.desc_ptr();
353 fml::scalapack::geqpf(m, n, NULL, descx, NULL, NULL, &tmp, -1, &info);
355 fml::scalapack::geqrf(m, n, NULL, descx, NULL, &tmp, -1, &info);
357 int lwork = std::max((
int) tmp, 1);
358 if (lwork > work.
size())
365 fml::scalapack::geqpf(m, n, x.
data_ptr(), descx, p.data_ptr(),
375 fml::linalgutils::check_info(info,
"geqpf");
377 fml::linalgutils::check_info(info,
"geqrf");
407 template <
typename REAL>
411 qr_internals(pivot, x, qraux, work);
435 template <
typename REAL>
438 err::check_grid(QR, Q);
440 const len_t m = QR.
nrows();
441 const len_t n = QR.
ncols();
442 const len_t minmn = std::min(m, n);
444 const int *descQR = QR.desc_ptr();
448 const int *descQ = Q.desc_ptr();
452 fml::scalapack::ormqr(
'L',
'N', m, minmn, minmn, NULL, descQR,
453 NULL, NULL, descQ, &tmp, -1, &info);
455 int lwork = (int) tmp;
456 if (lwork > work.
size())
459 fml::scalapack::ormqr(
'L',
'N', m, minmn, minmn, QR.
data_ptr(), descQR,
461 fml::linalgutils::check_info(info,
"ormqr");
483 template <
typename REAL>
486 err::check_grid(QR, R);
488 const len_t m = QR.
nrows();
489 const len_t n = QR.
ncols();
490 const len_t minmn = std::min(m, n);
494 fml::scalapack::lacpy(
'U', m, n, QR.
data_ptr(), QR.desc_ptr(), R.
data_ptr(),
502 template <
typename REAL>
505 const len_t m = x.
nrows();
506 const len_t n = x.
ncols();
507 const len_t minmn = std::min(m, n);
509 const int *descx = x.desc_ptr();
515 fml::scalapack::gelqf(m, n, NULL, descx, NULL, &tmp, -1, &info);
516 int lwork = std::max((
int) tmp, 1);
517 if (lwork > work.
size())
524 fml::linalgutils::check_info(info,
"gelqf");
551 template <
typename REAL>
555 lq_internals(x, lqaux, work);
577 template <
typename REAL>
580 err::check_grid(LQ, L);
582 const len_t m = LQ.
nrows();
583 const len_t n = LQ.
ncols();
584 const len_t minmn = std::min(m, n);
589 fml::scalapack::lacpy(
'L', m, n, LQ.
data_ptr(), LQ.desc_ptr(), L.
data_ptr(),
614 template <
typename REAL>
617 err::check_grid(LQ, Q);
619 const len_t m = LQ.
nrows();
620 const len_t n = LQ.
ncols();
621 const len_t minmn = std::min(m, n);
623 const int *descLQ = LQ.desc_ptr();
627 const int *descQ = Q.desc_ptr();
631 fml::scalapack::ormlq(
'R',
'N', minmn, n, minmn, NULL, descLQ,
632 NULL, NULL, descQ, &tmp, -1, &info);
634 int lwork = (int) tmp;
635 if (lwork > work.
size())
638 fml::scalapack::ormlq(
'R',
'N', minmn, n, minmn, LQ.
data_ptr(), descLQ,
640 fml::linalgutils::check_info(info,
"ormlq");
664 template <
typename REAL>
667 const len_t n = x.
nrows();
669 throw std::runtime_error(
"'x' must be a square matrix");
672 fml::scalapack::potrf(
'L', n, x.
data_ptr(), x.desc_ptr(), &info);
675 fml::linalgutils::check_info(info,
"potrf");
677 throw std::runtime_error(
"chol: leading minor of order " + std::to_string(info) +
" is not positive definite");
679 fml::mpi_utils::tri2zero(
'U',
false, x.get_grid(), n, n, x.
data_ptr(), x.desc_ptr());