fml  0.1-0
Fused Matrix Library
linalg_blas.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_LINALG_BLAS_H
6 #define FML_MPI_LINALG_LINALG_BLAS_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../mpimat.hh"
15 
16 #include "linalg_err.hh"
17 #include "pblas.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
42  template <typename REAL>
43  void add(const bool transx, const bool transy, const REAL alpha, const REAL beta, const mpimat<REAL> &x, const mpimat<REAL> &y, mpimat<REAL> &ret)
44  {
45  err::check_grid(x, y, ret);
46 
47  len_t m, n;
48  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n);
49 
50  if (ret.nrows() != m || ret.ncols() != n)
51  ret.resize(m, n);
52 
53  char ctransx = transx ? 'T' : 'N';
54  char ctransy = transy ? 'T' : 'N';
55 
56  fml::pblas::geadd(ctransy, m, n, beta, y.data_ptr(), y.desc_ptr(), (REAL) 0.0f, ret.data_ptr(), ret.desc_ptr());
57  fml::pblas::geadd(ctransx, m, n, alpha, x.data_ptr(), x.desc_ptr(), (REAL) 1.0f, ret.data_ptr(), ret.desc_ptr());
58  }
59 
61  template <typename REAL>
62  mpimat<REAL> add(const bool transx, const bool transy, const REAL alpha, const REAL beta, const mpimat<REAL> &x, const mpimat<REAL> &y)
63  {
64  err::check_grid(x, y);
65 
66  len_t m, n;
67  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n);
68 
69  const grid g = x.get_grid();
70  mpimat<REAL> ret(g, m, n, x.bf_rows(), x.bf_cols());
71  add(transx, transy, alpha, beta, x, y, ret);
72  return ret;
73  }
74 
75 
76 
96  template <typename REAL>
97  mpimat<REAL> matmult(const bool transx, const bool transy, const REAL alpha, const mpimat<REAL> &x, const mpimat<REAL> &y)
98  {
99  err::check_grid(x, y);
100 
101  len_t m, n, k;
102  fml::linalgutils::matmult_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n, &k);
103 
104  const grid g = x.get_grid();
105  mpimat<REAL> ret(g, m, n, x.bf_rows(), x.bf_cols());
106 
107  const char ctransx = transx ? 'T' : 'N';
108  const char ctransy = transy ? 'T' : 'N';
109 
110  fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
111  x.data_ptr(), x.desc_ptr(), y.data_ptr(), y.desc_ptr(),
112  (REAL)0, ret.data_ptr(), ret.desc_ptr());
113 
114  return ret;
115  }
116 
137  template <typename REAL>
138  void matmult(const bool transx, const bool transy, const REAL alpha, const mpimat<REAL> &x, const mpimat<REAL> &y, mpimat<REAL> &ret)
139  {
140  err::check_grid(x, y, ret);
141 
142  len_t m, n, k;
143  fml::linalgutils::matmult_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n, &k);
144 
145  if (m != ret.nrows() || n != ret.ncols())
146  ret.resize(m, n);
147 
148  const char ctransx = transx ? 'T' : 'N';
149  const char ctransy = transy ? 'T' : 'N';
150 
151  fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
152  x.data_ptr(), x.desc_ptr(), y.data_ptr(), y.desc_ptr(),
153  (REAL)0, ret.data_ptr(), ret.desc_ptr());
154  }
155 
156 
157 
177  template <typename REAL>
178  void crossprod(const REAL alpha, const mpimat<REAL> &x, mpimat<REAL> &ret)
179  {
180  err::check_grid(x, ret);
181 
182  const len_t n = x.ncols();
183 
184  if (n != ret.nrows() || n != ret.ncols())
185  ret.resize(n, n);
186 
187  ret.fill_zero();
188  fml::pblas::syrk('L', 'T', n, x.nrows(), alpha, x.data_ptr(), x.desc_ptr(), (REAL) 0, ret.data_ptr(), ret.desc_ptr());
189  }
190 
192  template <typename REAL>
193  mpimat<REAL> crossprod(const REAL alpha, const mpimat<REAL> &x)
194  {
195  const len_t n = x.ncols();
196  const grid g = x.get_grid();
197  mpimat<REAL> ret(g, n, n, x.bf_rows(), x.bf_cols());
198 
199  crossprod(alpha, x, ret);
200 
201  return ret;
202  }
203 
204 
205 
225  template <typename REAL>
226  void tcrossprod(const REAL alpha, const mpimat<REAL> &x, mpimat<REAL> &ret)
227  {
228  err::check_grid(x, ret);
229 
230  const len_t m = x.nrows();
231 
232  if (m != ret.nrows() || m != ret.ncols())
233  ret.resize(m, m);
234 
235  ret.fill_zero();
236  fml::pblas::syrk('L', 'N', m, x.ncols(), alpha, x.data_ptr(), x.desc_ptr(), (REAL) 0, ret.data_ptr(), ret.desc_ptr());
237  }
238 
240  template <typename REAL>
241  mpimat<REAL> tcrossprod(const REAL alpha, const mpimat<REAL> &x)
242  {
243  const len_t n = x.nrows();
244  const grid g = x.get_grid();
245  mpimat<REAL> ret(g, n, n, x.bf_rows(), x.bf_cols());
246 
247  tcrossprod(alpha, x, ret);
248 
249  return ret;
250  }
251 
252 
253 
272  template <typename REAL>
273  void xpose(const mpimat<REAL> &x, mpimat<REAL> &tx)
274  {
275  err::check_grid(x, tx);
276 
277  const len_t m = x.nrows();
278  const len_t n = x.ncols();
279 
280  if (m != tx.ncols() || n != tx.nrows())
281  tx.resize(n, m);
282 
283  fml::pblas::tran(n, m, 1.f, x.data_ptr(), x.desc_ptr(), 0.f, tx.data_ptr(), tx.desc_ptr());
284  }
285 
287  template <typename REAL>
289  {
290  const len_t m = x.nrows();
291  const len_t n = x.ncols();
292  const grid g = x.get_grid();
293 
294  mpimat<REAL> tx(g, n, m, x.bf_rows(), x.bf_cols());
295  xpose(x, tx);
296  return tx;
297  }
298 }
299 }
300 
301 
302 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:323
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: linalg_blas.hh:379
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::linalg::xpose
void xpose(const cpumat< REAL > &x, cpumat< REAL > &tx)
Computes the transpose out-of-place (i.e. in a copy).
Definition: linalg_blas.hh:463
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::mpimat::fill_zero
void fill_zero()
Set all values to zero.
Definition: mpimat.hh:562
fml::linalg::add
void add(const bool transx, const bool transy, const REAL alpha, const REAL beta, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Returns alpha*op(x) + beta*op(y) where op(A) is A or A^T.
Definition: linalg_blas.hh:77
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: linalg_blas.hh:423
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: linalg_blas.hh:257