5 #ifndef FML_MPI_LINALG_INVERT_H
6 #define FML_MPI_LINALG_INVERT_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/mpi_utils.hh"
17 #include "../mpimat.hh"
19 #include "internals/scalapack.hh"
46 template <
typename REAL>
50 throw std::runtime_error(
"'x' must be a square matrix");
56 linalgutils::check_info(info,
"getrf");
59 const len_t n = x.
nrows();
62 scalapack::getri(n, x.
data_ptr(), x.desc_ptr(), p.
data_ptr(), &tmp, -1, &liwork, -1, &info);
63 int lwork = std::max(1, (
int)tmp);
67 scalapack::getri(n, x.
data_ptr(), x.desc_ptr(), p.
data_ptr(), work.data_ptr(), lwork, iwork.
data_ptr(), liwork, &info);
68 linalgutils::check_info(info,
"getri");
89 template <
typename REAL>
93 throw std::runtime_error(
"'x' must be a square matrix");
95 const len_t n = x.
nrows();
98 char uplo = (upper ?
'U' :
'L');
99 char diag = (unit_diag ?
'U' :
'N');
100 scalapack::trtri(uplo, diag, n, x.
data_ptr(), x.desc_ptr(), &info);
101 linalgutils::check_info(info,
"trtri");
103 uplo = (uplo ==
'U' ?
'L' :
'U');
104 mpi_utils::tri2zero(uplo,
false, x.get_grid(), n, n, x.
data_ptr(), x.desc_ptr());