fml  0.1-0
Fused Matrix Library
solve.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_CPU_LINALG_LINALG_SOLVE_H
6 #define FML_CPU_LINALG_LINALG_SOLVE_H
7 #pragma once
8 
9 
10 #include <cmath>
11 #include <stdexcept>
12 
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
15 
16 #include "../cpumat.hh"
17 #include "../cpuvec.hh"
18 
19 #include "internals/lapack.hh"
20 
21 
22 namespace fml
23 {
24 namespace linalg
25 {
26  namespace
27  {
28  template <typename REAL>
29  void solver(cpumat<REAL> &x, len_t ylen, len_t nrhs, REAL *y_d)
30  {
31  const len_t n = x.nrows();
32  if (!x.is_square())
33  throw std::runtime_error("'x' must be a square matrix");
34  if (n != ylen)
35  throw std::runtime_error("rhs 'y' must be compatible with data matrix 'x'");
36 
37  int info;
38  cpuvec<int> p(n);
39  fml::lapack::gesv(n, nrhs, x.data_ptr(), n, p.data_ptr(), y_d, n, &info);
40  fml::linalgutils::check_info(info, "gesv");
41  }
42  }
43 
62  template <typename REAL>
64  {
65  solver(x, y.size(), 1, y.data_ptr());
66  }
67 
69  template <typename REAL>
71  {
72  solver(x, y.nrows(), y.ncols(), y.data_ptr());
73  }
74 }
75 }
76 
77 
78 #endif
fml::cpumat
Matrix class for data held on a single CPU.
Definition: cpumat.hh:36
fml::linalg::solve
void solve(cpumat< REAL > &x, cpuvec< REAL > &y)
Solve a system of equations.
Definition: solve.hh:63
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26