fml  0.1-0
Fused Matrix Library
matmult.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_MATMULT_H
6 #define FML_MPI_LINALG_MATMULT_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../mpimat.hh"
15 
16 #include "internals/err.hh"
17 #include "internals/pblas.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
43  template <typename REAL>
44  mpimat<REAL> matmult(const bool transx, const bool transy, const REAL alpha, const mpimat<REAL> &x, const mpimat<REAL> &y)
45  {
46  err::check_grid(x, y);
47 
48  len_t m, n, k;
49  fml::linalgutils::matmult_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n, &k);
50 
51  const grid g = x.get_grid();
52  mpimat<REAL> ret(g, m, n, x.bf_rows(), x.bf_cols());
53 
54  const char ctransx = transx ? 'T' : 'N';
55  const char ctransy = transy ? 'T' : 'N';
56 
57  fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
58  x.data_ptr(), x.desc_ptr(), y.data_ptr(), y.desc_ptr(),
59  (REAL)0, ret.data_ptr(), ret.desc_ptr());
60 
61  return ret;
62  }
63 
64 
65 
86  template <typename REAL>
87  void matmult(const bool transx, const bool transy, const REAL alpha, const mpimat<REAL> &x, const mpimat<REAL> &y, mpimat<REAL> &ret)
88  {
89  err::check_grid(x, y, ret);
90 
91  len_t m, n, k;
92  fml::linalgutils::matmult_params(transx, transy, x.nrows(), x.ncols(), y.nrows(), y.ncols(), &m, &n, &k);
93 
94  if (m != ret.nrows() || n != ret.ncols())
95  ret.resize(m, n);
96 
97  const char ctransx = transx ? 'T' : 'N';
98  const char ctransy = transy ? 'T' : 'N';
99 
100  fml::pblas::gemm(ctransx, ctransy, m, n, k, alpha,
101  x.data_ptr(), x.desc_ptr(), y.data_ptr(), y.desc_ptr(),
102  (REAL)0, ret.data_ptr(), ret.desc_ptr());
103  }
104 }
105 }
106 
107 
108 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:326
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43