5 #ifndef FML_GPU_LINALG_SVD_H
6 #define FML_GPU_LINALG_SVD_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpu_utils.hh"
17 #include "../internals/gpuscalar.hh"
18 #include "../internals/kernelfuns.hh"
21 #include "../gpumat.hh"
22 #include "../gpuvec.hh"
24 #include "crossprod.hh"
37 template <
typename REAL>
38 int svd_internals(
const int nu,
const int nv, gpumat<REAL> &x, gpuvec<REAL> &s, gpumat<REAL> &u, gpumat<REAL> &vt)
40 auto c = x.get_card();
42 const len_t m = x.nrows();
43 const len_t n = x.ncols();
44 const len_t minmn = std::min(m, n);
48 signed char jobu, jobvt;
49 if (nu == 0 && nv == 0)
64 gpulapack_status_t check = gpulapack::gesvd_buflen(c->lapack_handle(), m, n,
65 x.data_ptr(), &lwork);
66 gpulapack::err::check_ret(check,
"gesvd_bufferSize");
68 gpuvec<REAL> work(c, lwork);
69 gpuvec<REAL> rwork(c, minmn-1);
72 gpuscalar<int> info_device(c, info);
74 check = gpulapack::gesvd(c->lapack_handle(), jobu, jobvt, m, n, x.data_ptr(),
75 m, s.data_ptr(), u.data_ptr(), m, vt.data_ptr(), minmn, work.data_ptr(),
76 lwork, rwork.data_ptr(), info_device.data_ptr());
78 info_device.get_val(&info);
79 gpulapack::err::check_ret(check,
"gesvd");
105 template <
typename REAL>
108 err::check_card(x, s);
111 int info = svd_internals(0, 0, x, s, ignored, ignored);
112 fml::linalgutils::check_info(info,
"gesvd");
116 template <
typename REAL>
119 err::check_card(x, s, u, vt);
123 int info = svd_internals(1, 1, x, s, u, vt);
124 fml::linalgutils::check_info(info,
"gesvd");
130 int info = svd_internals(1, 1, tx, s, v, u);
132 fml::linalgutils::check_info(info,
"gesvd");
140 template <
typename REAL>
143 const len_t m = x.
nrows();
144 const len_t n = x.
ncols();
146 throw std::runtime_error(
"'x' must have more rows than cols");
148 auto c = x.get_card();
152 qr_internals(
false, x, qraux, work);
158 int info = svd_internals(1, 1, R, s, u_R, vt);
159 fml::linalgutils::check_info(info,
"gesvd");
162 qr_Q(x, qraux, u, work);
164 matmult(
false,
false, (REAL)1.0, u, u_R, x);
168 template <
typename REAL>
169 void tssvd(gpumat<REAL> &x, gpuvec<REAL> &s)
171 const len_t m = x.nrows();
172 const len_t n = x.ncols();
174 throw std::runtime_error(
"'x' must have more rows than cols");
176 auto c = x.get_card();
179 gpuvec<REAL> qraux(c);
180 gpuvec<REAL> work(c);
181 qr_internals(
false, x, qraux, work);
183 fml::gpu_utils::tri2zero(
'L',
false, n, n, x.data_ptr(), m);
186 gpulapack_status_t check = gpulapack::gesvd_buflen(c->lapack_handle(), n, n,
187 x.data_ptr(), &lwork);
188 gpulapack::err::check_ret(check,
"gesvd_bufferSize");
190 if (lwork > work.size())
192 if (m-1 > qraux.size())
196 gpuscalar<int> info_device(c, info);
198 check = gpulapack::gesvd(c->lapack_handle(),
'N',
'N', n, n, x.data_ptr(),
199 m, s.data_ptr(), NULL, m, NULL, 1, work.data_ptr(), lwork,
200 qraux.data_ptr(), info_device.data_ptr());
202 info_device.get_val(&info);
203 gpulapack::err::check_ret(check,
"gesvd");
204 fml::linalgutils::check_info(info,
"gesvd");
209 template <
typename REAL>
210 void sfsvd(gpumat<REAL> &x, gpuvec<REAL> &s, gpumat<REAL> &u, gpumat<REAL> &vt)
212 const len_t m = x.nrows();
213 const len_t n = x.ncols();
215 throw std::runtime_error(
"'x' must have more cols than rows");
217 gpumat<REAL> tx =
xpose(x);
218 gpumat<REAL> v(x.get_card());
223 template <
typename REAL>
224 void sfsvd(gpumat<REAL> &x, gpuvec<REAL> &s)
226 const len_t m = x.nrows();
227 const len_t n = x.ncols();
229 throw std::runtime_error(
"'x' must have more cols than rows");
269 template <
typename REAL>
272 err::check_card(x, s, u, vt);
283 template <
typename REAL>
286 err::check_card(x, s);
300 template <
typename REAL>
301 __global__
void kernel_sweep_cols_div(
const len_t m,
const len_t n, REAL *data,
const REAL *v)
303 int i = blockDim.x*blockIdx.x + threadIdx.x;
304 int j = blockDim.y*blockIdx.y + threadIdx.y;
307 data[i + m*j] /= v[j];
336 template <
typename REAL>
339 err::check_card(x, s, u, vt);
341 const len_t m = x.
nrows();
342 const len_t n = x.
ncols();
343 const len_t minmn = std::min(m, n);
345 auto c = x.get_card();
365 auto sgrid = s.get_griddim();
366 auto sblock = s.get_blockdim();
367 fml::kernelfuns::kernel_root_abs<<<sgrid, sblock>>>(s.
size(), s.
data_ptr());
375 auto xgrid = x.get_griddim();
376 auto xblock = x.get_blockdim();
377 kernel_sweep_cols_div<<<xgrid, xblock>>>(minmn, minmn, ev_d, s.
data_ptr());
381 matmult(
false,
false, (REAL)1.0, x, vt, u);
385 matmult(
true,
false, (REAL)1.0, cp, x, vt);
389 template <
typename REAL>
392 err::check_card(x, s);
394 const len_t m = x.
nrows();
395 const len_t n = x.
ncols();
397 auto c = x.get_card();
409 fml::kernelfuns::kernel_root_abs<<<s.get_griddim(), s.get_blockdim()>>>(s.
size(), s.
data_ptr());
416 template <
typename REAL>
417 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
gpumat<REAL> &x,
420 const len_t m = x.
nrows();
421 const len_t n = x.
ncols();
424 omega.fill_runif(seed);
436 matmult(
false,
false, (REAL)1.0, x, omega, Y);
437 qr_internals(
false, Y, qraux, work);
438 qr_Q(Y, qraux, QY, work);
440 for (
int i=0; i<q; i++)
442 matmult(
true,
false, (REAL)1.0, x, QY, Z);
443 qr_internals(
false, Z, qraux, work);
444 qr_Q(Z, qraux, QZ, work);
446 matmult(
false,
false, (REAL)1.0, x, QZ, Y);
447 qr_internals(
false, Y, qraux, work);
448 qr_Q(Y, qraux, QY, work);
474 template <
typename REAL>
478 const len_t m = x.
nrows();
479 const len_t n = x.
ncols();
485 rsvd_A(seed, k, q, x, QY);
488 matmult(
true,
false, (REAL)1.0, QY, x, B);
497 template <
typename REAL>
501 const len_t m = x.
nrows();
502 const len_t n = x.
ncols();
508 rsvd_A(seed, k, q, x, QY);
511 matmult(
true,
false, (REAL)1.0, QY, x, B);
518 matmult(
false,
false, (REAL)1.0, QY, uB, u);