5 #ifndef FML_CPU_LINALG_LINALG_BLAS_H
6 #define FML_CPU_LINALG_LINALG_BLAS_H
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
18 #include "../cpumat.hh"
39 template <
typename REAL>
42 const len_t n = std::min(x.
size(), y.
size());
47 #pragma omp simd reduction(+:d)
48 for (len_t i=0; i<n; i++)
54 template <
typename REAL>
76 template <
typename REAL>
77 void add(
const bool transx,
const bool transy,
const REAL alpha,
82 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
92 if (!transx && !transy)
94 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
95 for (len_t j=0; j<n; j++)
98 for (len_t i=0; i<m; i++)
99 ret_d[i + m*j] = alpha*x_d[i + m*j] + beta*y_d[i + m*j];
102 else if (transx && transy)
104 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
105 for (len_t j=0; j<m; j++)
108 for (len_t i=0; i<n; i++)
109 ret_d[j + m*i] = alpha*x_d[i + n*j] + beta*y_d[i + n*j];
112 else if (transx && !transy)
114 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
115 for (len_t j=0; j<n; j++)
118 for (len_t i=0; i<m; i++)
119 ret_d[i + m*j] = alpha*x_d[j + n*i] + beta*y_d[i + m*j];
122 else if (!transx && transy)
124 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
125 for (len_t j=0; j<n; j++)
128 for (len_t i=0; i<m; i++)
129 ret_d[i + m*j] = alpha*x_d[i + m*j] + beta*y_d[j + n*i];
135 template <
typename REAL>
140 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
144 add(transx, transy, alpha, beta, x, y, ret);
164 template <
typename REAL>
165 void prod(
const bool transx,
const bool transy,
const REAL alpha,
170 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
180 if (!transx && !transy)
182 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
183 for (len_t j=0; j<n; j++)
186 for (len_t i=0; i<m; i++)
187 ret_d[i + m*j] = alpha*x_d[i + m*j] * beta*y_d[i + m*j];
190 else if (transx && transy)
192 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
193 for (len_t j=0; j<m; j++)
196 for (len_t i=0; i<n; i++)
197 ret_d[j + m*i] = alpha*x_d[i + n*j] * beta*y_d[i + n*j];
200 else if (transx && !transy)
202 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
203 for (len_t j=0; j<n; j++)
206 for (len_t i=0; i<m; i++)
207 ret_d[i + m*j] = alpha*x_d[j + n*i] * beta*y_d[i + m*j];
210 else if (!transx && transy)
212 #pragma omp parallel for if(m*n > fml::omp::OMP_MIN_SIZE)
213 for (len_t j=0; j<n; j++)
216 for (len_t i=0; i<m; i++)
217 ret_d[i + m*j] = alpha*x_d[i + m*j] * beta*y_d[j + n*i];
223 template <
typename REAL>
228 fml::linalgutils::matadd_params(transx, transy, x.
nrows(), x.
ncols(),
232 prod(transx, transy, alpha, beta, x, y, ret);
256 template <
typename REAL>
257 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
261 const len_t mx = x.
nrows();
262 const len_t my = y.
nrows();
264 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(), my,
265 y.
ncols(), &m, &n, &k);
270 const char ctransx = transx ?
'T' :
'N';
271 const char ctransy = transy ?
'T' :
'N';
273 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
279 template <
typename REAL>
284 matmult(transx, transy, alpha, x, y, ret);
290 template <
typename REAL>
291 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
295 const len_t mx = x.
nrows();
296 const len_t my = y.
size();
298 fml::linalgutils::matmult_params(transx, transy, mx, x.
ncols(), my,
301 int len = std::max(m, n);
302 if (len != ret.
size())
305 const char ctransx = transx ?
'T' :
'N';
306 const char ctransy = transy ?
'T' :
'N';
308 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
314 template <
typename REAL>
319 matmult(transx, transy, alpha, x, y, ret);
325 template <
typename REAL>
326 void matmult(
const bool transx,
const bool transy,
const REAL alpha,
330 const len_t mx = x.
size();
331 const len_t my = y.
nrows();
333 fml::linalgutils::matmult_params(transx, transy, mx, 1, my,
334 y.
ncols(), &m, &n, &k);
336 int len = std::max(m, n);
337 if (len != ret.
size())
340 const char ctransx = transx ?
'T' :
'N';
341 const char ctransy = transy ?
'T' :
'N';
343 fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
349 template <
typename REAL>
354 matmult(transx, transy, alpha, x, y, ret);
378 template <
typename REAL>
381 const len_t m = x.
nrows();
382 const len_t n = x.
ncols();
388 fml::blas::syrk(
'L',
'T', n, m, alpha, x.
data_ptr(), m, (REAL)0.0, ret.
data_ptr(), n);
392 template <
typename REAL>
395 const len_t n = x.
ncols();
422 template <
typename REAL>
425 const len_t m = x.
nrows();
426 const len_t n = x.
ncols();
432 fml::blas::syrk(
'L',
'N', m, n, alpha, x.
data_ptr(), m, (REAL)0.0, ret.
data_ptr(), m);
435 template <
typename REAL>
438 const len_t m = x.
nrows();
462 template <
typename REAL>
465 const len_t m = x.
nrows();
466 const len_t n = x.
ncols();
471 const int blocksize = 8;
475 #pragma omp parallel for shared(tx) schedule(dynamic, 1) if(m*n > fml::omp::OMP_MIN_SIZE)
476 for (
int j=0; j<n; j+=blocksize)
478 for (
int i=0; i<m; i+=blocksize)
480 for (
int col=j; col<j+blocksize && col<n; ++col)
482 for (
int row=i; row<i+blocksize && row<m; ++row)
483 tx_d[col + n*row] = x_d[row + m*col];
490 template <
typename REAL>