5 #ifndef FML_GPU_LINALG_MATMULT_H
6 #define FML_GPU_LINALG_MATMULT_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../gpumat.hh"
18 #include "internals/err.hh"
42 template <
typename REAL>
43 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
46 err::check_card(x, y, ret);
48 const len_t mx = x.
nrows();
49 const len_t my = y.
nrows();
52 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(),
53 my, y.
ncols(), &m, &n, &k);
58 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
59 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
61 gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
64 gpublas::err::check_ret(check,
"gemm");
70 template <
typename REAL>
75 matmult(transx, transy, alpha, x, y, ret);
83 template <
typename REAL>
84 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
87 err::check_card(x, y, ret);
89 const len_t mx = x.
nrows();
90 const len_t my = y.
size();
93 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(),
95 auto c = x.get_card();
96 int len = std::max(m, n);
97 if (len != ret.
size())
100 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
101 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
103 gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
106 gpublas::err::check_ret(check,
"gemm");
112 template <
typename REAL>
117 matmult(transx, transy, alpha, x, y, ret);
125 template <
typename REAL>
126 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
129 err::check_card(x, y, ret);
131 const len_t mx = x.
size();
132 const len_t my = y.
nrows();
135 fml::linalgutils::matmult_params(transx, transy, mx, 1,
136 my, y.
ncols(), &m, &n, &k);
137 auto c = x.get_card();
138 int len = std::max(m, n);
139 if (len != ret.
size())
142 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
143 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
145 gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
148 gpublas::err::check_ret(check,
"gemm");
154 template <
typename REAL>
159 matmult(transx, transy, alpha, x, y, ret);