5 #ifndef FML_MPI_LINALG_SVD_H
6 #define FML_MPI_LINALG_SVD_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
19 #include "../mpimat.hh"
21 #include "internals/err.hh"
22 #include "internals/scalapack.hh"
36 template <
typename REAL>
37 int svd_internals(
const int nu,
const int nv, mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
42 const len_t m = x.nrows();
43 const len_t n = x.ncols();
44 const len_t minmn = std::min(m, n);
48 if (nu == 0 && nv == 0)
58 const int mb = x.bf_rows();
59 const int nb = x.bf_cols();
61 u.resize(m, minmn, mb, nb);
62 vt.resize(minmn, n, mb, nb);
66 fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), &tmp, -1, &info);
67 int lwork = (int) tmp;
68 cpuvec<REAL> work(lwork);
70 fml::scalapack::gesvd(jobu, jobvt, m, n, x.data_ptr(), x.desc_ptr(), s.data_ptr(), u.data_ptr(), u.desc_ptr(), vt.data_ptr(), vt.desc_ptr(), work.data_ptr(), lwork, &info);
97 template <
typename REAL>
98 void svd(mpimat<REAL> &x, cpuvec<REAL> &s)
100 mpimat<REAL> ignored(x.get_grid());
101 int info = svd_internals(0, 0, x, s, ignored, ignored);
102 fml::linalgutils::check_info(info,
"gesvd");
106 template <
typename REAL>
107 void svd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
109 err::check_grid(x, u, vt);
111 int info = svd_internals(1, 1, x, s, u, vt);
112 fml::linalgutils::check_info(info,
"gesvd");
117 template <
typename REAL>
118 void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
120 const len_t m = x.nrows();
121 const len_t n = x.ncols();
123 throw std::runtime_error(
"'x' must have more rows than cols");
125 const grid g = x.get_grid();
127 cpuvec<REAL> qraux(n);
128 cpuvec<REAL> work(m);
129 qr_internals(
false, x, qraux, work);
131 mpimat<REAL> R(g, n, n, x.bf_rows(), x.bf_cols());
134 mpimat<REAL> u_R(g, n, n, x.bf_rows(), x.bf_cols());
140 qr_Q(x, qraux, u, work);
142 matmult(
false,
false, (REAL)1.0, u, u_R, x);
146 template <
typename REAL>
147 void tssvd(mpimat<REAL> &x, cpuvec<REAL> &s)
149 const len_t m = x.nrows();
150 const len_t n = x.ncols();
152 throw std::runtime_error(
"'x' must have more rows than cols");
154 const grid g = x.get_grid();
157 cpuvec<REAL> qraux(n);
158 cpuvec<REAL> work(m);
159 qr_internals(
false, x, qraux, work);
161 fml::mpi_utils::tri2zero(
'L',
false, g, n, n, x.data_ptr(), x.desc_ptr());
166 fml::scalapack::gesvd(
'N',
'N', n, n, x.data_ptr(), x.desc_ptr(),
167 s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
168 int lwork = (int) tmp;
169 if (lwork > work.size())
172 fml::scalapack::gesvd(
'N',
'N', n, n, x.data_ptr(), x.desc_ptr(),
173 s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
174 fml::linalgutils::check_info(info,
"gesvd");
179 template <
typename REAL>
180 void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s, mpimat<REAL> &u, mpimat<REAL> &vt)
182 const len_t m = x.nrows();
183 const len_t n = x.ncols();
185 throw std::runtime_error(
"'x' must have more cols than rows");
187 const grid g = x.get_grid();
191 lq_internals(x, lqaux, work);
193 mpimat<REAL> L(g, m, m, x.bf_rows(), x.bf_cols());
196 mpimat<REAL> vt_L(g, m, m, x.bf_rows(), x.bf_cols());
202 lq_Q(x, lqaux, vt, work);
204 matmult(
false,
false, (REAL)1.0, vt_L, vt, x);
208 template <
typename REAL>
209 void sfsvd(mpimat<REAL> &x, cpuvec<REAL> &s)
211 const len_t m = x.nrows();
212 const len_t n = x.ncols();
214 throw std::runtime_error(
"'x' must have more cols than rows");
216 const grid g = x.get_grid();
221 lq_internals(x, lqaux, work);
223 fml::mpi_utils::tri2zero(
'U',
false, g, m, m, x.data_ptr(), x.desc_ptr());
228 fml::scalapack::gesvd(
'N',
'N', m, m, x.data_ptr(), x.desc_ptr(),
229 s.data_ptr(), NULL, NULL, NULL, NULL, &tmp, -1, &info);
230 int lwork = (int) tmp;
231 if (lwork > work.size())
234 fml::scalapack::gesvd(
'N',
'N', m, m, x.data_ptr(), x.desc_ptr(),
235 s.data_ptr(), NULL, NULL, NULL, NULL, work.data_ptr(), lwork, &info);
236 fml::linalgutils::check_info(info,
"gesvd");
273 template <
typename REAL>
276 err::check_grid(x, u, vt);
287 template <
typename REAL>
325 template <
typename REAL>
328 err::check_grid(x, u, vt);
330 const len_t m = x.
nrows();
331 const len_t n = x.
ncols();
332 const len_t minmn = std::min(m, n);
334 const grid g = x.get_grid();
355 for (len_t i=0; i<s.
size(); i++)
356 s_d[i] = sqrt(fabs(s_d[i]));
358 len_t m_local, n_local;
362 m_local = vt.nrows_local();
363 n_local = vt.ncols_local();
368 m_local = cp.nrows_local();
369 n_local = cp.ncols_local();
373 for (len_t j=0; j<n_local; j++)
376 for (len_t i=0; i<m_local; i++)
378 const int gi = fml::bcutils::l2g(i, x.bf_rows(), g.
nprow(), g.
myrow());
379 const int gj = fml::bcutils::l2g(j, x.bf_cols(), g.
npcol(), g.
mycol());
381 if (gi < minmn && gj < minmn)
382 ev_d[i + m_local*j] /= s_d[gj];
388 matmult(
false,
false, (REAL)1.0, x, vt, u);
392 matmult(
true,
false, (REAL)1.0, cp, x, vt);
396 template <
typename REAL>
399 const len_t m = x.
nrows();
400 const len_t n = x.
ncols();
402 const grid g = x.get_grid();
415 for (len_t i=0; i<s.
size(); i++)
416 s_d[i] = sqrt(fabs(s_d[i]));
423 template <
typename REAL>
424 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
mpimat<REAL> &x,
427 const len_t m = x.
nrows();
428 const len_t n = x.
ncols();
430 const int mb = x.bf_rows();
431 const int nb = x.bf_cols();
433 auto g = x.get_grid();
436 omega.fill_runif(seed);
448 matmult(
false,
false, (REAL)1.0, x, omega, Y);
449 qr_internals(
false, Y, qraux, work);
450 qr_Q(Y, qraux, QY, work);
452 for (
int i=0; i<q; i++)
454 matmult(
true,
false, (REAL)1.0, x, QY, Z);
455 qr_internals(
false, Z, qraux, work);
456 qr_Q(Z, qraux, QZ, work);
458 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
459 qr_internals(
false, Y, qraux, work);
460 qr_Q(Y, qraux, QY, work);
486 template <
typename REAL>
490 const len_t m = x.
nrows();
491 const len_t n = x.
ncols();
493 mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
494 mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
497 rsvd_A(seed, k, q, x, QY);
500 matmult(
true,
false, (REAL)1.0, QY, x, B);
508 template <
typename REAL>
512 err::check_grid(x, u, vt);
514 const len_t m = x.
nrows();
515 const len_t n = x.
ncols();
517 mpimat<REAL> QY(x.get_grid(), m, 2*k, x.bf_rows(), x.bf_cols());
518 mpimat<REAL> B(x.get_grid(), 2*k, n, x.bf_rows(), x.bf_cols());
521 rsvd_A(seed, k, q, x, QY);
524 matmult(
true,
false, (REAL)1.0, QY, x, B);
531 matmult(
false,
false, (REAL)1.0, QY, uB, u);