5 #ifndef FML_CPU_LINALG_LINALG_SVD_H
6 #define FML_CPU_LINALG_LINALG_SVD_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
17 #include "../cpumat.hh"
18 #include "../cpuvec.hh"
20 #include "internals/blas.hh"
21 #include "internals/lapack.hh"
22 #include "crossprod.hh"
35 template <
typename REAL>
36 int svd_internals(
const int nu,
const int nv, cpumat<REAL> &x, cpuvec<REAL> &s,
37 cpumat<REAL> &u, cpumat<REAL> &vt)
43 const len_t m = x.nrows();
44 const len_t n = x.ncols();
45 const len_t minmn = std::min(m, n);
49 if (nu == 0 && nv == 0)
54 else if (nu <= minmn && nv <= minmn)
68 cpuvec<int> iwork(8*minmn);
71 fml::lapack::gesdd(jobz, m, n, x.data_ptr(), m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), ldvt, &tmp, -1, iwork.data_ptr(), &info);
72 int lwork = (int) tmp;
73 cpuvec<REAL> work(lwork);
75 fml::lapack::gesdd(jobz, m, n, x.data_ptr(), m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), ldvt, work.data_ptr(), lwork, iwork.data_ptr(), &info);
100 template <
typename REAL>
104 int info = svd_internals(0, 0, x, s, ignored, ignored);
105 fml::linalgutils::check_info(info,
"gesdd");
109 template <
typename REAL>
112 int info = svd_internals(1, 1, x, s, u, vt);
113 fml::linalgutils::check_info(info,
"gesdd");
120 template <
typename REAL>
123 const len_t m = x.
nrows();
124 const len_t n = x.
ncols();
126 throw std::runtime_error(
"'x' must have more rows than cols");
130 qr_internals(
false, x, qraux, work);
141 qr_Q(x, qraux, u, work);
143 matmult(
false,
false, (REAL)1.0, u, u_R, x);
147 template <
typename REAL>
148 void tssvd(cpumat<REAL> &x, cpuvec<REAL> &s)
150 const len_t m = x.nrows();
151 const len_t n = x.ncols();
153 throw std::runtime_error(
"'x' must have more rows than cols");
157 cpuvec<REAL> qraux(n);
158 cpuvec<REAL> work(m);
159 qr_internals(
false, x, qraux, work);
161 fml::cpu_utils::tri2zero(
'L',
false, n, n, x.data_ptr(), m);
164 cpuvec<int> iwork(8*n);
167 fml::lapack::gesdd(
'N', n, n, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
168 1, &tmp, -1, iwork.data_ptr(), &info);
169 int lwork = (int) tmp;
170 if (lwork > work.size())
173 fml::lapack::gesdd(
'N', n, n, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
174 1, work.data_ptr(), lwork, iwork.data_ptr(), &info);
175 fml::linalgutils::check_info(info,
"gesdd");
180 template <
typename REAL>
181 void sfsvd(cpumat<REAL> &x, cpuvec<REAL> &s, cpumat<REAL> &u, cpumat<REAL> &vt)
183 const len_t m = x.nrows();
184 const len_t n = x.ncols();
186 throw std::runtime_error(
"'x' must have more cols than rows");
190 lq_internals(x, lqaux, work);
192 cpumat<REAL> L(m, m);
195 cpumat<REAL> vt_L(m, m);
201 lq_Q(x, lqaux, vt, work);
203 matmult(
false,
false, (REAL)1.0, vt_L, vt, x);
207 template <
typename REAL>
208 void sfsvd(cpumat<REAL> &x, cpuvec<REAL> &s)
210 const len_t m = x.nrows();
211 const len_t n = x.ncols();
213 throw std::runtime_error(
"'x' must have more cols than rows");
219 lq_internals(x, lqaux, work);
221 fml::cpu_utils::tri2zero(
'U',
false, m, m, x.data_ptr(), m);
224 cpuvec<int> iwork(8*n);
227 fml::lapack::gesdd(
'N', m, m, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
228 1, &tmp, -1, iwork.data_ptr(), &info);
229 int lwork = (int) tmp;
230 if (lwork > work.size())
233 fml::lapack::gesdd(
'N', m, m, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
234 1, work.data_ptr(), lwork, iwork.data_ptr(), &info);
235 fml::linalgutils::check_info(info,
"gesdd");
272 template <
typename REAL>
284 template <
typename REAL>
322 template <
typename REAL>
326 const len_t m = x.
nrows();
327 const len_t n = x.
ncols();
328 const len_t minmn = std::min(m, n);
350 for (len_t i=0; i<s.
size(); i++)
351 s_d[i] = sqrt(fabs(s_d[i]));
359 #pragma omp parallel for if(minmn*minmn > fml::omp::OMP_MIN_SIZE)
360 for (len_t j=0; j<minmn; j++)
363 for (len_t i=0; i<minmn; i++)
364 ev_d[i + minmn*j] /= s_d[j];
369 matmult(
false,
false, (REAL)1.0, x, vt, u);
373 matmult(
true,
false, (REAL)1.0, cp, x, vt);
377 template <
typename REAL>
380 const len_t m = x.
nrows();
381 const len_t n = x.
ncols();
395 for (len_t i=0; i<s.
size(); i++)
396 s_d[i] = sqrt(fabs(s_d[i]));
403 template <
typename REAL>
404 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
cpumat<REAL> &x,
407 const len_t m = x.
nrows();
408 const len_t n = x.
ncols();
411 omega.fill_runif(seed);
423 matmult(
false,
false, (REAL)1.0, x, omega, Y);
424 qr_internals(
false, Y, qraux, work);
425 qr_Q(Y, qraux, QY, work);
427 for (
int i=0; i<q; i++)
429 matmult(
true,
false, (REAL)1.0, x, QY, Z);
430 qr_internals(
false, Z, qraux, work);
431 qr_Q(Z, qraux, QZ, work);
433 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
434 qr_internals(
false, Y, qraux, work);
435 qr_Q(Y, qraux, QY, work);
461 template <
typename REAL>
465 const len_t m = x.
nrows();
466 const len_t n = x.
ncols();
472 rsvd_A(seed, k, q, x, QY);
475 matmult(
true,
false, (REAL)1.0, QY, x, B);
484 template <
typename REAL>
488 const len_t m = x.
nrows();
489 const len_t n = x.
ncols();
495 rsvd_A(seed, k, q, x, QY);
498 matmult(
true,
false, (REAL)1.0, QY, x, B);
505 matmult(
false,
false, (REAL)1.0, QY, uB, u);