fml  0.1-0
Fused Matrix Library
matmult.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_CPU_LINALG_MATMULT_H
6 #define FML_CPU_LINALG_MATMULT_H
7 #pragma once
8 
9 
10 #include <cmath>
11 #include <stdexcept>
12 
13 #include "../../_internals/linalgutils.hh"
14 #include "../../_internals/omp.hh"
15 
16 #include "../cpumat.hh"
17 
18 #include "internals/blas.hh"
19 
20 
21 namespace fml
22 {
23 namespace linalg
24 {
42  template <typename REAL>
43  void matmult(const bool transx, const bool transy, const REAL alpha,
44  const cpumat<REAL> &x, const cpumat<REAL> &y, cpumat<REAL> &ret)
45  {
46  len_t m, n, k;
47  const len_t mx = x.nrows();
48  const len_t my = y.nrows();
49 
50  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(), my,
51  y.ncols(), &m, &n, &k);
52 
53  if (m != ret.nrows() || n != ret.ncols())
54  ret.resize(m, n);
55 
56  const char ctransx = transx ? 'T' : 'N';
57  const char ctransy = transy ? 'T' : 'N';
58 
59  fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
60  x.data_ptr(), mx, y.data_ptr(), my,
61  (REAL)0, ret.data_ptr(), m);
62  }
63 
64 
65 
67  template <typename REAL>
68  cpumat<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
69  const cpumat<REAL> &x, const cpumat<REAL> &y)
70  {
71  cpumat<REAL> ret;
72  matmult(transx, transy, alpha, x, y, ret);
73 
74  return ret;
75  }
76 
77 
78 
80  template <typename REAL>
81  void matmult(const bool transx, const bool transy, const REAL alpha,
82  const cpumat<REAL> &x, const cpuvec<REAL> &y, cpuvec<REAL> &ret)
83  {
84  len_t m, n, k;
85  const len_t mx = x.nrows();
86  const len_t my = y.size();
87 
88  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(), my,
89  1, &m, &n, &k);
90 
91  int len = std::max(m, n);
92  if (len != ret.size())
93  ret.resize(len);
94 
95  const char ctransx = transx ? 'T' : 'N';
96  const char ctransy = transy ? 'T' : 'N';
97 
98  fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
99  x.data_ptr(), mx, y.data_ptr(), my,
100  (REAL)0, ret.data_ptr(), m);
101  }
102 
103 
104 
106  template <typename REAL>
107  cpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
108  const cpumat<REAL> &x, const cpuvec<REAL> &y)
109  {
110  cpuvec<REAL> ret;
111  matmult(transx, transy, alpha, x, y, ret);
112 
113  return ret;
114  }
115 
116 
117 
119  template <typename REAL>
120  void matmult(const bool transx, const bool transy, const REAL alpha,
121  const cpuvec<REAL> &x, const cpumat<REAL> &y, cpuvec<REAL> &ret)
122  {
123  len_t m, n, k;
124  const len_t mx = x.size();
125  const len_t my = y.nrows();
126 
127  fml::linalgutils::matmult_params(transx, transy, mx, 1, my,
128  y.ncols(), &m, &n, &k);
129 
130  int len = std::max(m, n);
131  if (len != ret.size())
132  ret.resize(len);
133 
134  const char ctransx = transx ? 'T' : 'N';
135  const char ctransy = transy ? 'T' : 'N';
136 
137  fml::blas::gemm(ctransx, ctransy, m, n, k, alpha,
138  x.data_ptr(), mx, y.data_ptr(), my,
139  (REAL)0, ret.data_ptr(), m);
140  }
141 
142 
143 
145  template <typename REAL>
146  cpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
147  const cpuvec<REAL> &x, const cpumat<REAL> &y)
148  {
149  cpuvec<REAL> ret;
150  matmult(transx, transy, alpha, x, y, ret);
151 
152  return ret;
153  }
154 }
155 }
156 
157 
158 #endif
fml::cpumat
Matrix class for data held on a single CPU.
Definition: cpumat.hh:36
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::cpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: cpuvec.hh:210
fml::cpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: cpumat.hh:233
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43