fml  0.1-0
Fused Matrix Library
trace.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_TRACE_H
6 #define FML_MPI_LINALG_TRACE_H
7 #pragma once
8 
9 
10 #include "../internals/bcutils.hh"
11 
12 #include "../mpimat.hh"
13 
14 
15 namespace fml
16 {
17 namespace linalg
18 {
28  template <typename REAL>
29  REAL trace(const mpimat<REAL> &x)
30  {
31  const REAL *x_d = x.data_ptr();
32  const len_t minmn = std::min(x.nrows(), x.ncols());
33  const len_t m_local = x.nrows_local();
34  const int mb = x.bf_rows();
35  const int nb = x.bf_cols();
36  const grid g = x.get_grid();
37 
38  REAL tr = 0;
39  for (len_t gi=0; gi<minmn; gi++)
40  {
41  const len_local_t i = fml::bcutils::g2l(gi, mb, g.nprow());
42  const len_local_t j = fml::bcutils::g2l(gi, nb, g.npcol());
43 
44  const int pr = fml::bcutils::g2p(gi, mb, g.nprow());
45  const int pc = fml::bcutils::g2p(gi, nb, g.npcol());
46 
47  if (pr == g.myrow() && pc == g.mycol())
48  tr += x_d[i + m_local*j];
49  }
50 
51  g.allreduce(1, 1, &tr, 'A');
52 
53  return tr;
54  }
55 }
56 }
57 
58 
59 #endif
fml::grid::allreduce
void allreduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM) const
Sum reduce operation across all processes in the grid.
Definition: grid.hh:420
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::grid::mycol
int mycol() const
The process column (0-based index) of the calling process.
Definition: grid.hh:129
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::linalg::trace
REAL trace(const cpumat< REAL > &x)
Computes the trace, i.e. the sum of the diagonal.
Definition: trace.hh:27
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::grid::npcol
int npcol() const
The number of processes columns in the BLACS context.
Definition: grid.hh:125
fml::grid::myrow
int myrow() const
The process row (0-based index) of the calling process.
Definition: grid.hh:127
fml::grid::nprow
int nprow() const
The number of processes rows in the BLACS context.
Definition: grid.hh:123