5 #ifndef FML_MPI_LINALG_LINALG_SVD_H
6 #define FML_MPI_LINALG_LINALG_SVD_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
19 #include "../mpimat.hh"
21 #include "linalg_blas.hh"
22 #include "linalg_eigen.hh"
23 #include "linalg_qr.hh"
34 template <
typename REAL>
35 int svd_internals(
const int nu,
const int nv, mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
40 const len_t m = x.nrows();
41 const len_t n = x.ncols();
42 const len_t minmn = std::min(m, n);
46 if (nu == 0 && nv == 0)
56 const int mb = x.bf_rows();
57 const int nb = x.bf_cols();
59 u.resize(m, minmn, mb, nb);
60 vt.resize(minmn, n, mb, nb);
64 fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), &tmp, -1, &info);
65 int lwork = (int) tmp;
66 cpuvec<REAL> work(lwork);
68 fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), work.data_ptr(), lwork, &info);
95 template <
typename REAL>
96 void svd(mpimat<REAL> &x, cpuvec<REAL> &s)
98 mpimat<REAL> ignored(x.get_grid());
99 int info = svd_internals(0, 0, x, s, ignored, ignored);
100 fml::linalgutils::check_info(info,
"gesvd");
104 template <
typename REAL>
105 void svd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
107 err::check_grid(x, u, vt);
109 int info = svd_internals(1, 1, x, s, u, vt);
110 fml::linalgutils::check_info(info,
"gesvd");
115 template <
typename REAL>
116 void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
118 const len_t m = x.nrows();
119 const len_t n = x.ncols();
121 throw std::runtime_error(
"'x' must have more rows than cols");
123 const grid g = x.get_grid();
125 cpuvec<REAL> qraux(n);
126 cpuvec<REAL> work(m);
127 qr_internals(
false, x, qraux, work);
129 mpimat<REAL> R(g, n, n, x.bf_rows(), x.bf_cols());
132 mpimat<REAL> u_R(g, n, n, x.bf_rows(), x.bf_cols());
138 qr_Q(x, qraux, u, work);
140 matmult(
false,
false, (REAL)1.0, u, u_R, x);
144 template <
typename REAL>
145 void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s)
147 const len_t m = x.nrows();
148 const len_t n = x.ncols();
150 throw std::runtime_error(
"'x' must have more rows than cols");
152 const grid g = x.get_grid();
155 cpuvec<REAL> qraux(n);
156 cpuvec<REAL> work(m);
157 qr_internals(
false, x, qraux, work);
159 fml::mpi_utils::tri2zero(
'L',
false, g, n, n, x.data_ptr(), x.desc_ptr());
164 fml::scalapack::gesvd(
'N',
'N', n, n, x.data_ptr(), x.desc_ptr(),
165 s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
166 int lwork = (int) tmp;
167 if (lwork > work.size())
170 fml::scalapack::gesvd(
'N',
'N', n, n, x.data_ptr(), x.desc_ptr(),
171 s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
172 fml::linalgutils::check_info(info,
"gesvd");
177 template <
typename REAL>
178 void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
180 const len_t m = x.nrows();
181 const len_t n = x.ncols();
183 throw std::runtime_error(
"'x' must have more cols than rows");
185 const grid g = x.get_grid();
189 lq_internals(x, lqaux, work);
191 mpimat<REAL> L(g, m, m, x.bf_rows(), x.bf_cols());
194 mpimat<REAL> vt_L(g, m, m, x.bf_rows(), x.bf_cols());
200 lq_Q(x, lqaux, vt, work);
202 matmult(
false,
false, (REAL)1.0, vt_L, vt, x);
206 template <
typename REAL>
207 void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s)
209 const len_t m = x.nrows();
210 const len_t n = x.ncols();
212 throw std::runtime_error(
"'x' must have more cols than rows");
214 const grid g = x.get_grid();
219 lq_internals(x, lqaux, work);
221 fml::mpi_utils::tri2zero(
'U',
false, g, m, m, x.data_ptr(), x.desc_ptr());
226 fml::scalapack::gesvd(
'N',
'N', m, m, x.data_ptr(), x.desc_ptr(),
227 s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
228 int lwork = (int) tmp;
229 if (lwork > work.size())
232 fml::scalapack::gesvd(
'N',
'N', m, m, x.data_ptr(), x.desc_ptr(),
233 s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
234 fml::linalgutils::check_info(info,
"gesvd");
271 template <
typename REAL>
274 err::check_grid(x, u, vt);
285 template <
typename REAL>
323 template <
typename REAL>
326 err::check_grid(x, u, vt);
328 const len_t m = x.
nrows();
329 const len_t n = x.
ncols();
330 const len_t minmn = std::min(m, n);
332 const grid g = x.get_grid();
353 for (len_t i=0; i<s.
size(); i++)
354 s_d[i] = sqrt(fabs(s_d[i]));
356 len_t m_local, n_local;
360 m_local = vt.nrows_local();
361 n_local = vt.ncols_local();
366 m_local = cp.nrows_local();
367 n_local = cp.ncols_local();
371 for (len_t j=0; j<n_local; j++)
374 for (len_t i=0; i<m_local; i++)
376 const int gi = fml::bcutils::l2g(i, x.bf_rows(), g.
nprow(), g.
myrow());
377 const int gj = fml::bcutils::l2g(j, x.bf_cols(), g.
npcol(), g.
mycol());
379 if (gi < minmn && gj < minmn)
380 ev_d[i + m_local*j] /= s_d[gj];
386 matmult(
false,
false, (REAL)1.0, x, vt, u);
390 matmult(
true,
false, (REAL)1.0, cp, x, vt);
394 template <
typename REAL>
397 const len_t m = x.
nrows();
398 const len_t n = x.
ncols();
400 const grid g = x.get_grid();
413 for (len_t i=0; i<s.
size(); i++)
414 s_d[i] = sqrt(fabs(s_d[i]));
421 template <
typename REAL>
422 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
mpimat<REAL> &x,
425 const len_t m = x.
nrows();
426 const len_t n = x.
ncols();
428 const int mb = x.bf_rows();
429 const int nb = x.bf_cols();
431 auto g = x.get_grid();
434 omega.fill_runif(seed);
446 matmult(
false,
false, (REAL)1.0, x, omega, Y);
447 qr_internals(
false, Y, qraux, work);
448 qr_Q(Y, qraux, QY, work);
450 for (
int i=0; i<q; i++)
452 matmult(
true,
false, (REAL)1.0, x, QY, Z);
453 qr_internals(
false, Z, qraux, work);
454 qr_Q(Z, qraux, QZ, work);
456 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
457 qr_internals(
false, Y, qraux, work);
458 qr_Q(Y, qraux, QY, work);
484 template <
typename REAL>
488 const len_t m = x.
nrows();
489 const len_t n = x.
ncols();
491 mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
492 mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
495 rsvd_A(seed, k, q, x, QY);
498 matmult(
true,
false, (REAL)1.0, QY, x, B);
506 template <
typename REAL>
510 err::check_grid(x, u, vt);
512 const len_t m = x.
nrows();
513 const len_t n = x.
ncols();
515 mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
516 mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
519 rsvd_A(seed, k, q, x, QY);
522 matmult(
true,
false, (REAL)1.0, QY, x, B);
529 matmult(
false,
false, (REAL)1.0, QY, uB, u);