fml  0.1-0
Fused Matrix Library
linalg_blas.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_LINALG_BLAS_H
6 #define FML_GPU_LINALG_LINALG_BLAS_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../gpumat.hh"
17 
18 #include "linalg_err.hh"
19 
20 
21 namespace fml
22 {
23 namespace linalg
24 {
38  template <typename REAL>
39  REAL dot(const gpuvec<REAL> &x, const gpuvec<REAL> &y)
40  {
41  err::check_card(x, y);
42  const len_t len = std::min(x.size(), y.size());
43 
44  len_t m, n, k;
45  fml::linalgutils::matmult_params(true, false, len, 1, len, 1, &m, &n, &k);
46 
47  REAL d;
48  gpuscalar<REAL> d_device(x.get_card());
49  gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
50  GPUBLAS_OP_T, GPUBLAS_OP_N, m, n, k, (REAL)1, x.data_ptr(), len,
51  y.data_ptr(), len, (REAL)0, d_device.data_ptr(), 1);
52  gpublas::err::check_ret(check, "gemm");
53 
54  d_device.get_val(&d);
55 
56  return d;
57  }
58 
59  template <typename REAL>
60  REAL dot(const gpuvec<REAL> &x)
61  {
62  return dot(x, x);
63  }
64 
65 
66 
83  template <typename REAL>
84  void add(const bool transx, const bool transy, const REAL alpha,
85  const REAL beta, const gpumat<REAL> &x, const gpumat<REAL> &y,
86  gpumat<REAL> &ret)
87  {
88  err::check_card(x, y, ret);
89 
90  len_t m, n;
91  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(),
92  y.nrows(), y.ncols(), &m, &n);
93 
94  if (ret.nrows() != m || ret.ncols() != n)
95  ret.resize(m, n);
96 
97  auto c = x.get_card();
98  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
99  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
100 
101  gpublas_status_t check = gpublas::geam(c->blas_handle(), cbtransx, cbtransy,
102  m, n, alpha, x.data_ptr(), x.nrows(), beta, y.data_ptr(), y.nrows(),
103  ret.data_ptr(), m);
104  gpublas::err::check_ret(check, "geam");
105  }
106 
108  template <typename REAL>
109  gpumat<REAL> add(const bool transx, const bool transy, const REAL alpha,
110  const REAL beta, const gpumat<REAL> &x, const gpumat<REAL> &y)
111  {
112  err::check_card(x, y);
113 
114  len_t m, n;
115  fml::linalgutils::matadd_params(transx, transy, x.nrows(), x.ncols(),
116  y.nrows(), y.ncols(), &m, &n);
117 
118  auto c = x.get_card();
119  gpumat<REAL> ret(c, m, n);
120  add(transx, transy, alpha, beta, x, y, ret);
121  return ret;
122  }
123 
124 
125 
143  template <typename REAL>
144  void matmult(const bool transx, const bool transy, const REAL alpha,
145  const gpumat<REAL> &x, const gpumat<REAL> &y, gpumat<REAL> &ret)
146  {
147  err::check_card(x, y, ret);
148 
149  const len_t mx = x.nrows();
150  const len_t my = y.nrows();
151 
152  int m, n, k;
153  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(),
154  my, y.ncols(), &m, &n, &k);
155 
156  if (m != ret.nrows() || n != ret.ncols())
157  ret.resize(m, n);
158 
159  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
160  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
161 
162  gpublas_status_t check = gpublas::gemm(x.get_card()->blas_handle(),
163  cbtransx, cbtransy, m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(),
164  my, (REAL)0, ret.data_ptr(), m);
165  gpublas::err::check_ret(check, "gemm");
166  }
167 
169  template <typename REAL>
170  gpumat<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
171  const gpumat<REAL> &x, const gpumat<REAL> &y)
172  {
173  gpumat<REAL> ret(x.get_card());
174  matmult(transx, transy, alpha, x, y, ret);
175 
176  return ret;
177  }
178 
180  template <typename REAL>
181  void matmult(const bool transx, const bool transy, const REAL alpha,
182  const gpumat<REAL> &x, const gpuvec<REAL> &y, gpuvec<REAL> &ret)
183  {
184  err::check_card(x, y, ret);
185 
186  const len_t mx = x.nrows();
187  const len_t my = y.size();
188 
189  int m, n, k;
190  fml::linalgutils::matmult_params(transx, transy, mx, x.ncols(),
191  my, 1, &m, &n, &k);
192  auto c = x.get_card();
193  int len = std::max(m, n);
194  if (len != ret.size())
195  ret.resize(len);
196 
197  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
198  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
199 
200  gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
201  m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(), my, (REAL)0,
202  ret.data_ptr(), m);
203  gpublas::err::check_ret(check, "gemm");
204  }
205 
207  template <typename REAL>
208  gpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
209  const gpumat<REAL> &x, const gpuvec<REAL> &y)
210  {
211  gpuvec<REAL> ret(x.get_card());
212  matmult(transx, transy, alpha, x, y, ret);
213 
214  return ret;
215  }
216 
218  template <typename REAL>
219  void matmult(const bool transx, const bool transy, const REAL alpha,
220  const gpuvec<REAL> &x, const gpumat<REAL> &y, gpuvec<REAL> &ret)
221  {
222  err::check_card(x, y, ret);
223 
224  const len_t mx = x.size();
225  const len_t my = y.nrows();
226 
227  int m, n, k;
228  fml::linalgutils::matmult_params(transx, transy, mx, 1,
229  my, y.ncols(), &m, &n, &k);
230  auto c = x.get_card();
231  int len = std::max(m, n);
232  if (len != ret.size())
233  ret.resize(len);
234 
235  gpublas_operation_t cbtransx = transx ? GPUBLAS_OP_T : GPUBLAS_OP_N;
236  gpublas_operation_t cbtransy = transy ? GPUBLAS_OP_T : GPUBLAS_OP_N;
237 
238  gpublas_status_t check = gpublas::gemm(c->blas_handle(), cbtransx, cbtransy,
239  m, n, k, alpha, x.data_ptr(), mx, y.data_ptr(), my, (REAL)0,
240  ret.data_ptr(), m);
241  gpublas::err::check_ret(check, "gemm");
242  }
243 
245  template <typename REAL>
246  gpuvec<REAL> matmult(const bool transx, const bool transy, const REAL alpha,
247  const gpuvec<REAL> &x, const gpumat<REAL> &y)
248  {
249  gpuvec<REAL> ret(x.get_card());
250  matmult(transx, transy, alpha, x, y, ret);
251 
252  return ret;
253  }
254 
255 
256 
274  template <typename REAL>
275  void crossprod(const REAL alpha, const gpumat<REAL> &x, gpumat<REAL> &ret)
276  {
277  err::check_card(x, ret);
278 
279  const len_t m = x.nrows();
280  const len_t n = x.ncols();
281 
282  if (n != ret.nrows() || n != ret.ncols())
283  ret.resize(n, n);
284 
285  matmult(true, false, alpha, x, x, ret);
286  }
287 
289  template <typename REAL>
290  gpumat<REAL> crossprod(const REAL alpha, const gpumat<REAL> &x)
291  {
292  const len_t n = x.ncols();
293  gpumat<REAL> ret(x.get_card(), n, n);
294 
295  crossprod(alpha, x, ret);
296 
297  return ret;
298  }
299 
300 
301 
319  template <typename REAL>
320  void tcrossprod(const REAL alpha, const gpumat<REAL> &x, gpumat<REAL> &ret)
321  {
322  err::check_card(x, ret);
323 
324  const len_t m = x.nrows();
325  const len_t n = x.ncols();
326 
327  if (m != ret.nrows() || m != ret.ncols())
328  ret.resize(m, m);
329 
330  matmult(false, true, alpha, x, x, ret);
331  }
332 
334  template <typename REAL>
335  gpumat<REAL> tcrossprod(const REAL alpha, const gpumat<REAL> &x)
336  {
337  const len_t m = x.nrows();
338  gpumat<REAL> ret(x.get_card(), m, m);
339 
340  tcrossprod(alpha, x, ret);
341 
342  return ret;
343  }
344 
345 
346 
363  template <typename REAL>
364  void xpose(const gpumat<REAL> &x, gpumat<REAL> &tx)
365  {
366  err::check_card(x, tx);
367 
368  const len_t m = x.nrows();
369  const len_t n = x.ncols();
370 
371  if (m != tx.ncols() || n != tx.nrows())
372  tx.resize(n, m);
373 
374  auto cbh = x.get_card()->blas_handle();
375 
376  gpublas_status_t check = gpublas::geam(cbh, GPUBLAS_OP_T, GPUBLAS_OP_N, n, m, (REAL)1.0, x.data_ptr(), m, (REAL) 0.0, tx.data_ptr(), n, tx.data_ptr(), n);
377  gpublas::err::check_ret(check, "geam");
378  }
379 
381  template <typename REAL>
383  {
384  gpumat<REAL> tx(x.get_card(), x.ncols(), x.nrows());
385  xpose(x, tx);
386  return tx;
387  }
388 }
389 }
390 
391 
392 #endif
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: linalg_blas.hh:379
fml::linalg::xpose
void xpose(const cpumat< REAL > &x, cpumat< REAL > &tx)
Computes the transpose out-of-place (i.e. in a copy).
Definition: linalg_blas.hh:463
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::gpuvec
Vector class for data held on a single GPU.
Definition: gpuvec.hh:32
fml::gpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: gpuvec.hh:224
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::univec::size
len_t size() const
Number of elements in the vector.
Definition: univec.hh:26
fml::gpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: gpumat.hh:253
fml::linalg::add
void add(const bool transx, const bool transy, const REAL alpha, const REAL beta, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Returns alpha*op(x) + beta*op(y) where op(A) is A or A^T.
Definition: linalg_blas.hh:77
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::gpuscalar
Definition: gpuscalar.hh:16
fml::linalg::dot
REAL dot(const cpuvec< REAL > &x, const cpuvec< REAL > &y)
Computes the dot product of two vectors, i.e. the sum of the product of the elements.
Definition: linalg_blas.hh:40
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: linalg_blas.hh:423
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: linalg_blas.hh:257