fml  0.1-0
Fused Matrix Library
linalg_misc.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_LINALG_MISC_H
6 #define FML_MPI_LINALG_LINALG_MISC_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../cpu/cpuvec.hh"
13 
14 #include "../mpimat.hh"
15 
16 #include "linalg_qr.hh"
17 #include "linalg_svd.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
42  template <typename REAL>
43  void det(mpimat<REAL> &x, int &sign, REAL &modulus)
44  {
45  if (!x.is_square())
46  throw std::runtime_error("'x' must be a square matrix");
47 
48  cpuvec<int> p;
49  int info;
50  lu(x, p, info);
51 
52  if (info != 0)
53  {
54  if (info > 0)
55  {
56  sign = 1;
57  modulus = -INFINITY;
58  return;
59  }
60  else
61  return;
62  }
63 
64 
65  // get determinant
66  REAL mod = 0.0;
67  int sgn = 1;
68 
69  const len_t m_local = x.nrows_local();
70  const len_t n_local = x.ncols_local();
71 
72  const int *ipiv = p.data_ptr();
73  const REAL *a = x.data_ptr();
74  const grid g = x.get_grid();
75 
76  for (len_t i=0; i<m_local; i++)
77  {
78  len_t gi = fml::bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow());
79 
80  if (ipiv[i] != (gi + 1))
81  sgn = -sgn;
82  }
83 
84  for (len_t j=0; j<n_local; j++)
85  {
86  for (len_t i=0; i<m_local; i++)
87  {
88  len_t gi = fml::bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow());
89  len_t gj = fml::bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol());
90 
91  if (gi == gj)
92  {
93  const REAL d = a[i + m_local*j];
94  if (d < 0)
95  {
96  mod += log(-d);
97  sgn *= -1;
98  }
99  else
100  mod += log(d);
101  }
102  }
103  }
104 
105  g.allreduce(1, 1, &mod);
106 
107  sgn = (sgn<0 ? 1 : 0);
108  g.allreduce(1, 1, &sgn, 'C');
109  sgn = (sgn%2==0 ? 1 : -1);
110 
111  modulus = mod;
112  sign = sgn;
113  }
114 
115 
116 
126  template <typename REAL>
127  REAL trace(const mpimat<REAL> &x)
128  {
129  const REAL *x_d = x.data_ptr();
130  const len_t minmn = std::min(x.nrows(), x.ncols());
131  const len_t m_local = x.nrows_local();
132  const int mb = x.bf_rows();
133  const int nb = x.bf_cols();
134  const grid g = x.get_grid();
135 
136  REAL tr = 0;
137  for (len_t gi=0; gi<minmn; gi++)
138  {
139  const len_local_t i = fml::bcutils::g2l(gi, mb, g.nprow());
140  const len_local_t j = fml::bcutils::g2l(gi, nb, g.npcol());
141 
142  const int pr = fml::bcutils::g2p(gi, mb, g.nprow());
143  const int pc = fml::bcutils::g2p(gi, nb, g.npcol());
144 
145  if (pr == g.myrow() && pc == g.mycol())
146  tr += x_d[i + m_local*j];
147  }
148 
149  g.allreduce(1, 1, &tr, 'A');
150 
151  return tr;
152  }
153 }
154 }
155 
156 
157 #endif
fml::grid::allreduce
void allreduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM) const
Sum reduce operation across all processes in the grid.
Definition: grid.hh:420
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::grid::mycol
int mycol() const
The process column (0-based index) of the calling process.
Definition: grid.hh:129
fml::unimat::is_square
bool is_square() const
Is the matrix square?
Definition: unimat.hh:34
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::linalg::trace
REAL trace(const cpumat< REAL > &x)
Computes the trace, i.e. the sum of the diagonal.
Definition: linalg_misc.hh:109
fml::linalg::lu
void lu(cpumat< REAL > &x, cpuvec< int > &p, int &info)
Computes the PLU factorization with partial pivoting.
Definition: linalg_lu.hh:48
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::linalg::det
void det(cpumat< REAL > &x, int &sign, REAL &modulus)
Computes the determinant in logarithmic form.
Definition: linalg_misc.hh:46
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::grid::npcol
int npcol() const
The number of processes columns in the BLACS context.
Definition: grid.hh:125
fml::grid::myrow
int myrow() const
The process row (0-based index) of the calling process.
Definition: grid.hh:127
fml::grid::nprow
int nprow() const
The number of processes rows in the BLACS context.
Definition: grid.hh:123