fml  0.1-0
Fused Matrix Library
svd.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_SVD_H
6 #define FML_MPI_LINALG_SVD_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
14 
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
17 
18 #include "../copy.hh"
19 #include "../mpimat.hh"
20 
21 #include "internals/err.hh"
22 #include "internals/scalapack.hh"
23 #include "eigen.hh"
24 #include "qr.hh"
25 #include "xpose.hh"
26 
27 
28 namespace fml
29 {
30 namespace linalg
31 {
32  namespace
33  {
34  namespace
35  {
36  template <typename REAL>
37  int svd_internals(const int nu, const int nv, mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
38  {
39  int info = 0;
40  char jobu, jobvt;
41 
42  const len_t m = x.nrows();
43  const len_t n = x.ncols();
44  const len_t minmn = std::min(m, n);
45 
46  s.resize(minmn);
47 
48  if (nu == 0 && nv == 0)
49  {
50  jobu = 'N';
51  jobvt = 'N';
52  }
53  else // if (nu <= minmn && nv <= minmn)
54  {
55  jobu = 'V';
56  jobvt = 'V';
57 
58  const int mb = x.bf_rows();
59  const int nb = x.bf_cols();
60 
61  u.resize(m, minmn, mb, nb);
62  vt.resize(minmn, n, mb, nb);
63  }
64 
65  REAL tmp;
66  fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), &tmp, -1, &info);
67  int lwork = (int) tmp;
68  cpuvec<REAL> work(lwork);
69 
70  fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), work.data_ptr(), lwork, &info);
71 
72  return info;
73  }
74  }
75 
97  template <typename REAL>
98  void svd(mpimat<REAL> &x, cpuvec<REAL> &s)
99  {
100  mpimat<REAL> ignored(x.get_grid());
101  int info = svd_internals(0, 0, x, s, ignored, ignored);
102  fml::linalgutils::check_info(info, "gesvd");
103  }
104 
106  template <typename REAL>
107  void svd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
108  {
109  err::check_grid(x, u, vt);
110 
111  int info = svd_internals(1, 1, x, s, u, vt);
112  fml::linalgutils::check_info(info, "gesvd");
113  }
114 
115 
116 
117  template <typename REAL>
118  void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
119  {
120  const len_t m = x.nrows();
121  const len_t n = x.ncols();
122  if (m <= n)
123  throw std::runtime_error("'x' must have more rows than cols");
124 
125  const grid g = x.get_grid();
126 
127  cpuvec<REAL> qraux(n);
128  cpuvec<REAL> work(m);
129  qr_internals(false, x, qraux, work);
130 
131  mpimat<REAL> R(g, n, n, x.bf_rows(), x.bf_cols());
132  qr_R(x, R);
133 
134  mpimat<REAL> u_R(g, n, n, x.bf_rows(), x.bf_cols());
135  svd(R, s, u_R, vt);
136 
137  u.resize(m, n);
138  u.fill_eye();
139 
140  qr_Q(x, qraux, u, work);
141 
142  matmult(false, false, (REAL)1.0, u, u_R, x);
143  copy::mpi2mpi(x, u);
144  }
145 
146  template <typename REAL>
147  void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s)
148  {
149  const len_t m = x.nrows();
150  const len_t n = x.ncols();
151  if (m <= n)
152  throw std::runtime_error("'x' must have more rows than cols");
153 
154  const grid g = x.get_grid();
155  s.resize(n);
156 
157  cpuvec<REAL> qraux(n);
158  cpuvec<REAL> work(m);
159  qr_internals(false, x, qraux, work);
160 
161  fml::mpi_utils::tri2zero('L', false, g, n, n, x.data_ptr(), x.desc_ptr());
162 
163  int info = 0;
164 
165  REAL tmp;
166  fml::scalapack::gesvd('N', 'N', n, n, x.data_ptr(), x.desc_ptr(),
167  s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
168  int lwork = (int) tmp;
169  if (lwork > work.size())
170  work.resize(lwork);
171 
172  fml::scalapack::gesvd('N', 'N', n, n, x.data_ptr(), x.desc_ptr(),
173  s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
174  fml::linalgutils::check_info(info, "gesvd");
175  }
176 
177 
178 
179  template <typename REAL>
180  void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
181  {
182  const len_t m = x.nrows();
183  const len_t n = x.ncols();
184  if (m >= n)
185  throw std::runtime_error("'x' must have more cols than rows");
186 
187  const grid g = x.get_grid();
188 
189  cpuvec<REAL> lqaux;
190  cpuvec<REAL> work;
191  lq_internals(x, lqaux, work);
192 
193  mpimat<REAL> L(g, m, m, x.bf_rows(), x.bf_cols());
194  lq_L(x, L);
195 
196  mpimat<REAL> vt_L(g, m, m, x.bf_rows(), x.bf_cols());
197  svd(L, s, u, vt_L);
198 
199  vt.resize(n, m);
200  vt.fill_eye();
201 
202  lq_Q(x, lqaux, vt, work);
203 
204  matmult(false, false, (REAL)1.0, vt_L, vt, x);
205  copy::mpi2mpi(x, vt);
206  }
207 
208  template <typename REAL>
209  void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s)
210  {
211  const len_t m = x.nrows();
212  const len_t n = x.ncols();
213  if (m >= n)
214  throw std::runtime_error("'x' must have more cols than rows");
215 
216  const grid g = x.get_grid();
217  s.resize(m);
218 
219  cpuvec<REAL> lqaux;
220  cpuvec<REAL> work;
221  lq_internals(x, lqaux, work);
222 
223  fml::mpi_utils::tri2zero('U', false, g, m, m, x.data_ptr(), x.desc_ptr());
224 
225  int info = 0;
226 
227  REAL tmp;
228  fml::scalapack::gesvd('N', 'N', m, m, x.data_ptr(), x.desc_ptr(),
229  s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
230  int lwork = (int) tmp;
231  if (lwork > work.size())
232  work.resize(lwork);
233 
234  fml::scalapack::gesvd('N', 'N', m, m, x.data_ptr(), x.desc_ptr(),
235  s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
236  fml::linalgutils::check_info(info, "gesvd");
237  }
238  }
239 
273  template <typename REAL>
275  {
276  err::check_grid(x, u, vt);
277 
278  if (x.is_square())
279  svd(x, s, u, vt);
280  else if (x.nrows() > x.ncols())
281  tssvd(x, s, u, vt);
282  else
283  sfsvd(x, s, u, vt);
284  }
285 
287  template <typename REAL>
289  {
290  if (x.is_square())
291  svd(x, s);
292  else if (x.nrows() > x.ncols())
293  tssvd(x, s);
294  else
295  sfsvd(x, s);
296  }
297 
298 
299 
325  template <typename REAL>
327  {
328  err::check_grid(x, u, vt);
329 
330  const len_t m = x.nrows();
331  const len_t n = x.ncols();
332  const len_t minmn = std::min(m, n);
333 
334  const grid g = x.get_grid();
335  mpimat<REAL> cp(g, x.bf_rows(), x.bf_cols());
336 
337  if (m >= n)
338  {
339  crossprod((REAL)1.0, x, cp);
340  eigen_sym(cp, s, vt);
341  vt.rev_cols();
342  copy::mpi2mpi(vt, cp);
343  }
344  else
345  {
346  tcrossprod((REAL)1.0, x, cp);
347  eigen_sym(cp, s, u);
348  u.rev_cols();
349  copy::mpi2mpi(u, cp);
350  }
351 
352  s.rev();
353  REAL *s_d = s.data_ptr();
354  #pragma omp for simd
355  for (len_t i=0; i<s.size(); i++)
356  s_d[i] = sqrt(fabs(s_d[i]));
357 
358  len_t m_local, n_local;
359  REAL *ev_d;
360  if (m >= n)
361  {
362  m_local = vt.nrows_local();
363  n_local = vt.ncols_local();
364  ev_d = vt.data_ptr();
365  }
366  else
367  {
368  m_local = cp.nrows_local();
369  n_local = cp.ncols_local();
370  ev_d = cp.data_ptr();
371  }
372 
373  for (len_t j=0; j<n_local; j++)
374  {
375  #pragma omp for simd
376  for (len_t i=0; i<m_local; i++)
377  {
378  const int gi = fml::bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow());
379  const int gj = fml::bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol());
380 
381  if (gi < minmn && gj < minmn)
382  ev_d[i + m_local*j] /= s_d[gj];
383  }
384  }
385 
386  if (m >= n)
387  {
388  matmult(false, false, (REAL)1.0, x, vt, u);
389  xpose(cp, vt);
390  }
391  else
392  matmult(true, false, (REAL)1.0, cp, x, vt);
393  }
394 
396  template <typename REAL>
397  void cpsvd(const mpimat<REAL> &x, cpuvec<REAL> &s)
398  {
399  const len_t m = x.nrows();
400  const len_t n = x.ncols();
401 
402  const grid g = x.get_grid();
403  mpimat<REAL> cp(g, x.bf_rows(), x.bf_cols());
404 
405  if (m >= n)
406  crossprod((REAL)1.0, x, cp);
407  else
408  tcrossprod((REAL)1.0, x, cp);
409 
410  eigen_sym(cp, s);
411 
412  s.rev();
413  REAL *s_d = s.data_ptr();
414  #pragma omp for simd
415  for (len_t i=0; i<s.size(); i++)
416  s_d[i] = sqrt(fabs(s_d[i]));
417  }
418 
419 
420 
421  namespace
422  {
423  template <typename REAL>
424  void rsvd_A(const uint32_t seed, const int k, const int q, mpimat<REAL> &x,
425  mpimat<REAL> &QY)
426  {
427  const len_t m = x.nrows();
428  const len_t n = x.ncols();
429 
430  const int mb = x.bf_rows();
431  const int nb = x.bf_cols();
432 
433  auto g = x.get_grid();
434 
435  mpimat<REAL> omega(g, n, 2*k, mb, nb);
436  omega.fill_runif(seed);
437 
438  mpimat<REAL> Y(g, m, 2*k, mb, nb);
439  mpimat<REAL> Z(g, n, 2*k, mb, nb);
440  mpimat<REAL> QZ(g, n, 2*k, mb, nb);
441 
442  cpuvec<REAL> qraux;
443  cpuvec<REAL> work;
444 
445  mpimat<REAL> B(g, 2*k, n, mb, nb);
446 
447  // Stage A
448  matmult(false, false, (REAL)1.0, x, omega, Y);
449  qr_internals(false, Y, qraux, work);
450  qr_Q(Y, qraux, QY, work);
451 
452  for (int i=0; i<q; i++)
453  {
454  matmult(true, false, (REAL)1.0, x, QY, Z);
455  qr_internals(false, Z, qraux, work);
456  qr_Q(Z, qraux, QZ, work);
457 
458  matmult(false, false, (REAL)1.0, x, QZ, Y);
459  qr_internals(false, Y, qraux, work);
460  qr_Q(Y, qraux, QY, work);
461  }
462  }
463  }
464 
486  template <typename REAL>
487  void rsvd(const uint32_t seed, const int k, const int q, mpimat<REAL> &x,
488  cpuvec<REAL> &s)
489  {
490  const len_t m = x.nrows();
491  const len_t n = x.ncols();
492 
493  mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
494  mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
495 
496  // Stage A
497  rsvd_A(seed, k, q, x, QY);
498 
499  // Stage B
500  matmult(true, false, (REAL)1.0, QY, x, B);
501 
502  svd(B, s);
503 
504  s.resize(k);
505  }
506 
508  template <typename REAL>
509  void rsvd(const uint32_t seed, const int k, const int q, mpimat<REAL> &x,
511  {
512  err::check_grid(x, u, vt);
513 
514  const len_t m = x.nrows();
515  const len_t n = x.ncols();
516 
517  mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
518  mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
519 
520  // Stage A
521  rsvd_A(seed, k, q, x, QY);
522 
523  // Stage B
524  matmult(true, false, (REAL)1.0, QY, x, B);
525 
526  mpimat<REAL> uB(x.get_grid());
527  svd(B, s, uB, vt);
528 
529  s.resize(k);
530 
531  matmult(false, false, (REAL)1.0, QY, uB, u);
532  u.resize(u.nrows(), k);
533  }
534 }
535 }
536 
537 
538 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:326
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: crossprod.hh:37
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::grid::mycol
int mycol() const
The process column (0-based index) of the calling process.
Definition: grid.hh:129
fml::unimat::is_square
bool is_square() const
Is the matrix square?
Definition: unimat.hh:34
fml::linalg::xpose
void xpose(const cpumat< REAL > &x, cpumat< REAL > &tx)
Computes the transpose out-of-place (i.e. in a copy).
Definition: xpose.hh:37
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::copy::mpi2mpi
void mpi2mpi(const mpimat< REAL_IN > &mpi_in, mpimat< REAL_OUT > &mpi_out)
Copy data from an MPI object to another.
Definition: copy.hh:288
fml::cpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: cpuvec.hh:210
fml::linalg::rsvd
void rsvd(const uint32_t seed, const int k, const int q, cpumat< REAL > &x, cpuvec< REAL > &s)
Computes the truncated singular value decomposition using the normal projections method of Halko et a...
Definition: svd.hh:462
fml::cpuvec::rev
void rev()
Reverse the vector.
Definition: cpuvec.hh:447
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::mpimat::rev_cols
void rev_cols()
Reverse the columns of the matrix.
Definition: mpimat.hh:965
fml::linalg::qrsvd
void qrsvd(cpumat< REAL > &x, cpuvec< REAL > &s, cpumat< REAL > &u, cpumat< REAL > &vt)
Computes the singular value decomposition by first reducing the rectangular matrix to a square matrix...
Definition: svd.hh:273
fml::linalg::tssvd
void tssvd(parmat_cpu< REAL > &x, cpuvec< REAL > &s)
Computes the singular value decomposition by first reducing the rectangular matrix to a square matrix...
Definition: svd.hh:130
fml::linalg::svd
void svd(cpumat< REAL > &x, cpuvec< REAL > &s)
Computes the singular value decomposition.
Definition: svd.hh:101
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml::linalg::lq_L
void lq_L(const cpumat< REAL > &LQ, cpumat< REAL > &L)
Recover the L matrix from an LQ decomposition.
Definition: qr.hh:247
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::qr_Q
void qr_Q(const cpumat< REAL > &QR, const cpuvec< REAL > &qraux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from a QR decomposition.
Definition: qr.hh:120
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::linalg::qr_R
void qr_R(const cpumat< REAL > &QR, cpumat< REAL > &R)
Recover the R matrix from a QR decomposition.
Definition: qr.hh:162
fml::linalg::eigen_sym
void eigen_sym(cpumat< REAL > &x, cpuvec< REAL > &values)
Compute the eigenvalues and optionally the eigenvectors for a symmetric matrix.
Definition: eigen.hh:95
fml::grid::npcol
int npcol() const
The number of processes columns in the BLACS context.
Definition: grid.hh:125
fml::grid::myrow
int myrow() const
The process row (0-based index) of the calling process.
Definition: grid.hh:127
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: crossprod.hh:81
fml::linalg::lq_Q
void lq_Q(const cpumat< REAL > &LQ, const cpuvec< REAL > &lqaux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from an LQ decomposition.
Definition: qr.hh:279
fml::grid::nprow
int nprow() const
The number of processes rows in the BLACS context.
Definition: grid.hh:123
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43
fml::linalg::cpsvd
void cpsvd(const cpumat< REAL > &x, cpuvec< REAL > &s, cpumat< REAL > &u, cpumat< REAL > &vt)
Computes the singular value decomposition using the "crossproducts SVD". This method is not numerical...
Definition: svd.hh:323