5 #ifndef FML_CPU_LINALG_LINALG_INVERT_H
6 #define FML_CPU_LINALG_LINALG_INVERT_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
16 #include "../internals/cpu_utils.hh"
18 #include "../cpumat.hh"
19 #include "../cpuvec.hh"
21 #include "internals/lapack.hh"
45 template <
typename REAL>
48 const len_t n = x.
nrows();
50 throw std::runtime_error(
"'x' must be a square matrix");
56 linalgutils::check_info(info,
"getrf");
61 int lwork = (int) tmp;
64 lapack::getri(n, x.
data_ptr(), n, p.
data_ptr(), work.data_ptr(), lwork, &info);
65 linalgutils::check_info(info,
"getri");
86 template <
typename REAL>
90 throw std::runtime_error(
"'x' must be a square matrix");
92 const len_t n = x.
nrows();
95 char uplo = (upper ?
'U' :
'L');
96 char diag = (unit_diag ?
'U' :
'N');
98 linalgutils::check_info(info,
"trtri");
100 uplo = (uplo ==
'U' ?
'L' :
'U');
101 cpu_utils::tri2zero(uplo,
false, n, n, x.
data_ptr(), n);