fml  0.1-0
Fused Matrix Library
qr.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_LINALG_QR_H
6 #define FML_MPI_LINALG_QR_H
7 #pragma once
8 
9 
10 #include "../../_internals/linalgutils.hh"
11 #include "../../cpu/cpuvec.hh"
12 
13 #include "../copy.hh"
14 #include "../mpimat.hh"
15 
16 #include "internals/err.hh"
17 #include "internals/scalapack.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
24  namespace
25  {
26  template <typename REAL>
27  void qr_internals(const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux, cpuvec<REAL> &work)
28  {
29  const len_t m = x.nrows();
30  const len_t n = x.ncols();
31  const len_t minmn = std::min(m, n);
32 
33  const int *descx = x.desc_ptr();
34 
35  int info = 0;
36  qraux.resize(minmn);
37 
38  REAL tmp;
39  if (pivot)
40  fml::scalapack::geqpf(m, n, NULL, descx, NULL, NULL, &tmp, -1, &info);
41  else
42  fml::scalapack::geqrf(m, n, NULL, descx, NULL, &tmp, -1, &info);
43 
44  int lwork = std::max((int) tmp, 1);
45  if (lwork > work.size())
46  work.resize(lwork);
47 
48  if (pivot)
49  {
50  cpuvec<int> p(n);
51  p.fill_zero();
52  fml::scalapack::geqpf(m, n, x.data_ptr(), descx, p.data_ptr(),
53  qraux.data_ptr(), work.data_ptr(), lwork, &info);
54  }
55  else
56  fml::scalapack::geqrf(m, n, x.data_ptr(), descx, qraux.data_ptr(),
57  work.data_ptr(), lwork, &info);
58 
59  if (info != 0)
60  {
61  if (pivot)
62  fml::linalgutils::check_info(info, "geqpf");
63  else
64  fml::linalgutils::check_info(info, "geqrf");
65  }
66  }
67  }
68 
94  template <typename REAL>
95  void qr(const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux)
96  {
97  cpuvec<REAL> work;
98  qr_internals(pivot, x, qraux, work);
99  }
100 
122  template <typename REAL>
123  void qr_Q(const mpimat<REAL> &QR, const cpuvec<REAL> &qraux, mpimat<REAL> &Q, cpuvec<REAL> &work)
124  {
125  err::check_grid(QR, Q);
126 
127  const len_t m = QR.nrows();
128  const len_t n = QR.ncols();
129  const len_t minmn = std::min(m, n);
130 
131  const int *descQR = QR.desc_ptr();
132 
133  Q.resize(m, minmn);
134  const int *descQ = Q.desc_ptr();
135 
136  int info = 0;
137  REAL tmp;
138  fml::scalapack::orgqr(m, minmn, minmn, NULL, descQR, NULL,
139  &tmp, -1, &info);
140 
141  int lwork = (int) tmp;
142  if (lwork > work.size())
143  work.resize(lwork);
144 
145  fml::scalapack::lacpy('A', m, minmn, QR.data_ptr(), descQR, Q.data_ptr(),
146  descQ);
147 
148  fml::scalapack::orgqr(m, minmn, minmn, Q.data_ptr(), descQR,
149  qraux.data_ptr(), work.data_ptr(), lwork, &info);
150  fml::linalgutils::check_info(info, "orgqr");
151  }
152 
172  template <typename REAL>
173  void qr_R(const mpimat<REAL> &QR, mpimat<REAL> &R)
174  {
175  err::check_grid(QR, R);
176 
177  const len_t m = QR.nrows();
178  const len_t n = QR.ncols();
179  const len_t minmn = std::min(m, n);
180 
181  R.resize(minmn, n);
182  R.fill_zero();
183  fml::scalapack::lacpy('U', m, n, QR.data_ptr(), QR.desc_ptr(), R.data_ptr(),
184  R.desc_ptr());
185  }
186 
187 
188 
189  namespace
190  {
191  template <typename REAL>
192  void lq_internals(mpimat<REAL> &x, cpuvec<REAL> &lqaux, cpuvec<REAL> &work)
193  {
194  const len_t m = x.nrows();
195  const len_t n = x.ncols();
196  const len_t minmn = std::min(m, n);
197 
198  const int *descx = x.desc_ptr();
199 
200  int info = 0;
201  lqaux.resize(minmn);
202 
203  REAL tmp;
204  fml::scalapack::gelqf(m, n, NULL, descx, NULL, &tmp, -1, &info);
205  int lwork = std::max((int) tmp, 1);
206  if (lwork > work.size())
207  work.resize(lwork);
208 
209  fml::scalapack::gelqf(m, n, x.data_ptr(), descx, lqaux.data_ptr(),
210  work.data_ptr(), lwork, &info);
211 
212  if (info != 0)
213  fml::linalgutils::check_info(info, "gelqf");
214  }
215  }
216 
240  template <typename REAL>
241  void lq(mpimat<REAL> &x, cpuvec<REAL> &lqaux)
242  {
243  cpuvec<REAL> work;
244  lq_internals(x, lqaux, work);
245  }
246 
266  template <typename REAL>
267  void lq_L(const mpimat<REAL> &LQ, mpimat<REAL> &L)
268  {
269  err::check_grid(LQ, L);
270 
271  const len_t m = LQ.nrows();
272  const len_t n = LQ.ncols();
273  const len_t minmn = std::min(m, n);
274 
275  L.resize(m, minmn);
276  L.fill_zero();
277 
278  fml::scalapack::lacpy('L', m, n, LQ.data_ptr(), LQ.desc_ptr(), L.data_ptr(),
279  L.desc_ptr());
280  }
281 
303  template <typename REAL>
304  void lq_Q(const mpimat<REAL> &LQ, const cpuvec<REAL> &lqaux, mpimat<REAL> &Q, cpuvec<REAL> &work)
305  {
306  err::check_grid(LQ, Q);
307 
308  const len_t m = LQ.nrows();
309  const len_t n = LQ.ncols();
310  const len_t minmn = std::min(m, n);
311 
312  const int *descLQ = LQ.desc_ptr();
313 
314  Q.resize(minmn, n);
315  const int *descQ = Q.desc_ptr();
316 
317  int info = 0;
318  REAL tmp;
319  fml::scalapack::orglq(minmn, n, minmn, NULL, descLQ, NULL,
320  &tmp, -1, &info);
321 
322  int lwork = (int) tmp;
323  if (lwork > work.size())
324  work.resize(lwork);
325 
326  fml::scalapack::lacpy('A', minmn, n, LQ.data_ptr(), descLQ, Q.data_ptr(),
327  descQ);
328 
329  fml::scalapack::orglq(minmn, n, minmn, Q.data_ptr(), descQ,
330  lqaux.data_ptr(), work.data_ptr(), lwork, &info);
331  fml::linalgutils::check_info(info, "orglq");
332  }
333 }
334 }
335 
336 
337 #endif
fml::mpimat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: mpimat.hh:326
fml::mpimat
Matrix class for data distributed over MPI in the 2-d block cyclic format.
Definition: mpimat.hh:40
fml::univec::data_ptr
T * data_ptr()
Pointer to the internal array.
Definition: univec.hh:28
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::cpuvec::resize
void resize(len_t size)
Resize the internal object storage.
Definition: cpuvec.hh:210
fml::linalg::qr
void qr(const bool pivot, cpumat< REAL > &x, cpuvec< REAL > &qraux)
Computes the QR decomposition.
Definition: qr.hh:94
fml::cpuvec
Vector class for data held on a single CPU.
Definition: cpuvec.hh:31
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml::linalg::lq_L
void lq_L(const cpumat< REAL > &LQ, cpumat< REAL > &L)
Recover the L matrix from an LQ decomposition.
Definition: qr.hh:247
fml
Core namespace.
Definition: dimops.hh:10
fml::mpimat::fill_zero
void fill_zero()
Set all values to zero.
Definition: mpimat.hh:565
fml::linalg::qr_Q
void qr_Q(const cpumat< REAL > &QR, const cpuvec< REAL > &qraux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from a QR decomposition.
Definition: qr.hh:120
fml::linalg::qr_R
void qr_R(const cpumat< REAL > &QR, cpumat< REAL > &R)
Recover the R matrix from a QR decomposition.
Definition: qr.hh:162
fml::linalg::lq
void lq(cpumat< REAL > &x, cpuvec< REAL > &lqaux)
Computes the LQ decomposition.
Definition: qr.hh:223
fml::linalg::lq_Q
void lq_Q(const cpumat< REAL > &LQ, const cpuvec< REAL > &lqaux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from an LQ decomposition.
Definition: qr.hh:279