5 #ifndef FML_GPU_LINALG_LINALG_SVD_H
6 #define FML_GPU_LINALG_LINALG_SVD_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpu_utils.hh"
17 #include "../internals/gpuscalar.hh"
18 #include "../internals/kernelfuns.hh"
21 #include "../gpumat.hh"
22 #include "../gpuvec.hh"
24 #include "linalg_blas.hh"
25 #include "linalg_eigen.hh"
26 #include "linalg_qr.hh"
35 template <
typename REAL>
36 int svd_internals(
const int nu,
const int nv, gpumat<REAL> &x, gpuvec<REAL> &s, gpumat<REAL> &u, gpumat<REAL> &vt)
38 auto c = x.get_card();
40 const len_t m = x.nrows();
41 const len_t n = x.ncols();
42 const len_t minmn = std::min(m, n);
46 signed char jobu, jobvt;
47 if (nu == 0 && nv == 0)
62 gpulapack_status_t check = gpulapack::gesvd_buflen(c->lapack_handle(), m, n,
63 x.data_ptr(), &lwork);
64 gpulapack::err::check_ret(check,
"gesvd_bufferSize");
66 gpuvec<REAL> work(c, lwork);
67 gpuvec<REAL> rwork(c, minmn-1);
70 gpuscalar<int> info_device(c, info);
72 check = gpulapack::gesvd(c->lapack_handle(), jobu, jobvt, m, n, x.data_ptr(),
73 m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), minmn, work.data_ptr(),
74 lwork, rwork.data_ptr(), info_device.data_ptr());
76 info_device.get_val(&info);
77 gpulapack::err::check_ret(check,
"gesvd");
103 template <
typename REAL>
106 err::check_card(x, s);
109 int info = svd_internals(0, 0, x, s, ignored, ignored);
110 fml::linalgutils::check_info(info,
"gesvd");
114 template <
typename REAL>
117 err::check_card(x, s, u, vt);
121 int info = svd_internals(1, 1, x, s, u, vt);
122 fml::linalgutils::check_info(info,
"gesvd");
128 int info = svd_internals(1, 1, tx, s, v, u);
130 fml::linalgutils::check_info(info,
"gesvd");
138 template <
typename REAL>
141 const len_t m = x.
nrows();
142 const len_t n = x.
ncols();
144 throw std::runtime_error(
"'x' must have more rows than cols");
146 auto c = x.get_card();
150 qr_internals(
false, x, qraux, work);
156 int info = svd_internals(1, 1, R, s, u_R, vt);
157 fml::linalgutils::check_info(info,
"gesvd");
160 qr_Q(x, qraux, u, work);
162 matmult(
false,
false, (REAL)1.0, u, u_R, x);
166 template <
typename REAL>
167 void tssvd(gpumat<REAL> &x, gpuvec<REAL> &s)
169 const len_t m = x.nrows();
170 const len_t n = x.ncols();
172 throw std::runtime_error(
"'x' must have more rows than cols");
174 auto c = x.get_card();
177 gpuvec<REAL> qraux(c);
178 gpuvec<REAL> work(c);
179 qr_internals(
false, x, qraux, work);
181 fml::gpu_utils::tri2zero(
'L',
false, n, n, x.data_ptr(), m);
184 gpulapack_status_t check = gpulapack::gesvd_buflen(c->lapack_handle(), n, n,
185 x.data_ptr(), &lwork);
186 gpulapack::err::check_ret(check,
"gesvd_bufferSize");
188 if (lwork > work.size())
190 if (m-1 > qraux.size())
194 gpuscalar<int> info_device(c, info);
196 check = gpulapack::gesvd(c->lapack_handle(),
'N',
'N', n, n, x.data_ptr(),
197 m, s.data_ptr(), NULL, m, NULL, 1, work.data_ptr(), lwork,
198 qraux.data_ptr(), info_device.data_ptr());
200 info_device.get_val(&info);
201 gpulapack::err::check_ret(check,
"gesvd");
202 fml::linalgutils::check_info(info,
"gesvd");
207 template <
typename REAL>
208 void sfsvd(gpumat<REAL> &x, gpuvec<REAL> &s, gpumat<REAL> &u, gpumat<REAL> &vt)
210 const len_t m = x.nrows();
211 const len_t n = x.ncols();
213 throw std::runtime_error(
"'x' must have more cols than rows");
215 gpumat<REAL> tx =
xpose(x);
216 gpumat<REAL> v(x.get_card());
221 template <
typename REAL>
222 void sfsvd(gpumat<REAL> &x, gpuvec<REAL> &s)
224 const len_t m = x.nrows();
225 const len_t n = x.ncols();
227 throw std::runtime_error(
"'x' must have more cols than rows");
267 template <
typename REAL>
270 err::check_card(x, s, u, vt);
281 template <
typename REAL>
284 err::check_card(x, s);
320 template <
typename REAL>
323 err::check_card(x, s, u, vt);
325 const len_t m = x.
nrows();
326 const len_t n = x.
ncols();
327 const len_t minmn = std::min(m, n);
329 auto c = x.get_card();
349 auto sgrid = s.get_griddim();
350 auto sblock = s.get_blockdim();
351 fml::kernelfuns::kernel_root_abs<<<sgrid, sblock>>>(s.
size(), s.
data_ptr());
359 auto xgrid = x.get_griddim();
360 auto xblock = x.get_blockdim();
361 fml::kernelfuns::kernel_sweep_cols_div<<<xgrid, xblock>>>(minmn, minmn,
366 matmult(
false,
false, (REAL)1.0, x, vt, u);
370 matmult(
true,
false, (REAL)1.0, cp, x, vt);
374 template <
typename REAL>
377 err::check_card(x, s);
379 const len_t m = x.
nrows();
380 const len_t n = x.
ncols();
382 auto c = x.get_card();
394 fml::kernelfuns::kernel_root_abs<<<s.get_griddim(), s.get_blockdim()>>>(s.
size(), s.
data_ptr());
401 template <
typename REAL>
402 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
gpumat<REAL> &x,
405 const len_t m = x.
nrows();
406 const len_t n = x.
ncols();
409 omega.fill_runif(seed);
421 matmult(
false,
false, (REAL)1.0, x, omega, Y);
422 qr_internals(
false, Y, qraux, work);
423 qr_Q(Y, qraux, QY, work);
425 for (
int i=0; i<q; i++)
427 matmult(
true,
false, (REAL)1.0, x, QY, Z);
428 qr_internals(
false, Z, qraux, work);
429 qr_Q(Z, qraux, QZ, work);
431 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
432 qr_internals(
false, Y, qraux, work);
433 qr_Q(Y, qraux, QY, work);
459 template <
typename REAL>
463 const len_t m = x.
nrows();
464 const len_t n = x.
ncols();
470 rsvd_A(seed, k, q, x, QY);
473 matmult(
true,
false, (REAL)1.0, QY, x, B);
482 template <
typename REAL>
486 const len_t m = x.
nrows();
487 const len_t n = x.
ncols();
493 rsvd_A(seed, k, q, x, QY);
496 matmult(
true,
false, (REAL)1.0, QY, x, B);
503 matmult(
false,
false, (REAL)1.0, QY, uB, u);