fml  0.1-0
Fused Matrix Library
crossprod.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_CROSSPROD_H
6 #define FML_MPI_LINALG_CROSSPROD_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../mpimat.hh"
15 
16 #include "internals/err.hh"
17 #include "internals/pblas.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
43  template <typename REAL>
44  void crossprod(const REAL alpha, const mpimat<REAL> &x, mpimat<REAL> &ret)
45  {
46  err::check_grid(x, ret);
47 
48  const len_t n = x.ncols();
49 
50  if (n != ret.nrows() || n != ret.ncols())
51  ret.resize(n, n);
52 
53  ret.fill_zero();
54  fml::pblas::syrk('L', 'T', n, x.nrows(), alpha, x.data_ptr(), x.desc_ptr(), (REAL) 0, ret.data_ptr(), ret.desc_ptr());
55  }
56 
58  template <typename REAL>
59  mpimat<REAL> crossprod(const REAL alpha, const mpimat<REAL> &x)
60  {
61  const len_t n = x.ncols();
62  const grid g = x.get_grid();
63  mpimat<REAL> ret(g, n, n, x.bf_rows(), x.bf_cols());
64 
65  crossprod(alpha, x, ret);
66 
67  return ret;
68  }
69 
70 
71 
91  template <typename REAL>
92  void tcrossprod(const REAL alpha, const mpimat<REAL> &x, mpimat<REAL> &ret)
93  {
94  err::check_grid(x, ret);
95 
96  const len_t m = x.nrows();
97 
98  if (m != ret.nrows() || m != ret.ncols())
99  ret.resize(m, m);
100 
101  ret.fill_zero();
102  fml::pblas::syrk('L', 'N', m, x.ncols(), alpha, x.data_ptr(), x.desc_ptr(), (REAL) 0, ret.data_ptr(), ret.desc_ptr());
103  }
104 
106  template <typename REAL>
107  mpimat<REAL> tcrossprod(const REAL alpha, const mpimat<REAL> &x)
108  {
109  const len_t n = x.nrows();
110  const grid g = x.get_grid();
111  mpimat<REAL> ret(g, n, n, x.bf_rows(), x.bf_cols());
112 
113  tcrossprod(alpha, x, ret);
114 
115  return ret;
116  }
117 }
118 }
119 
120 
121 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:326
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: crossprod.hh:37
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::mpimat::fill_zero
void fill_zero()
Set all values to zero.
Definition: mpimat.hh:565
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: crossprod.hh:81