5 #ifndef FML_GPU_LINALG_LINALG_BLAS_H
6 #define FML_GPU_LINALG_LINALG_BLAS_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../gpumat.hh"
18 #include "linalg_err.hh"
38 template <
typename REAL>
41 err::check_card(x, y);
42 const len_t len = std::min(x.
size(), y.
size());
45 fml::linalgutils::matmult_params(
true,
false, len, 1, len, 1, &m, &n, &k);
49 gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
50 GPUBLAS_OP_T, GPUBLAS_OP_N, m, n, k, (REAL)1, x.
data_ptr(), len,
51 y.
data_ptr(), len, (REAL)0, d_device.data_ptr(), 1);
52 gpublas::err::check_ret(check,
"gemm");
59 template <
typename REAL>
83 template <
typename REAL>
84 void add(
const bool transx,
const bool transy,
const REAL alpha,
88 err::check_card(x, y, ret);
91 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
97 auto c = x.get_card();
98 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
99 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
101 gpublas_status_t check = gpublas::geam(c->blas_handle(), cbtransx, cbtransy,
104 gpublas::err::check_ret(check,
"geam");
108 template <
typename REAL>
112 err::check_card(x, y);
115 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
118 auto c = x.get_card();
120 add(transx, transy, alpha, beta, x, y, ret);
143 template <
typename REAL>
144 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
147 err::check_card(x, y, ret);
149 const len_t mx = x.
nrows();
150 const len_t my = y.
nrows();
153 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(),
154 my, y.
ncols(), &m, &n, &k);
159 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
160 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
162 gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
165 gpublas::err::check_ret(check,
"gemm");
169 template <
typename REAL>
174 matmult(transx, transy, alpha, x, y, ret);
180 template <
typename REAL>
181 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
184 err::check_card(x, y, ret);
186 const len_t mx = x.
nrows();
187 const len_t my = y.
size();
190 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(),
192 auto c = x.get_card();
193 int len = std::max(m, n);
194 if (len != ret.
size())
197 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
198 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
200 gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
203 gpublas::err::check_ret(check,
"gemm");
207 template <
typename REAL>
212 matmult(transx, transy, alpha, x, y, ret);
218 template <
typename REAL>
219 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
222 err::check_card(x, y, ret);
224 const len_t mx = x.
size();
225 const len_t my = y.
nrows();
228 fml::linalgutils::matmult_params(transx, transy, mx, 1,
229 my, y.
ncols(), &m, &n, &k);
230 auto c = x.get_card();
231 int len = std::max(m, n);
232 if (len != ret.
size())
235 gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
236 gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
238 gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
241 gpublas::err::check_ret(check,
"gemm");
245 template <
typename REAL>
250 matmult(transx, transy, alpha, x, y, ret);
274 template <
typename REAL>
277 err::check_card(x, ret);
279 const len_t m = x.
nrows();
280 const len_t n = x.
ncols();
285 matmult(
true,
false, alpha, x, x, ret);
289 template <
typename REAL>
292 const len_t n = x.
ncols();
319 template <
typename REAL>
322 err::check_card(x, ret);
324 const len_t m = x.
nrows();
325 const len_t n = x.
ncols();
330 matmult(
false,
true, alpha, x, x, ret);
334 template <
typename REAL>
337 const len_t m = x.
nrows();
363 template <
typename REAL>
366 err::check_card(x, tx);
368 const len_t m = x.
nrows();
369 const len_t n = x.
ncols();
374 auto cbh = x.get_card()->blas_handle();
376 gpublas_status_t check = gpublas::geam(cbh, GPUBLAS_OP_T, GPUBLAS_OP_N, n, m, (REAL)1.0, x.
data_ptr(), m, (REAL) 0.0, tx.
data_ptr(), n, tx.
data_ptr(), n);
377 gpublas::err::check_ret(check,
"geam");
381 template <
typename REAL>