fml  0.1-0
Fused Matrix Library
linalg_invert.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_LINALG_INVERT_H
6 #define FML_MPI_LINALG_LINALG_INVERT_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
14 
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
17 
18 #include "../mpimat.hh"
19 
20 #include "linalg_err.hh"
21 #include "linalg_lu.hh"
22 #include "scalapack.hh"
23 
24 
25 namespace fml
26 {
27 namespace linalg
28 {
48  template <typename REAL>
50  {
51  if (!x.is_square())
52  throw std::runtime_error("'x' must be a square matrix");
53 
54  // Factor x = LU
55  cpuvec<int> p;
56  int info;
57  lu(x, p, info);
58  linalgutils::check_info(info, "getrf");
59 
60  // Invert
61  const len_t n = x.nrows();
62  REAL tmp;
63  int liwork;
64  scalapack::getri(n, x.data_ptr(), x.desc_ptr(), p.data_ptr(), &tmp, -1, &liwork, -1, &info);
65  int lwork = std::max(1, (int)tmp);
66  cpuvec<REAL> work(lwork);
67  cpuvec<int> iwork(liwork);
68 
69  scalapack::getri(n, x.data_ptr(), x.desc_ptr(), p.data_ptr(), work.data_ptr(), lwork, iwork.data_ptr(), liwork, &info);
70  linalgutils::check_info(info, "getri");
71  }
72 
73 
74 
91  template <typename REAL>
92  void trinv(const bool upper, const bool unit_diag, mpimat<REAL> &x)
93  {
94  if (!x.is_square())
95  throw std::runtime_error("'x' must be a square matrix");
96 
97  const len_t n = x.nrows();
98 
99  int info;
100  char uplo = (upper ? 'U' : 'L');
101  char diag = (unit_diag ? 'U' : 'N');
102  scalapack::trtri(uplo, diag, n, x.data_ptr(), x.desc_ptr(), &info);
103  linalgutils::check_info(info, "trtri");
104 
105  uplo = (uplo == 'U' ? 'L' : 'U');
106  mpi_utils::tri2zero(uplo, false, x.get_grid(), n, n, x.data_ptr(), x.desc_ptr());
107  }
108 }
109 }
110 
111 
112 #endif
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::unimat::is_square
bool is_square() const
Is the matrix square?
Definition: unimat.hh:34
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::linalg::lu
void lu(cpumat< REAL > &x, cpuvec< int > &p, int &info)
Computes the PLU factorization with partial pivoting.
Definition: linalg_lu.hh:48
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::trinv
void trinv(const bool upper, const bool unit_diag, cpumat< REAL > &x)
Compute the matrix inverse of a triangular matrix.
Definition: linalg_invert.hh:87
fml::linalg::invert
void invert(cpumat< REAL > &x)
Compute the matrix inverse.
Definition: linalg_invert.hh:46