5 #ifndef FML_PAR_CPU_LINALG_QR_ALLREDUCE_H
6 #define FML_PAR_CPU_LINALG_QR_ALLREDUCE_H
10 #include "../../../_internals/arraytools/src/arraytools.hpp"
11 #include "../../../_internals/omp.hh"
12 #include "../../../_internals/restrict.hh"
14 #include "../../internals/mpi_utils.hh"
16 #include "../../../cpu/linalg/lapack.hh"
17 #include "../../../cpu/cpumat.hh"
27 inline int _m, _n, minmn, mtb;
30 template <
typename REAL>
32 template <
typename REAL>
34 template <
typename REAL>
39 template <
typename REAL>
40 void qr_global_cleanup()
42 arraytools::free(tallboy<REAL>);
45 arraytools::free(work<REAL>);
48 arraytools::free(qraux<REAL>);
54 template <
typename REAL>
55 static inline int qrworksize(
const int m,
const int n)
60 fml::lapack::geqrf(m, n, NULL, m, NULL, &tmp, -1, &info);
61 int lwork = (int) tmp;
63 return std::max(lwork, 1);
68 template <
typename REAL>
69 void qr_global_init(
int m,
int n)
73 minmn = std::min(_m, _n);
78 arraytools::alloc(mtb*_n, &(tallboy<REAL>));
79 lwork = qrworksize<REAL>(mtb, _n);
80 arraytools::alloc(lwork, &(work<REAL>));
81 arraytools::alloc(minmn, &(qraux<REAL>));
83 arraytools::check_alloc(tallboy<REAL>, work<REAL>, qraux<REAL>);
88 template <
typename REAL>
89 void custom_op_qr(
void *a_,
void *b_,
int *len, MPI_Datatype *dtype)
97 #pragma omp parallel for default(shared) if(_m*_n > omp::OMP_MIN_SIZE)
98 for (
int j=0; j<_n; j++)
101 for (
int i=0; i<_m; i++)
102 tallboy<REAL>[i + mtb*j] = a[i + _m*j];
105 for (
int i=0; i<_m; i++)
106 tallboy<REAL>[_m+i + mtb*j] = b[i + _m*j];
110 fml::lapack::geqrf(mtb, _n, tallboy<REAL>, mtb, qraux<REAL>, work<REAL>, lwork, &info);
114 for (
int j=0; j<_n; j++)
117 for (
int i=0; i<=j; i++)
118 b[i + _m*j] = tallboy<REAL>[i + mtb*j];
121 for (
int i=j+1; i<_m; i++)
122 b[i + _m*j] = (REAL) 0.f;
129 template <
typename REAL>
130 void qr_allreduce(
const int root,
const int m,
const int n,
131 const REAL *
const restrict a, REAL *
const restrict b, MPI_Comm comm)
135 internals::qr_global_init<REAL>(m, n);
138 MPI_Datatype mat_type;
139 mpi::contig_type(m*n, a, &mat_type);
143 const int commutative = 1;
145 MPI_Op_create((MPI_User_function*) internals::custom_op_qr<REAL>, commutative, &op);
146 if (root == mpi::REDUCE_TO_ALL)
147 mpi_ret = MPI_Allreduce(a, b, 1, mat_type, op, comm);
149 mpi_ret = MPI_Reduce(a, b, 1, mat_type, op, root, comm);
153 MPI_Type_free(&mat_type);
155 internals::qr_global_cleanup<REAL>();
157 mpi::check_MPI_ret(mpi_ret);
158 if (internals::badinfo)
159 throw std::runtime_error(
"unrecoverable error with LAPACK function geqrf() occurred during reduction");