5 #ifndef FML_CPU_LINALG_LINALG_SVD_H
6 #define FML_CPU_LINALG_LINALG_SVD_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
17 #include "../cpumat.hh"
18 #include "../cpuvec.hh"
20 #include "linalg_blas.hh"
21 #include "linalg_eigen.hh"
22 #include "linalg_qr.hh"
31 template <
typename REAL>
32 int svd_internals(
const int nu,
const int nv, cpumat<REAL> &x, cpuvec<REAL> &s,
33 cpumat<REAL> &u, cpumat<REAL> &vt)
39 const len_t m = x.nrows();
40 const len_t n = x.ncols();
41 const len_t minmn = std::min(m, n);
45 if (nu == 0 && nv == 0)
50 else if (nu <= minmn && nv <= minmn)
64 cpuvec<int> iwork(8*minmn);
67 fml::lapack::gesdd(jobz, m, n, x.data_ptr(), m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), ldvt, &tmp, -1, iwork.data_ptr(), &info);
68 int lwork = (int) tmp;
69 cpuvec<REAL> work(lwork);
71 fml::lapack::gesdd(jobz, m, n, x.data_ptr(), m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), ldvt, work.data_ptr(), lwork, iwork.data_ptr(), &info);
96 template <
typename REAL>
100 int info = svd_internals(0, 0, x, s, ignored, ignored);
101 fml::linalgutils::check_info(info,
"gesdd");
105 template <
typename REAL>
108 int info = svd_internals(1, 1, x, s, u, vt);
109 fml::linalgutils::check_info(info,
"gesdd");
116 template <
typename REAL>
119 const len_t m = x.
nrows();
120 const len_t n = x.
ncols();
122 throw std::runtime_error(
"'x' must have more rows than cols");
126 qr_internals(
false, x, qraux, work);
137 qr_Q(x, qraux, u, work);
139 matmult(
false,
false, (REAL)1.0, u, u_R, x);
143 template <
typename REAL>
144 void tssvd(cpumat<REAL> &x, cpuvec<REAL> &s)
146 const len_t m = x.nrows();
147 const len_t n = x.ncols();
149 throw std::runtime_error(
"'x' must have more rows than cols");
153 cpuvec<REAL> qraux(n);
154 cpuvec<REAL> work(m);
155 qr_internals(
false, x, qraux, work);
157 fml::cpu_utils::tri2zero(
'L',
false, n, n, x.data_ptr(), m);
160 cpuvec<int> iwork(8*n);
163 fml::lapack::gesdd(
'N', n, n, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
164 1, &tmp, -1, iwork.data_ptr(), &info);
165 int lwork = (int) tmp;
166 if (lwork > work.size())
169 fml::lapack::gesdd(
'N', n, n, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
170 1, work.data_ptr(), lwork, iwork.data_ptr(), &info);
171 fml::linalgutils::check_info(info,
"gesdd");
176 template <
typename REAL>
177 void sfsvd(cpumat<REAL> &x, cpuvec<REAL> &s, cpumat<REAL> &u, cpumat<REAL> &vt)
179 const len_t m = x.nrows();
180 const len_t n = x.ncols();
182 throw std::runtime_error(
"'x' must have more cols than rows");
186 lq_internals(x, lqaux, work);
188 cpumat<REAL> L(m, m);
191 cpumat<REAL> vt_L(m, m);
197 lq_Q(x, lqaux, vt, work);
199 matmult(
false,
false, (REAL)1.0, vt_L, vt, x);
203 template <
typename REAL>
204 void sfsvd(cpumat<REAL> &x, cpuvec<REAL> &s)
206 const len_t m = x.nrows();
207 const len_t n = x.ncols();
209 throw std::runtime_error(
"'x' must have more cols than rows");
215 lq_internals(x, lqaux, work);
217 fml::cpu_utils::tri2zero(
'U',
false, m, m, x.data_ptr(), m);
220 cpuvec<int> iwork(8*n);
223 fml::lapack::gesdd(
'N', m, m, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
224 1, &tmp, -1, iwork.data_ptr(), &info);
225 int lwork = (int) tmp;
226 if (lwork > work.size())
229 fml::lapack::gesdd(
'N', m, m, x.data_ptr(), m, s.data_ptr(), NULL, m, NULL,
230 1, work.data_ptr(), lwork, iwork.data_ptr(), &info);
231 fml::linalgutils::check_info(info,
"gesdd");
268 template <
typename REAL>
280 template <
typename REAL>
318 template <
typename REAL>
322 const len_t m = x.
nrows();
323 const len_t n = x.
ncols();
324 const len_t minmn = std::min(m, n);
346 for (len_t i=0; i<s.
size(); i++)
347 s_d[i] = sqrt(fabs(s_d[i]));
355 #pragma omp parallel for if(minmn*minmn > fml::omp::OMP_MIN_SIZE)
356 for (len_t j=0; j<minmn; j++)
359 for (len_t i=0; i<minmn; i++)
360 ev_d[i + minmn*j] /= s_d[j];
365 matmult(
false,
false, (REAL)1.0, x, vt, u);
369 matmult(
true,
false, (REAL)1.0, cp, x, vt);
373 template <
typename REAL>
376 const len_t m = x.
nrows();
377 const len_t n = x.
ncols();
391 for (len_t i=0; i<s.
size(); i++)
392 s_d[i] = sqrt(fabs(s_d[i]));
399 template <
typename REAL>
400 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
cpumat<REAL> &x,
403 const len_t m = x.
nrows();
404 const len_t n = x.
ncols();
407 omega.fill_runif(seed);
419 matmult(
false,
false, (REAL)1.0, x, omega, Y);
420 qr_internals(
false, Y, qraux, work);
421 qr_Q(Y, qraux, QY, work);
423 for (
int i=0; i<q; i++)
425 matmult(
true,
false, (REAL)1.0, x, QY, Z);
426 qr_internals(
false, Z, qraux, work);
427 qr_Q(Z, qraux, QZ, work);
429 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
430 qr_internals(
false, Y, qraux, work);
431 qr_Q(Y, qraux, QY, work);
457 template <
typename REAL>
461 const len_t m = x.
nrows();
462 const len_t n = x.
ncols();
468 rsvd_A(seed, k, q, x, QY);
471 matmult(
true,
false, (REAL)1.0, QY, x, B);
480 template <
typename REAL>
484 const len_t m = x.
nrows();
485 const len_t n = x.
ncols();
491 rsvd_A(seed, k, q, x, QY);
494 matmult(
true,
false, (REAL)1.0, QY, x, B);
501 matmult(
false,
false, (REAL)1.0, QY, uB, u);