fml  0.1-0
Fused Matrix Library
blas.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_PAR_GPU_LINALG_BLAS_H
6 #define FML_PAR_GPU_LINALG_BLAS_H
7 #pragma once
8 
9 
10 #include "../../../gpu/linalg/linalg_blas.hh"
11 
12 #include "../parmat.hh"
13 
14 
15 namespace fml
16 {
17 namespace linalg
18 {
36  template <typename REAL>
37  void matmult(const parmat_gpu<REAL> &x, const gpumat<REAL> &y,
38  parmat_gpu<REAL> &ret)
39  {
40  err::check_card(x, y, ret);
41 
42  linalg::matmult(false, false, (REAL)1.0, x.data_obj(), y, ret.data_obj());
43  }
44 
46  template <typename REAL>
48  {
49  parmat_gpu<REAL> ret(x.get_comm(), x.get_card(), x.nrows(), x.ncols(), x.nrows_before());
50  matmult(x, y, ret);
51  }
52 
54  template <typename REAL>
55  void matmult(const parmat_gpu<REAL> &x, const parmat_gpu<REAL> &y,
56  gpumat<REAL> &ret)
57  {
58  err::check_card(x, y, ret);
59 
60  linalg::matmult(true, false, (REAL)1.0, x.data_obj(), y.data_obj(), ret);
61  x.get_comm().allreduce(ret.nrows()*ret.ncols(), ret.data_ptr());
62  }
63 
65  template <typename REAL>
67  {
68  gpumat<REAL> ret(x.get_card(), x.ncols(), y.ncols());
69  matmult(x, y, ret);
70  }
71 
72 
73 
91  template <typename REAL>
92  void crossprod(const REAL alpha, const parmat_gpu<REAL> &x, gpumat<REAL> &ret)
93  {
94  err::check_card(x, ret);
95 
96  const len_t n = x.ncols();
97  if (n != ret.nrows() || n != ret.ncols())
98  ret.resize(n, n);
99 
100  linalg::crossprod(alpha, x.data_obj(), ret);
101 
102  comm r = x.get_comm();
103  r.allreduce(n*n, ret.data_ptr());
104  }
105 
107  template <typename REAL>
108  gpumat<REAL> crossprod(const REAL alpha, const parmat_gpu<REAL> &x)
109  {
110  const len_t n = x.ncols();
111  gpumat<REAL> ret(x.get_card(), n, n);
112 
113  crossprod(alpha, x, ret);
114  return ret;
115  }
116 }
117 }
118 
119 
120 #endif
fml::parmat_gpu
Definition: parmat.hh:20
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: crossprod.hh:37
fml::comm::allreduce
void allreduce(int n, T *data, MPI_Op op=MPI_SUM) const
Sum reduce operation across all processes in the MPI communicator.
Definition: comm.hh:357
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::comm
MPI communicator data and helpers.
Definition: comm.hh:24
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::gpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: gpumat.hh:256
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43