fml  0.1-0
Fused Matrix Library
qr.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_PAR_GPU_LINALG_QR_H
6 #define FML_PAR_GPU_LINALG_QR_H
7 #pragma once
8 
9 
10 #include "../parmat.hh"
11 
12 #include "../../../gpu/linalg/linalg_invert.hh"
13 #include "../../../gpu/linalg/linalg_qr.hh"
14 
15 #include "../../../gpu/copy.hh"
16 
17 #include "blas.hh"
18 #include "qr_allreduce.hh"
19 
20 
21 namespace fml
22 {
23 namespace linalg
24 {
25  namespace internals
26  {
27  template <typename REAL>
28  void qr_R(const int root, parmat_gpu<REAL> &x, gpumat<REAL> &R,
29  gpumat<REAL> &R_local, gpuvec<REAL> &qraux)
30  {
31  const len_t n = x.ncols();
32 
33  linalg::qr(false, x.data_obj(), qraux);
34  linalg::qr_R(x.data_obj(), R_local);
35 
36  R.resize(n, n);
37  tsqr::qr_allreduce(root, n, n, R_local.data_ptr(), R.data_ptr(),
38  x.get_comm().get_comm(), R.get_card());
39  }
40  }
41 
42 
43 
44  template <typename REAL>
45  void qr_R(const int root, parmat_gpu<REAL> &x, gpumat<REAL> &R)
46  {
47  if (x.nrows() < (len_global_t)x.ncols())
48  throw std::runtime_error("impossible dimensions");
49 
50  gpumat<REAL> R_local(R.get_card());
51  gpuvec<REAL> qraux(R.get_card());
52 
53  internals::qr_R(root, x, R, R_local, qraux);
54  }
55 
56 
57 
58  namespace internals
59  {
60  template <typename REAL>
61  void qr_Q(const parmat_gpu<REAL> &x, parmat_gpu<REAL> &x_cpy,
62  gpumat<REAL> &R, gpumat<REAL> &R_local, gpuvec<REAL> &qraux,
63  parmat_gpu<REAL> &Q)
64  {
65  copy::gpu2gpu(x.data_obj(), x_cpy.data_obj());
66  internals::qr_R(mpi::REDUCE_TO_ALL, x_cpy, R, R_local, qraux);
67  trinv(true, false, R);
68  matmult(x, R, Q);
69  }
70  }
71 
72  template <typename REAL>
73  void qr_Q(parmat_gpu<REAL> &x, gpuvec<REAL> &qraux, parmat_gpu<REAL> &Q)
74  {
75  gpumat<REAL> R, R_local;
76 
77  qr_R(mpi::REDUCE_TO_ALL, x, R, R_local, qraux);
78  trinv(true, false, R);
79  matmult(x, R, Q);
80  }
81 }
82 }
83 
84 
85 #endif
fml::linalg::qr
void qr(const bool pivot, cpumat< REAL > &x, cpuvec< REAL > &qraux)
Computes the QR decomposition.
Definition: qr.hh:94
fml::copy::gpu2gpu
void gpu2gpu(const gpuvec< REAL_IN > &gpu_in, gpuvec< REAL_OUT > &gpu_out)
Copy data from a GPU object to another.
Definition: copy.hh:203
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::qr_Q
void qr_Q(const cpumat< REAL > &QR, const cpuvec< REAL > &qraux, cpumat< REAL > &Q, cpuvec< REAL > &work)
Recover the Q matrix from a QR decomposition.
Definition: qr.hh:120
fml::linalg::trinv
void trinv(const bool upper, const bool unit_diag, cpumat< REAL > &x)
Compute the matrix inverse of a triangular matrix.
Definition: invert.hh:87
fml::linalg::qr_R
void qr_R(const cpumat< REAL > &QR, cpumat< REAL > &R)
Recover the R matrix from a QR decomposition.
Definition: qr.hh:162
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43