fml  0.1-0
Fused Matrix Library
matmult.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_MATMULT_H
6 #define FML_GPU_LINALG_MATMULT_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../gpumat.hh"
17 
18 #include "internals/err.hh"
19 
20 
21 namespace fml
22 {
23 namespace linalg
24 {
42  template <typename REAL>
43  void matmult(const bool transx, const bool transy, const REAL alpha,
44  const gpumat<REAL> &x, const gpumat<REAL> &y, gpumat<REAL> &ret)
45  {
46  err::check_card(x, y, ret);
47 
48  const len_t mx = x.nrows();
49  const len_t my = y.nrows();
50 
51  int m, n, k;
52  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(),
53  my, y.ncols(), &m, &n, &k);
54 
55  if (m != ret.nrows() || n != ret.ncols())
56  ret.resize(m, n);
57 
58  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
59  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
60 
61  gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
62  cbtransx, cbtransy, m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(),
63  my, (REAL)0, ret.data_ptr(), m);
64  gpublas::err::check_ret(check, "gemm");
65  }
66 
67 
68 
70  template <typename REAL>
71  gpumat<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
72  const gpumat<REAL> &x, const gpumat<REAL> &y)
73  {
74  gpumat<REAL> ret(x.get_card());
75  matmult(transx, transy, alpha, x, y, ret);
76 
77  return ret;
78  }
79 
80 
81 
83  template <typename REAL>
84  void matmult(const bool transx, const bool transy, const REAL alpha,
85  const gpumat<REAL> &x, const gpuvec<REAL> &y, gpuvec<REAL> &ret)
86  {
87  err::check_card(x, y, ret);
88 
89  const len_t mx = x.nrows();
90  const len_t my = y.size();
91 
92  int m, n, k;
93  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(),
94  my, 1, &m, &n, &k);
95  auto c = x.get_card();
96  int len = std::max(m, n);
97  if (len != ret.size())
98  ret.resize(len);
99 
100  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
101  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
102 
103  gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
104  m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(), my, (REAL)0,
105  ret.data_ptr(), m);
106  gpublas::err::check_ret(check, "gemm");
107  }
108 
109 
110 
112  template <typename REAL>
113  gpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
114  const gpumat<REAL> &x, const gpuvec<REAL> &y)
115  {
116  gpuvec<REAL> ret(x.get_card());
117  matmult(transx, transy, alpha, x, y, ret);
118 
119  return ret;
120  }
121 
122 
123 
125  template <typename REAL>
126  void matmult(const bool transx, const bool transy, const REAL alpha,
127  const gpuvec<REAL> &x, const gpumat<REAL> &y, gpuvec<REAL> &ret)
128  {
129  err::check_card(x, y, ret);
130 
131  const len_t mx = x.size();
132  const len_t my = y.nrows();
133 
134  int m, n, k;
135  fml::linalgutils::matmult_params(transx, transy, mx, 1,
136  my, y.ncols(), &m, &n, &k);
137  auto c = x.get_card();
138  int len = std::max(m, n);
139  if (len != ret.size())
140  ret.resize(len);
141 
142  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
143  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
144 
145  gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
146  m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(), my, (REAL)0,
147  ret.data_ptr(), m);
148  gpublas::err::check_ret(check, "gemm");
149  }
150 
151 
152 
154  template <typename REAL>
155  gpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
156  const gpuvec<REAL> &x, const gpumat<REAL> &y)
157  {
158  gpuvec<REAL> ret(x.get_card());
159  matmult(transx, transy, alpha, x, y, ret);
160 
161  return ret;
162  }
163 }
164 }
165 
166 
167 #endif
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::gpuvec
Vector class for data held on a single GPU.
Definition: gpuvec.hh:32
fml::gpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: gpuvec.hh:225
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::gpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: gpumat.hh:256
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43