5 #ifndef FML_MPI_LINALG_LINALG_INVERT_H
6 #define FML_MPI_LINALG_LINALG_INVERT_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
18 #include "../mpimat.hh"
20 #include "linalg_err.hh"
21 #include "linalg_lu.hh"
22 #include "scalapack.hh"
48 template <
typename REAL>
52 throw std::runtime_error(
"'x' must be a square matrix");
58 linalgutils::check_info(info,
"getrf");
61 const len_t n = x.
nrows();
64 scalapack::getri(n, x.
data_ptr(), x.desc_ptr(), p.
data_ptr(), &tmp, -1, &liwork, -1, &info);
65 int lwork = std::max(1, (
int)tmp);
69 scalapack::getri(n, x.
data_ptr(), x.desc_ptr(), p.
data_ptr(), work.data_ptr(), lwork, iwork.
data_ptr(), liwork, &info);
70 linalgutils::check_info(info,
"getri");
91 template <
typename REAL>
95 throw std::runtime_error(
"'x' must be a square matrix");
97 const len_t n = x.
nrows();
100 char uplo = (upper ?
'U' :
'L');
101 char diag = (unit_diag ?
'U' :
'N');
102 scalapack::trtri(uplo, diag, n, x.
data_ptr(), x.desc_ptr(), &info);
103 linalgutils::check_info(info,
"trtri");
105 uplo = (uplo ==
'U' ?
'L' :
'U');
106 mpi_utils::tri2zero(uplo,
false, x.get_grid(), n, n, x.
data_ptr(), x.desc_ptr());