fml  0.1-0
Fused Matrix Library
solve.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_SOLVE_H
6 #define FML_GPU_LINALG_SOLVE_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../internals/gpuscalar.hh"
17 
18 #include "../gpumat.hh"
19 #include "../gpuvec.hh"
20 
21 #include "internals/err.hh"
22 #include "lu.hh"
23 
24 
25 namespace fml
26 {
27 namespace linalg
28 {
29  namespace
30  {
31  template <typename REAL>
32  void solver(gpumat<REAL> &x, len_t ylen, len_t nrhs, REAL *y_d)
33  {
34  const len_t n = x.nrows();
35  if (!x.is_square())
36  throw std::runtime_error("'x' must be a square matrix");
37  if (n != ylen)
38  throw std::runtime_error("rhs 'y' must be compatible with data matrix 'x'");
39 
40  // Factor x = LU
41  auto c = x.get_card();
42  gpuvec<int> p(c);
43  int info;
44  lu(x, p, info);
45  fml::linalgutils::check_info(info, "getrf");
46 
47  // Solve xb = y
48  gpuscalar<int> info_device(c, info);
49 
50  gpulapack_status_t check = gpulapack::getrs(c->lapack_handle(), GPUBLAS_OP_N,
51  n, nrhs, x.data_ptr(), n, p.data_ptr(), y_d, n, info_device.data_ptr());
52 
53  info_device.get_val(&info);
54  gpulapack::err::check_ret(check, "getrs");
55  fml::linalgutils::check_info(info, "getrs");
56  }
57  }
58 
78  template <typename REAL>
80  {
81  err::check_card(x, y);
82  solver(x, y.size(), 1, y.data_ptr());
83  }
84 
86  template <typename REAL>
88  {
89  err::check_card(x, y);
90  solver(x, y.nrows(), y.ncols(), y.data_ptr());
91  }
92 }
93 }
94 
95 
96 #endif
fml::linalg::solve
void solve(cpumat< REAL > &x, cpuvec< REAL > &y)
Solve a system of equations.
Definition: solve.hh:63
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::gpuvec
Vector class for data held on a single GPU.
Definition: gpuvec.hh:32
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::linalg::lu
void lu(cpumat< REAL > &x, cpuvec< int > &p, int &info)
Computes the PLU factorization with partial pivoting.
Definition: lu.hh:48
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35