fml  0.1-0
Fused Matrix Library
add.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_ADD_H
6 #define FML_GPU_LINALG_ADD_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../gpumat.hh"
17 
18 #include "internals/err.hh"
19 
20 
21 namespace fml
22 {
23 namespace linalg
24 {
41  template <typename REAL>
42  void add(const bool transx, const bool transy, const REAL alpha,
43  const REAL beta, const gpumat<REAL> &x, const gpumat<REAL> &y,
44  gpumat<REAL> &ret)
45  {
46  err::check_card(x, y, ret);
47 
48  len_t m, n;
49  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(),
50  y.nrows(), y.ncols(), &m, &n);
51 
52  if (ret.nrows() != m || ret.ncols() != n)
53  ret.resize(m, n);
54 
55  auto c = x.get_card();
56  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
57  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
58 
59  gpublas_status_t check = gpublas::geam(c->blas_handle(), cbtransx, cbtransy,
60  m, n, alpha, x.data_ptr(), x.nrows(), beta, y.data_ptr(), y.nrows(),
61  ret.data_ptr(), m);
62  gpublas::err::check_ret(check, "geam");
63  }
64 
65 
66 
68  template <typename REAL>
69  gpumat<REAL> add(const bool transx, const bool transy, const REAL alpha,
70  const REAL beta, const gpumat<REAL> &x, const gpumat<REAL> &y)
71  {
72  err::check_card(x, y);
73 
74  len_t m, n;
75  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(),
76  y.nrows(), y.ncols(), &m, &n);
77 
78  auto c = x.get_card();
79  gpumat<REAL> ret(c, m, n);
80  add(transx, transy, alpha, beta, x, y, ret);
81  return ret;
82  }
83 }
84 }
85 
86 
87 #endif
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::gpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: gpumat.hh:256
fml::linalg::add
void add(const bool transx, const bool transy, const REAL alpha, const REAL beta, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Returns alpha*op(x) + beta*op(y) where op(A) is A or A^T.
Definition: add.hh:35
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35