5 #ifndef FML_MPI_LINALG_LINALG_BLAS_H
6 #define FML_MPI_LINALG_LINALG_BLAS_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../mpimat.hh"
16 #include "linalg_err.hh"
42 template <
typename REAL>
45 err::check_grid(x, y, ret);
48 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n);
53 char ctransx = transx ?
'T' :
'N';
54 char ctransy = transy ?
'T' :
'N';
56 fml::pblas::geadd(ctransy, m, n, beta, y.
data_ptr(), y.desc_ptr(), (REAL) 0.0f, ret.
data_ptr(), ret.desc_ptr());
57 fml::pblas::geadd(ctransx, m, n, alpha, x.
data_ptr(), x.desc_ptr(), (REAL) 1.0f, ret.
data_ptr(), ret.desc_ptr());
61 template <
typename REAL>
64 err::check_grid(x, y);
67 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n);
69 const grid g = x.get_grid();
71 add(transx, transy, alpha, beta, x, y, ret);
96 template <
typename REAL>
99 err::check_grid(x, y);
102 fml::linalgutils::matmult_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n, &k);
104 const grid g = x.get_grid();
107 const char ctransx = transx ?
'T' :
'N';
108 const char ctransy = transy ?
'T' :
'N';
110 fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
112 (REAL)0, ret.
data_ptr(), ret.desc_ptr());
137 template <
typename REAL>
140 err::check_grid(x, y, ret);
143 fml::linalgutils::matmult_params(transx, transy, x.
nrows(), x.
ncols(), y.
nrows(), y.
ncols(), &m, &n, &k);
148 const char ctransx = transx ?
'T' :
'N';
149 const char ctransy = transy ?
'T' :
'N';
151 fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
153 (REAL)0, ret.
data_ptr(), ret.desc_ptr());
177 template <
typename REAL>
180 err::check_grid(x, ret);
182 const len_t n = x.
ncols();
188 fml::pblas::syrk(
'L',
'T', n, x.
nrows(), alpha, x.
data_ptr(), x.desc_ptr(), (REAL) 0, ret.
data_ptr(), ret.desc_ptr());
192 template <
typename REAL>
195 const len_t n = x.
ncols();
196 const grid g = x.get_grid();
225 template <
typename REAL>
228 err::check_grid(x, ret);
230 const len_t m = x.
nrows();
236 fml::pblas::syrk(
'L',
'N', m, x.
ncols(), alpha, x.
data_ptr(), x.desc_ptr(), (REAL) 0, ret.
data_ptr(), ret.desc_ptr());
240 template <
typename REAL>
243 const len_t n = x.
nrows();
244 const grid g = x.get_grid();
272 template <
typename REAL>
275 err::check_grid(x, tx);
277 const len_t m = x.
nrows();
278 const len_t n = x.
ncols();
283 fml::pblas::tran(n, m, 1.f, x.
data_ptr(), x.desc_ptr(), 0.f, tx.
data_ptr(), tx.desc_ptr());
287 template <
typename REAL>
290 const len_t m = x.
nrows();
291 const len_t n = x.
ncols();
292 const grid g = x.get_grid();