5 #ifndef FML_GPU_LINALG_INVERT_H
6 #define FML_GPU_LINALG_INVERT_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpuscalar.hh"
17 #include "../internals/gpu_utils.hh"
20 #include "../gpumat.hh"
21 #include "../gpuvec.hh"
48 template <
typename REAL>
52 throw std::runtime_error(
"'x' must be a square matrix");
55 auto c = x.get_card();
59 linalgutils::check_info(info,
"getrf");
62 const len_t n = x.
nrows();
69 gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N, n,
72 info_device.get_val(&info);
73 gpulapack::err::check_ret(check,
"getrs");
74 fml::linalgutils::check_info(info,
"getrs");
100 template <
typename REAL>
104 throw std::runtime_error(
"'x' must be a square matrix");
106 const len_t n = x.
nrows();
110 gpublas_fillmode_t uplo = (upper ? GPUBLAS_FILL_U : GPUBLAS_FILL_L);
111 gpublas_diagtype_t diag = (unit_diag ? GPUBLAS_DIAG_UNIT : GPUBLAS_DIAG_NON_UNIT);
113 gpublas_status_t check = gpublas::trsm(x.get_card()->blas_handle(),
114 GPUBLAS_SIDE_LEFT, uplo, GPUBLAS_OP_N, diag, n, n, (REAL)1, x.
data_ptr(),
117 gpublas::err::check_ret(check,
"trsm");
120 char cuplo = (upper ?
'L' :
'U');
121 gpu_utils::tri2zero(cuplo,
false, n, n, x.
data_ptr(), n);