5 #ifndef FML_GPU_LINALG_SOLVE_H
6 #define FML_GPU_LINALG_SOLVE_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpuscalar.hh"
18 #include "../gpumat.hh"
19 #include "../gpuvec.hh"
21 #include "internals/err.hh"
31 template <
typename REAL>
32 void solver(gpumat<REAL> &x, len_t ylen, len_t nrhs, REAL *y_d)
34 const len_t n = x.nrows();
36 throw std::runtime_error(
"'x' must be a square matrix");
38 throw std::runtime_error(
"rhs 'y' must be compatible with data matrix 'x'");
41 auto c = x.get_card();
45 fml::linalgutils::check_info(info,
"getrf");
48 gpuscalar<int> info_device(c, info);
50 gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N,
51 n, nrhs, x.data_ptr(), n, p.data_ptr(), y_d, n, info_device.data_ptr());
53 info_device.get_val(&info);
54 gpulapack::err::check_ret(check,
"getrs");
55 fml::linalgutils::check_info(info,
"getrs");
78 template <
typename REAL>
81 err::check_card(x, y);
86 template <
typename REAL>
89 err::check_card(x, y);