fml  0.1-0
Fused Matrix Library
qr_allreduce.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_PAR_CPU_LINALG_QR_ALLREDUCE_H
6 #define FML_PAR_CPU_LINALG_QR_ALLREDUCE_H
7 #pragma once
8 
9 
10 #include "../../../_internals/arraytools/src/arraytools.hpp"
11 #include "../../../_internals/omp.hh"
12 #include "../../../_internals/restrict.hh"
13 
14 #include "../../internals/mpi_utils.hh"
15 
16 #include "../../../cpu/linalg/lapack.hh"
17 #include "../../../cpu/cpumat.hh"
18 
19 
20 namespace fml
21 {
22 namespace tsqr
23 {
24  namespace internals
25  {
26  inline bool badinfo;
27  inline int _m, _n, minmn, mtb;
28  inline int lwork;
29 
30  template <typename REAL>
31  inline REAL *tallboy;
32  template <typename REAL>
33  inline REAL *work;
34  template <typename REAL>
35  inline REAL *qraux;
36 
37 
38 
39  template <typename REAL>
40  void qr_global_cleanup()
41  {
42  arraytools::free(tallboy<REAL>);
43  tallboy<REAL> = NULL;
44 
45  arraytools::free(work<REAL>);
46  work<REAL> = NULL;
47 
48  arraytools::free(qraux<REAL>);
49  qraux<REAL> = NULL;
50  }
51 
52 
53 
54  template <typename REAL>
55  static inline int qrworksize(const int m, const int n)
56  {
57  REAL tmp;
58 
59  int info;
60  fml::lapack::geqrf(m, n, NULL, m, NULL, &tmp, -1, &info);
61  int lwork = (int) tmp;
62 
63  return std::max(lwork, 1);
64  }
65 
66 
67 
68  template <typename REAL>
69  void qr_global_init(int m, int n)
70  {
71  _m = m;
72  _n = n;
73  minmn = std::min(_m, _n);
74  mtb = 2*_m;
75 
76  badinfo = false;
77 
78  arraytools::alloc(mtb*_n, &(tallboy<REAL>));
79  lwork = qrworksize<REAL>(mtb, _n);
80  arraytools::alloc(lwork, &(work<REAL>));
81  arraytools::alloc(minmn, &(qraux<REAL>));
82 
83  arraytools::check_alloc(tallboy<REAL>, work<REAL>, qraux<REAL>);
84  }
85 
86 
87 
88  template <typename REAL>
89  void custom_op_qr(void *a_, void *b_, int *len, MPI_Datatype *dtype)
90  {
91  (void)len;
92  (void)dtype;
93 
94  REAL *a = (REAL*)a_;
95  REAL *b = (REAL*)b_;
96 
97  #pragma omp parallel for default(shared) if(_m*_n > omp::OMP_MIN_SIZE)
98  for (int j=0; j<_n; j++)
99  {
100  #pragma omp simd
101  for (int i=0; i<_m; i++)
102  tallboy<REAL>[i + mtb*j] = a[i + _m*j];
103 
104  #pragma omp simd
105  for (int i=0; i<_m; i++)
106  tallboy<REAL>[_m+i + mtb*j] = b[i + _m*j];
107  }
108 
109  int info = 0;
110  fml::lapack::geqrf(mtb, _n, tallboy<REAL>, mtb, qraux<REAL>, work<REAL>, lwork, &info);
111  if (info != 0)
112  badinfo = true;
113 
114  for (int j=0; j<_n; j++)
115  {
116  #pragma omp for simd
117  for (int i=0; i<=j; i++)
118  b[i + _m*j] = tallboy<REAL>[i + mtb*j];
119 
120  #pragma omp for simd
121  for (int i=j+1; i<_m; i++)
122  b[i + _m*j] = (REAL) 0.f;
123  }
124  }
125  }
126 
127 
128 
129  template <typename REAL>
130  void qr_allreduce(const int root, const int m, const int n,
131  const REAL *const restrict a, REAL *const restrict b, MPI_Comm comm)
132  {
133  int mpi_ret;
134 
135  internals::qr_global_init<REAL>(m, n);
136 
137  // custom data type
138  MPI_Datatype mat_type;
139  mpi::contig_type(m*n, a, &mat_type);
140 
141  // custom op + reduce
142  MPI_Op op;
143  const int commutative = 1;
144 
145  MPI_Op_create((MPI_User_function*) internals::custom_op_qr<REAL>, commutative, &op);
146  if (root == mpi::REDUCE_TO_ALL)
147  mpi_ret = MPI_Allreduce(a, b, 1, mat_type, op, comm);
148  else
149  mpi_ret = MPI_Reduce(a, b, 1, mat_type, op, root, comm);
150 
151  // cleanup and return
152  MPI_Op_free(&op);
153  MPI_Type_free(&mat_type);
154 
155  internals::qr_global_cleanup<REAL>();
156 
157  mpi::check_MPI_ret(mpi_ret);
158  if (internals::badinfo)
159  throw std::runtime_error("unrecoverable error with LAPACK function geqrf() occurred during reduction");
160  }
161 }
162 }
163 
164 
165 #endif
fml
Core namespace.
Definition: dimops.hh:10