5 #ifndef FML_GPU_LINALG_LINALG_QR_H
6 #define FML_GPU_LINALG_LINALG_QR_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpu_utils.hh"
17 #include "../internals/gpuscalar.hh"
19 #include "../gpumat.hh"
20 #include "../gpuvec.hh"
22 #include "linalg_err.hh"
23 #include "linalg_blas.hh"
32 template <
typename REAL>
33 void qr_internals(
const bool pivot, gpumat<REAL> &x, gpuvec<REAL> &qraux, gpuvec<REAL> &work)
36 throw std::runtime_error(
"pivoting not supported at this time");
38 const len_t m = x.nrows();
39 const len_t n = x.ncols();
40 const len_t minmn = std::min(m, n);
41 auto c = x.get_card();
46 gpulapack_status_t check = gpulapack::geqrf_buflen(c->lapack_handle(), m,
47 n, x.data_ptr(), m, &lwork);
48 gpulapack::err::check_ret(check,
"geqrf_bufferSize");
50 if (lwork > work.size())
54 gpuscalar<int> info_device(c, info);
56 check = gpulapack::geqrf(c->lapack_handle(), m, n, x.data_ptr(), m,
57 qraux.data_ptr(), work.data_ptr(), lwork, info_device.data_ptr());
59 info_device.get_val(&info);
60 gpulapack::err::check_ret(check,
"syevd");
61 fml::linalgutils::check_info(info,
"geqrf");
87 template <
typename REAL>
90 err::check_card(x, qraux);
92 qr_internals(pivot, x, qraux, work);
114 template <
typename REAL>
118 err::check_card(QR, qraux, Q, work);
120 const len_t m = QR.
nrows();
121 const len_t n = QR.
ncols();
122 const len_t minmn = std::min(m, n);
124 auto c = QR.get_card();
127 gpulapack_status_t check = gpulapack::orgqr_buflen(c->lapack_handle(),
128 m, minmn, minmn, QR.
data_ptr(), m, qraux.data_ptr(), &lwork);
130 if (lwork > work.size())
141 check = gpulapack::orgqr(c->lapack_handle(), m, minmn, minmn, Q.
data_ptr(),
142 m, qraux.data_ptr(), work.data_ptr(), lwork, info_device.data_ptr());
144 info_device.get_val(&info);
145 gpulapack::err::check_ret(check,
"ormqr");
146 fml::linalgutils::check_info(info,
"ormqr");
166 template <
typename REAL>
169 err::check_card(QR, R);
171 const len_t m = QR.
nrows();
172 const len_t n = QR.
ncols();
173 const len_t minmn = std::min(m, n);
204 template <
typename REAL>
207 err::check_card(x, lqaux);
213 qr_internals(
false, tx, lqaux, work);
236 template <
typename REAL>
239 err::check_card(LQ, L);
241 const len_t m = LQ.
nrows();
242 const len_t n = LQ.
ncols();
243 const len_t minmn = std::min(m, n);
270 template <
typename REAL>
274 err::check_card(LQ, lqaux, Q, work);
279 const len_t m = LQ.
nrows();
280 const len_t n = LQ.
ncols();
281 const len_t minmn = std::min(m, n);
283 auto c = QR.get_card();
286 gpulapack_status_t check = gpulapack::ormqr_buflen(c->lapack_handle(),
287 GPUBLAS_SIDE_RIGHT, GPUBLAS_OP_T, minmn, n, n, QR.
data_ptr(), QR.
nrows(),
290 if (lwork > work.size())
299 check = gpulapack::ormqr(c->lapack_handle(), GPUBLAS_SIDE_RIGHT,
301 Q.
data_ptr(), minmn, work.data_ptr(), lwork, info_device.data_ptr());
303 info_device.get_val(&info);
304 gpulapack::err::check_ret(check,
"ormqr");
305 fml::linalgutils::check_info(info,
"ormqr");