5 #ifndef FML_MPI_LINALG_MATMULT_H
6 #define FML_MPI_LINALG_MATMULT_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../mpimat.hh"
16 #include "internals/err.hh"
17 #include "internals/pblas.hh"
43 template <
typename REAL>
46 err::check_grid(x, y);
49 fml::linalgutils::matmult_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n, &k);
51 const grid g = x.get_grid();
54 const char ctransx = transx ?
'T' :
'N';
55 const char ctransy = transy ?
'T' :
'N';
57 fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
59 (REAL)0, ret.
data_ptr(), ret.desc_ptr());
86 template <
typename REAL>
89 err::check_grid(x, y, ret);
92 fml::linalgutils::matmult_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n, &k);
97 const char ctransx = transx ?
'T' :
'N';
98 const char ctransy = transy ?
'T' :
'N';
100 fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
102 (REAL)0, ret.
data_ptr(), ret.desc_ptr());