5 #ifndef FML_CPU_LINALG_MATMULT_H
6 #define FML_CPU_LINALG_MATMULT_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
16 #include "../cpumat.hh"
18 #include "internals/blas.hh"
42 template <
typename REAL>
43 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
47 const len_t mx = x.
nrows();
48 const len_t my = y.
nrows();
50 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(), my,
51 y.
ncols(), &m, &n, &k);
56 const char ctransx = transx ?
'T' :
'N';
57 const char ctransy = transy ?
'T' :
'N';
59 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
67 template <
typename REAL>
72 matmult(transx, transy, alpha, x, y, ret);
80 template <
typename REAL>
81 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
85 const len_t mx = x.
nrows();
86 const len_t my = y.
size();
88 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(), my,
91 int len = std::max(m, n);
92 if (len != ret.
size())
95 const char ctransx = transx ?
'T' :
'N';
96 const char ctransy = transy ?
'T' :
'N';
98 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
106 template <
typename REAL>
111 matmult(transx, transy, alpha, x, y, ret);
119 template <
typename REAL>
120 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
124 const len_t mx = x.
size();
125 const len_t my = y.
nrows();
127 fml::linalgutils::matmult_params(transx, transy, mx, 1, my,
128 y.
ncols(), &m, &n, &k);
130 int len = std::max(m, n);
131 if (len != ret.
size())
134 const char ctransx = transx ?
'T' :
'N';
135 const char ctransy = transy ?
'T' :
'N';
137 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
145 template <
typename REAL>
150 matmult(transx, transy, alpha, x, y, ret);