fml  0.1-0
Fused Matrix Library
invert.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_INVERT_H
6 #define FML_GPU_LINALG_INVERT_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../internals/gpuscalar.hh"
17 #include "../internals/gpu_utils.hh"
18 
19 #include "../copy.hh"
20 #include "../gpumat.hh"
21 #include "../gpuvec.hh"
22 
23 #include "lu.hh"
24 
25 
26 namespace fml
27 {
28 namespace linalg
29 {
48  template <typename REAL>
50  {
51  if (!x.is_square())
52  throw std::runtime_error("'x' must be a square matrix");
53 
54  // Factor x = LU
55  auto c = x.get_card();
56  gpuvec<int> p(c);
57  int info;
58  lu(x, p, info);
59  linalgutils::check_info(info, "getrf");
60 
61  // Invert
62  const len_t n = x.nrows();
63  const len_t nrhs = n;
64  gpumat<REAL> inv(c, n, nrhs);
65  inv.fill_eye();
66 
67  gpuscalar<int> info_device(c, info);
68 
69  gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N, n,
70  nrhs, x.data_ptr(), n, p.data_ptr(), inv.data_ptr(), n, info_device.data_ptr());
71 
72  info_device.get_val(&info);
73  gpulapack::err::check_ret(check, "getrs");
74  fml::linalgutils::check_info(info, "getrs");
75 
76  copy::gpu2gpu(inv, x);
77  }
78 
79 
80 
100  template <typename REAL>
101  void trinv(const bool upper, const bool unit_diag, gpumat<REAL> &x)
102  {
103  if (!x.is_square())
104  throw std::runtime_error("'x' must be a square matrix");
105 
106  const len_t n = x.nrows();
107  gpumat<REAL> inv(x.get_card(), n, n);
108  inv.fill_eye();
109 
110  gpublas_fillmode_t uplo = (upper ? GPUBLAS_FILL_U : GPUBLAS_FILL_L);
111  gpublas_diagtype_t diag = (unit_diag ? GPUBLAS_DIAG_UNIT : GPUBLAS_DIAG_NON_UNIT);
112 
113  gpublas_status_t check = gpublas::trsm(x.get_card()->blas_handle(),
114  GPUBLAS_SIDE_LEFT, uplo, GPUBLAS_OP_N, diag, n, n, (REAL)1, x.data_ptr(),
115  n, inv.data_ptr(), n);
116 
117  gpublas::err::check_ret(check, "trsm");
118  copy::gpu2gpu(inv, x);
119 
120  char cuplo = (upper ? 'L' : 'U');
121  gpu_utils::tri2zero(cuplo, false, n, n, x.data_ptr(), n);
122  }
123 }
124 }
125 
126 
127 #endif
fml::unimat::is_square
bool is_square() const
Is the matrix square?
Definition: unimat.hh:34
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::gpuvec
Vector class for data held on a single GPU.
Definition: gpuvec.hh:32
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::linalg::lu
void lu(cpumat< REAL > &x, cpuvec< int > &p, int &info)
Computes the PLU factorization with partial pivoting.
Definition: lu.hh:48
fml::copy::gpu2gpu
void gpu2gpu(const gpuvec< REAL_IN > &gpu_in, gpuvec< REAL_OUT > &gpu_out)
Copy data from a GPU object to another.
Definition: copy.hh:203
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml::gpumat::fill_eye
void fill_eye()
Set diagonal entries to 1 and non-diagonal entries to 0.
Definition: gpumat.hh:455
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::trinv
void trinv(const bool upper, const bool unit_diag, cpumat< REAL > &x)
Compute the matrix inverse of a triangular matrix.
Definition: invert.hh:87
fml::linalg::invert
void invert(cpumat< REAL > &x)
Compute the matrix inverse.
Definition: invert.hh:46
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::gpuscalar
Definition: gpuscalar.hh:16