5 #ifndef FML_PAR_GPU_LINALG_QR_H
6 #define FML_PAR_GPU_LINALG_QR_H
10 #include "../parmat.hh"
12 #include "../../../gpu/linalg/linalg_invert.hh"
13 #include "../../../gpu/linalg/linalg_qr.hh"
15 #include "../../../gpu/copy.hh"
18 #include "qr_allreduce.hh"
27 template <
typename REAL>
28 void qr_R(
const int root, parmat_gpu<REAL> &x, gpumat<REAL> &R,
29 gpumat<REAL> &R_local, gpuvec<REAL> &qraux)
31 const len_t n = x.ncols();
37 tsqr::qr_allreduce(root, n, n, R_local.data_ptr(), R.data_ptr(),
38 x.get_comm().get_comm(), R.get_card());
44 template <
typename REAL>
45 void qr_R(
const int root, parmat_gpu<REAL> &x, gpumat<REAL> &R)
47 if (x.nrows() < (len_global_t)x.ncols())
48 throw std::runtime_error(
"impossible dimensions");
50 gpumat<REAL> R_local(R.get_card());
51 gpuvec<REAL> qraux(R.get_card());
53 internals::qr_R(root, x, R, R_local, qraux);
60 template <
typename REAL>
61 void qr_Q(
const parmat_gpu<REAL> &x, parmat_gpu<REAL> &x_cpy,
62 gpumat<REAL> &R, gpumat<REAL> &R_local, gpuvec<REAL> &qraux,
66 internals::qr_R(mpi::REDUCE_TO_ALL, x_cpy, R, R_local, qraux);
67 trinv(
true,
false, R);
72 template <
typename REAL>
73 void qr_Q(parmat_gpu<REAL> &x, gpuvec<REAL> &qraux, parmat_gpu<REAL> &Q)
75 gpumat<REAL> R, R_local;
77 qr_R(mpi::REDUCE_TO_ALL, x, R, R_local, qraux);
78 trinv(
true,
false, R);