fml  0.1-0
Fused Matrix Library
linalg_factorizations.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_LINALG_FACTORIZATIONS_H
6 #define FML_MPI_LINALG_LINALG_FACTORIZATIONS_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
14 
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
17 
18 #include "../copy.hh"
19 #include "../mpimat.hh"
20 
21 #include "linalg_blas.hh"
22 #include "linalg_err.hh"
23 #include "scalapack.hh"
24 
25 
26 namespace fml
27 {
28 namespace linalg
29 {
53  template <typename REAL>
54  void lu(mpimat<REAL> &x, cpuvec<int> &p, int &info)
55  {
56  info = 0;
57  const len_t m = x.nrows();
58  const len_t lipiv = std::min(m, x.ncols());
59 
60  p.resize(lipiv);
61 
62  fml::scalapack::getrf(m, x.ncols(), x.data_ptr(), x.desc_ptr(), p.data_ptr(), &info);
63  }
64 
66  template <typename REAL>
67  void lu(mpimat<REAL> &x)
68  {
69  cpuvec<int> p;
70  int info;
71 
72  lu(x, p, info);
73 
74  fml::linalgutils::check_info(info, "getrf");
75  }
76 
77 
78 
79  namespace
80  {
81  template <typename REAL>
82  int svd_internals(const int nu, const int nv, mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
83  {
84  int info = 0;
85  char jobu, jobvt;
86 
87  const len_t m = x.nrows();
88  const len_t n = x.ncols();
89  const len_t minmn = std::min(m, n);
90 
91  s.resize(minmn);
92 
93  if (nu == 0 && nv == 0)
94  {
95  jobu = 'N';
96  jobvt = 'N';
97  }
98  else // if (nu <= minmn && nv <= minmn)
99  {
100  jobu = 'V';
101  jobvt = 'V';
102 
103  const int mb = x.bf_rows();
104  const int nb = x.bf_cols();
105 
106  u.resize(m, minmn, mb, nb);
107  vt.resize(minmn, n, mb, nb);
108  }
109 
110  REAL tmp;
111  fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), &tmp, -1, &info);
112  int lwork = (int) tmp;
113  cpuvec<REAL> work(lwork);
114 
115  fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), work.data_ptr(), lwork, &info);
116 
117  return info;
118  }
119  }
120 
142  template <typename REAL>
144  {
145  mpimat<REAL> ignored(x.get_grid());
146  int info = svd_internals(0, 0, x, s, ignored, ignored);
147  fml::linalgutils::check_info(info, "gesvd");
148  }
149 
151  template <typename REAL>
153  {
154  err::check_grid(x, u);
155  err::check_grid(x, vt);
156 
157  int info = svd_internals(1, 1, x, s, u, vt);
158  fml::linalgutils::check_info(info, "gesvd");
159  }
160 
161 
162 
163  namespace
164  {
165  template <typename REAL>
166  int eig_sym_internals(const bool only_values, mpimat<REAL> &x,
167  cpuvec<REAL> &values, mpimat<REAL> &vectors)
168  {
169  if (!x.is_square())
170  throw std::runtime_error("'x' must be a square matrix");
171 
172  int info = 0;
173  int val_found, vec_found;
174  char jobz;
175 
176  len_t n = x.nrows();
177  values.resize(n);
178 
179  if (only_values)
180  jobz = 'N';
181  else
182  {
183  jobz = 'V';
184  vectors.resize(n, n, x.bf_rows(), x.bf_cols());
185  }
186 
187  REAL worksize;
188  int lwork, liwork;
189 
190  fml::scalapack::syevr(jobz, 'A', 'L', n, x.data_ptr(), x.desc_ptr(),
191  (REAL) 0.f, (REAL) 0.f, 0, 0, &val_found, &vec_found,
192  values.data_ptr(), vectors.data_ptr(), vectors.desc_ptr(),
193  &worksize, -1, &liwork, -1, &info);
194 
195  lwork = (int) worksize;
196  cpuvec<REAL> work(lwork);
197  cpuvec<int> iwork(liwork);
198 
199  fml::scalapack::syevr(jobz, 'A', 'L', n, x.data_ptr(), x.desc_ptr(),
200  (REAL) 0.f, (REAL) 0.f, 0, 0, &val_found, &vec_found,
201  values.data_ptr(), vectors.data_ptr(), vectors.desc_ptr(),
202  work.data_ptr(), lwork, iwork.data_ptr(), liwork, &info);
203 
204  return info;
205  }
206  }
207 
230  template <typename REAL>
232  {
233  mpimat<REAL> ignored(x.get_grid());
234 
235  int info = eig_sym_internals(true, x, values, ignored);
236  fml::linalgutils::check_info(info, "syevr");
237  }
238 
240  template <typename REAL>
241  void eigen_sym(mpimat<REAL> &x, cpuvec<REAL> &values, mpimat<REAL> &vectors)
242  {
243  err::check_grid(x, vectors);
244 
245  int info = eig_sym_internals(false, x, values, vectors);
246  fml::linalgutils::check_info(info, "syevr");
247  }
248 
249 
250 
270  template <typename REAL>
272  {
273  if (!x.is_square())
274  throw std::runtime_error("'x' must be a square matrix");
275 
276  // Factor x = LU
277  cpuvec<int> p;
278  int info;
279  lu(x, p, info);
280  fml::linalgutils::check_info(info, "getrf");
281 
282  // Invert
283  const len_t n = x.nrows();
284  REAL tmp;
285  int liwork;
286  fml::scalapack::getri(n, x.data_ptr(), x.desc_ptr(), p.data_ptr(), &tmp, -1, &liwork, -1, &info);
287  int lwork = std::max(1, (int)tmp);
288  cpuvec<REAL> work(lwork);
289  cpuvec<int> iwork(liwork);
290 
291  fml::scalapack::getri(n, x.data_ptr(), x.desc_ptr(), p.data_ptr(), work.data_ptr(), lwork, iwork.data_ptr(), liwork, &info);
292  fml::linalgutils::check_info(info, "getri");
293  }
294 
295 
296 
318  template <typename REAL>
320  {
321  err::check_grid(x, y);
322 
323  const len_t n = x.nrows();
324  if (!x.is_square())
325  throw std::runtime_error("'x' must be a square matrix");
326  if (n != y.nrows())
327  throw std::runtime_error("rhs 'y' must be compatible with data matrix 'x'");
328 
329  int info;
330  cpuvec<int> p(n);
331  fml::scalapack::gesv(n, y.ncols(), x.data_ptr(), x.desc_ptr(), p.data_ptr(), y.data_ptr(), y.desc_ptr(), &info);
332  fml::linalgutils::check_info(info, "gesv");
333  }
334 
335 
336 
337  namespace
338  {
339  template <typename REAL>
340  void qr_internals(const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux, cpuvec<REAL> &work)
341  {
342  const len_t m = x.nrows();
343  const len_t n = x.ncols();
344  const len_t minmn = std::min(m, n);
345 
346  const int *descx = x.desc_ptr();
347 
348  int info = 0;
349  qraux.resize(minmn);
350 
351  REAL tmp;
352  if (pivot)
353  fml::scalapack::geqpf(m, n, NULL, descx, NULL, NULL, &tmp, -1, &info);
354  else
355  fml::scalapack::geqrf(m, n, NULL, descx, NULL, &tmp, -1, &info);
356 
357  int lwork = std::max((int) tmp, 1);
358  if (lwork > work.size())
359  work.resize(lwork);
360 
361  if (pivot)
362  {
363  cpuvec<int> p(n);
364  p.fill_zero();
365  fml::scalapack::geqpf(m, n, x.data_ptr(), descx, p.data_ptr(),
366  qraux.data_ptr(), work.data_ptr(), lwork, &info);
367  }
368  else
369  fml::scalapack::geqrf(m, n, x.data_ptr(), descx, qraux.data_ptr(),
370  work.data_ptr(), lwork, &info);
371 
372  if (info != 0)
373  {
374  if (pivot)
375  fml::linalgutils::check_info(info, "geqpf");
376  else
377  fml::linalgutils::check_info(info, "geqrf");
378  }
379  }
380  }
381 
407  template <typename REAL>
408  void qr(const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux)
409  {
410  cpuvec<REAL> work;
411  qr_internals(pivot, x, qraux, work);
412  }
413 
435  template <typename REAL>
436  void qr_Q(const mpimat<REAL> &QR, const cpuvec<REAL> &qraux, mpimat<REAL> &Q, cpuvec<REAL> &work)
437  {
438  err::check_grid(QR, Q);
439 
440  const len_t m = QR.nrows();
441  const len_t n = QR.ncols();
442  const len_t minmn = std::min(m, n);
443 
444  const int *descQR = QR.desc_ptr();
445 
446  Q.resize(m, minmn);
447  Q.fill_eye();
448  const int *descQ = Q.desc_ptr();
449 
450  int info = 0;
451  REAL tmp;
452  fml::scalapack::ormqr('L', 'N', m, minmn, minmn, NULL, descQR,
453  NULL, NULL, descQ, &tmp, -1, &info);
454 
455  int lwork = (int) tmp;
456  if (lwork > work.size())
457  work.resize(lwork);
458 
459  fml::scalapack::ormqr('L', 'N', m, minmn, minmn, QR.data_ptr(), descQR,
460  qraux.data_ptr(), Q.data_ptr(), descQ, work.data_ptr(), lwork, &info);
461  fml::linalgutils::check_info(info, "ormqr");
462  }
463 
483  template <typename REAL>
484  void qr_R(const mpimat<REAL> &QR, mpimat<REAL> &R)
485  {
486  err::check_grid(QR, R);
487 
488  const len_t m = QR.nrows();
489  const len_t n = QR.ncols();
490  const len_t minmn = std::min(m, n);
491 
492  R.resize(minmn, n);
493  R.fill_zero();
494  fml::scalapack::lacpy('U', m, n, QR.data_ptr(), QR.desc_ptr(), R.data_ptr(),
495  R.desc_ptr());
496  }
497 
498 
499 
500  namespace
501  {
502  template <typename REAL>
503  void lq_internals(mpimat<REAL> &x, cpuvec<REAL> &lqaux, cpuvec<REAL> &work)
504  {
505  const len_t m = x.nrows();
506  const len_t n = x.ncols();
507  const len_t minmn = std::min(m, n);
508 
509  const int *descx = x.desc_ptr();
510 
511  int info = 0;
512  lqaux.resize(minmn);
513 
514  REAL tmp;
515  fml::scalapack::gelqf(m, n, NULL, descx, NULL, &tmp, -1, &info);
516  int lwork = std::max((int) tmp, 1);
517  if (lwork > work.size())
518  work.resize(lwork);
519 
520  fml::scalapack::gelqf(m, n, x.data_ptr(), descx, lqaux.data_ptr(),
521  work.data_ptr(), lwork, &info);
522 
523  if (info != 0)
524  fml::linalgutils::check_info(info, "gelqf");
525  }
526  }
527 
551  template <typename REAL>
552  void lq(mpimat<REAL> &x, cpuvec<REAL> &lqaux)
553  {
554  cpuvec<REAL> work;
555  lq_internals(x, lqaux, work);
556  }
557 
577  template <typename REAL>
578  void lq_L(const mpimat<REAL> &LQ, mpimat<REAL> &L)
579  {
580  err::check_grid(LQ, L);
581 
582  const len_t m = LQ.nrows();
583  const len_t n = LQ.ncols();
584  const len_t minmn = std::min(m, n);
585 
586  L.resize(m, minmn);
587  L.fill_zero();
588 
589  fml::scalapack::lacpy('L', m, n, LQ.data_ptr(), LQ.desc_ptr(), L.data_ptr(),
590  L.desc_ptr());
591  }
592 
614  template <typename REAL>
615  void lq_Q(const mpimat<REAL> &LQ, const cpuvec<REAL> &lqaux, mpimat<REAL> &Q, cpuvec<REAL> &work)
616  {
617  err::check_grid(LQ, Q);
618 
619  const len_t m = LQ.nrows();
620  const len_t n = LQ.ncols();
621  const len_t minmn = std::min(m, n);
622 
623  const int *descLQ = LQ.desc_ptr();
624 
625  Q.resize(minmn, n);
626  Q.fill_eye();
627  const int *descQ = Q.desc_ptr();
628 
629  int info = 0;
630  REAL tmp;
631  fml::scalapack::ormlq('R', 'N', minmn, n, minmn, NULL, descLQ,
632  NULL, NULL, descQ, &tmp, -1, &info);
633 
634  int lwork = (int) tmp;
635  if (lwork > work.size())
636  work.resize(lwork);
637 
638  fml::scalapack::ormlq('R', 'N', minmn, n, minmn, LQ.data_ptr(), descLQ,
639  lqaux.data_ptr(), Q.data_ptr(), descQ, work.data_ptr(), lwork, &info);
640  fml::linalgutils::check_info(info, "ormlq");
641  }
642 
643 
644 
664  template <typename REAL>
666  {
667  const len_t n = x.nrows();
668  if (n != x.ncols())
669  throw std::runtime_error("'x' must be a square matrix");
670 
671  int info = 0;
672  fml::scalapack::potrf('L', n, x.data_ptr(), x.desc_ptr(), &info);
673 
674  if (info < 0)
675  fml::linalgutils::check_info(info, "potrf");
676  else if (info > 0)
677  throw std::runtime_error("chol: leading minor of order " + std::to_string(info) + " is not positive definite");
678 
679  fml::mpi_utils::tri2zero('U', false, x.get_grid(), n, n, x.data_ptr(), x.desc_ptr());
680  }
681 }
682 }
683 
684 
685 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:323
fml::linalg::solve
void solve(cpumat< REAL > &x, cpuvec< REAL > &y)
Solve a system of equations.
Definition: linalg_factorizations.hh:326
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::unimat::is_square
bool is_square() const
Is the matrix square?
Definition: unimat.hh:34
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::cpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: cpuvec.hh:185
fml::linalg::qr
void qr(const bool pivot, cpumat< REAL > &x, cpuvec< REAL > &qraux)
Computes the QR decomposition.
Definition: linalg_factorizations.hh:405
fml::linalg::lu
void lu(cpumat< REAL > &x, cpuvec< int > &p, int &info)
Computes the PLU factorization with partial pivoting.
Definition: linalg_factorizations.hh:53
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::linalg::svd
void svd(cpumat< REAL > &x, cpuvec< REAL > &s)
Computes the singular value decomposition.
Definition: linalg_factorizations.hh:146
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml::mpimat::fill_eye
void fill_eye()
Set diagonal entries to 1 and non-diagonal entries to 0.
Definition: mpimat.hh:654
fml::linalg::lq_L
void lq_L(const cpumat< REAL > &LQ, cpumat< REAL > &L)
Recover the L matrix from an LQ decomposition.
Definition: linalg_factorizations.hh:558
fml
Core namespace.
Definition: linalgutils.hh:15
fml::mpimat::fill_zero
void fill_zero()
Set all values to zero.
Definition: mpimat.hh:562
fml::linalg::qr_Q
void qr_Q(const cpumat< REAL > &QR, const cpuvec< REAL > &qraux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from a QR decomposition.
Definition: linalg_factorizations.hh:431
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::linalg::qr_R
void qr_R(const cpumat< REAL > &QR, cpumat< REAL > &R)
Recover the R matrix from a QR decomposition.
Definition: linalg_factorizations.hh:473
fml::linalg::chol
void chol(cpumat< REAL > &x)
Compute the Choleski factorization.
Definition: linalg_factorizations.hh:636
fml::linalg::lq
void lq(cpumat< REAL > &x, cpuvec< REAL > &lqaux)
Computes the LQ decomposition.
Definition: linalg_factorizations.hh:534
fml::linalg::eigen_sym
void eigen_sym(cpumat< REAL > &x, cpuvec< REAL > &values)
Compute the eigenvalues and optionally the eigenvectors for a symmetric matrix.
Definition: linalg_factorizations.hh:230
fml::linalg::invert
void invert(cpumat< REAL > &x)
Compute the matrix inverse.
Definition: linalg_factorizations.hh:265
fml::linalg::lq_Q
void lq_Q(const cpumat< REAL > &LQ, const cpuvec< REAL > &lqaux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from an LQ decomposition.
Definition: linalg_factorizations.hh:590