5 #ifndef FML_CPU_LINALG_LINALG_FACTORIZATIONS_H
6 #define FML_CPU_LINALG_LINALG_FACTORIZATIONS_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
16 #include "../internals/cpu_utils.hh"
19 #include "../cpumat.hh"
20 #include "../cpuvec.hh"
23 #include "linalg_blas.hh"
52 template <
typename REAL>
56 const len_t m = x.
nrows();
57 const len_t lipiv = std::min(m, x.
ncols());
65 template <
typename REAL>
73 fml::linalgutils::check_info(info,
"getrf");
80 template <
typename REAL>
88 const len_t m = x.
nrows();
89 const len_t n = x.
ncols();
90 const len_t minmn = std::min(m, n);
94 if (nu == 0 && nv == 0)
99 else if (nu <= minmn && nv <= minmn)
113 cpuvec<int> iwork(8*minmn);
116 fml::lapack::gesdd(jobz, m, n, x.
data_ptr(), m, s.
data_ptr(), u.
data_ptr(), m, vt.
data_ptr(), ldvt, &tmp, -1, iwork.data_ptr(), &info);
117 int lwork = (int) tmp;
118 cpuvec<REAL> work(lwork);
120 fml::lapack::gesdd(jobz, m, n, x.
data_ptr(), m, s.
data_ptr(), u.
data_ptr(), m, vt.
data_ptr(), ldvt, work.data_ptr(), lwork, iwork.data_ptr(), &info);
145 template <
typename REAL>
149 int info = svd_internals(0, 0, x, s, ignored, ignored);
150 fml::linalgutils::check_info(info,
"gesdd");
154 template <
typename REAL>
157 int info = svd_internals(1, 1, x, s, u, vt);
158 fml::linalgutils::check_info(info,
"gesdd");
165 template <
typename REAL>
166 int eig_sym_internals(
const bool only_values,
cpumat<REAL> &x,
170 throw std::runtime_error(
"'x' must be a square matrix");
191 fml::lapack::syevr(jobz,
'A',
'L', n, x.
data_ptr(), n, (REAL) 0.f, (REAL) 0.f,
193 support.
data_ptr(), &worksize, -1, &liwork, -1,
196 lwork = (int) worksize;
197 cpuvec<REAL> work(lwork);
198 cpuvec<int> iwork(liwork);
200 fml::lapack::syevr(jobz,
'A',
'L', n, x.
data_ptr(), n, (REAL) 0.f, (REAL) 0.f,
202 support.
data_ptr(), work.data_ptr(), lwork, iwork.data_ptr(), liwork,
229 template <
typename REAL>
234 int info = eig_sym_internals(
true, x, values, ignored);
235 fml::linalgutils::check_info(info,
"syevr");
239 template <
typename REAL>
242 int info = eig_sym_internals(
false, x, values, vectors);
243 fml::linalgutils::check_info(info,
"syevr");
264 template <
typename REAL>
267 const len_t n = x.
nrows();
269 throw std::runtime_error(
"'x' must be a square matrix");
275 fml::linalgutils::check_info(info,
"getrf");
280 int lwork = (int) tmp;
284 fml::linalgutils::check_info(info,
"getri");
291 template <
typename REAL>
292 void solver(
cpumat<REAL> &x, len_t ylen, len_t nrhs, REAL *y_d)
294 const len_t n = x.
nrows();
296 throw std::runtime_error(
"'x' must be a square matrix");
298 throw std::runtime_error(
"rhs 'y' must be compatible with data matrix 'x'");
302 fml::lapack::gesv(n, nrhs, x.
data_ptr(), n, p.data_ptr(), y_d, n, &info);
303 fml::linalgutils::check_info(info,
"gesv");
325 template <
typename REAL>
332 template <
typename REAL>
342 template <
typename REAL>
345 const len_t m = x.
nrows();
346 const len_t n = x.
ncols();
347 const len_t minmn = std::min(m, n);
354 fml::lapack::geqp3(m, n, NULL, m, NULL, NULL, &tmp, -1, &info);
356 fml::lapack::geqrf(m, n, NULL, m, NULL, &tmp, -1, &info);
358 int lwork = std::max((
int) tmp, 1);
359 if (lwork > work.
size())
374 fml::linalgutils::check_info(info,
"geqp3");
376 fml::linalgutils::check_info(info,
"geqrf");
404 template <
typename REAL>
408 qr_internals(pivot, x, qraux, work);
430 template <
typename REAL>
434 const len_t m = QR.
nrows();
435 const len_t n = QR.
ncols();
436 const len_t minmn = std::min(m, n);
440 fml::lapack::ormqr(
'L',
'N', m, minmn, m, QR.
data_ptr(), m, NULL,
441 NULL, m, &tmp, -1, &info);
443 int lwork = (int) tmp;
444 if (lwork > work.
size())
450 fml::lapack::ormqr(
'L',
'N', m, minmn, m, QR.
data_ptr(), m, qraux.
data_ptr(),
452 fml::linalgutils::check_info(info,
"ormqr");
472 template <
typename REAL>
475 const len_t m = QR.
nrows();
476 const len_t n = QR.
ncols();
477 const len_t minmn = std::min(m, n);
488 template <
typename REAL>
491 const len_t m = x.
nrows();
492 const len_t n = x.
ncols();
493 const len_t minmn = std::min(m, n);
499 fml::lapack::gelqf(m, n, NULL, m, NULL, &tmp, -1, &info);
500 int lwork = std::max((
int) tmp, 1);
501 if (lwork > work.
size())
508 fml::linalgutils::check_info(info,
"gelqf");
533 template <
typename REAL>
537 lq_internals(x, lqaux, work);
557 template <
typename REAL>
560 const len_t m = LQ.
nrows();
561 const len_t n = LQ.
ncols();
562 const len_t minmn = std::min(m, n);
589 template <
typename REAL>
593 const len_t m = LQ.
nrows();
594 const len_t n = LQ.
ncols();
595 const len_t minmn = std::min(m, n);
599 fml::lapack::ormlq(
'R',
'N', minmn, n, minmn, LQ.
data_ptr(), m, NULL,
600 NULL, minmn, &tmp, -1, &info);
602 int lwork = (int) tmp;
603 if (lwork > work.
size())
609 fml::lapack::ormlq(
'R',
'N', minmn, n, minmn, LQ.
data_ptr(), m, lqaux.
data_ptr(),
611 fml::linalgutils::check_info(info,
"ormlq");
635 template <
typename REAL>
638 const len_t n = x.
nrows();
640 throw std::runtime_error(
"'x' must be a square matrix");
643 fml::lapack::potrf(
'L', n, x.
data_ptr(), n, &info);
646 fml::linalgutils::check_info(info,
"potrf");
648 throw std::runtime_error(
"chol: leading minor of order " + std::to_string(info) +
" is not positive definite");
650 fml::cpu_utils::tri2zero(
'U',
false, n, n, x.
data_ptr(), n);