5 #ifndef FML_PAR_GPU_LINALG_SVD_H
6 #define FML_PAR_GPU_LINALG_SVD_H
10 #include "../parmat.hh"
15 #include "../../../gpu/linalg/linalg_blas.hh"
16 #include "../../../gpu/linalg/linalg_invert.hh"
17 #include "../../../gpu/linalg/linalg_qr.hh"
18 #include "../../../gpu/linalg/linalg_svd.hh"
20 #include "../../../gpu/copy.hh"
51 template <
typename REAL>
54 const len_t n = x.ncols();
56 auto cp =
crossprod((REAL)1.0, x.data_obj());
57 x.get_comm().allreduce(n*n, cp.data_ptr());
62 template <
typename REAL>
66 const len_t n = x.ncols();
68 auto cp =
crossprod((REAL)1.0, x.data_obj());
69 x.get_comm().allreduce(n*n, cp.data_ptr());
72 auto c = vt.get_card();
76 auto sgrid = s.get_griddim();
77 auto sblock = s.get_blockdim();
78 kernelfuns::kernel_root_abs<<<sgrid, sblock>>>(s.
size(), s_d);
84 auto xgrid = vt.get_griddim();
85 auto xblock = vt.get_blockdim();
86 kernelfuns::kernel_sweep_cols_div<<<xgrid, xblock>>>(n, n, vt_d, s_d);
88 matmult(
false,
false, (REAL)1.0, x.data_obj(), vt, u);
126 template <
typename REAL>
130 qr_R(mpi::REDUCE_TO_ALL, x, R);
135 template <
typename REAL>
139 const len_t n = x.ncols();
140 if (x.nrows() < (len_global_t)n)
141 throw std::runtime_error(
"impossible dimensions");
147 qr(
false, x.data_obj(), qraux);
148 qr_R(x.data_obj(), R_local);
151 tsqr::qr_allreduce(mpi::REDUCE_TO_ALL, n, n, R_local.
data_ptr(), R.
data_ptr(), x.get_comm().get_comm());
155 svd(R_local, s, u_R, vt);
158 u.resize(x.nrows(), x.ncols());
159 trinv(
true,
false, R);
160 matmult(
false,
false, (REAL)1, x_cpy, R, x.data_obj());
161 matmult(
false,
false, (REAL)1, x.data_obj(), u_R, u.data_obj());
168 template <
typename REAL>
169 void rsvd_A(
const uint32_t seed,
const int k,
const int q,
172 const len_global_t m = x.nrows();
173 const len_t n = x.ncols();
174 if (m < (len_global_t)n)
175 throw std::runtime_error(
"must have m>n");
177 throw std::runtime_error(
"must have k<n");
179 auto c = x.data_obj().get_card();
192 omega.fill_runif(seed);
196 internals::qr_Q(Y, Y_tmp, R, R_local, qraux, QY);
198 for (
int i=0; i<q; i++)
201 linalg::qr_internals(
false, Z, qraux, work);
205 internals::qr_Q(Y, Y_tmp, R, R_local, qraux, QY);
232 template <
typename REAL>
233 void rsvd(
const uint32_t seed,
const int k,
const int q,
236 parmat_gpu<REAL> QY(x.get_comm(), s.get_card(), x.nrows(), 2*k, x.nrows_before());
240 internals::rsvd_A(seed, k, q, x, QY);
250 template <
typename REAL>
251 void rsvd(
const uint32_t seed,
const int k,
const int q,
255 parmat_gpu<REAL> QY(x.get_comm(), s.get_card(), x.nrows(), 2*k, x.nrows_before());
259 internals::rsvd_A(seed, k, q, x, QY);