5 #ifndef FML_MPI_LINALG_QR_H
6 #define FML_MPI_LINALG_QR_H
10 #include "../../_internals/linalgutils.hh"
11 #include "../../cpu/cpuvec.hh"
14 #include "../mpimat.hh"
16 #include "internals/err.hh"
17 #include "internals/scalapack.hh"
26 template <
typename REAL>
27 void qr_internals(
const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux, cpuvec<REAL> &work)
29 const len_t m = x.nrows();
30 const len_t n = x.ncols();
31 const len_t minmn = std::min(m, n);
33 const int *descx = x.desc_ptr();
40 fml::scalapack::geqpf(m, n, NULL, descx, NULL, NULL, &tmp, -1, &info);
42 fml::scalapack::geqrf(m, n, NULL, descx, NULL, &tmp, -1, &info);
44 int lwork = std::max((
int) tmp, 1);
45 if (lwork > work.size())
52 fml::scalapack::geqpf(m, n, x.data_ptr(), descx, p.data_ptr(),
53 qraux.data_ptr(), work.data_ptr(), lwork, &info);
56 fml::scalapack::geqrf(m, n, x.data_ptr(), descx, qraux.data_ptr(),
57 work.data_ptr(), lwork, &info);
62 fml::linalgutils::check_info(info,
"geqpf");
64 fml::linalgutils::check_info(info,
"geqrf");
94 template <
typename REAL>
98 qr_internals(pivot, x, qraux, work);
122 template <
typename REAL>
125 err::check_grid(QR, Q);
127 const len_t m = QR.
nrows();
128 const len_t n = QR.
ncols();
129 const len_t minmn = std::min(m, n);
131 const int *descQR = QR.desc_ptr();
134 const int *descQ = Q.desc_ptr();
138 fml::scalapack::orgqr(m, minmn, minmn, NULL, descQR, NULL,
141 int lwork = (int) tmp;
142 if (lwork > work.size())
148 fml::scalapack::orgqr(m, minmn, minmn, Q.
data_ptr(), descQR,
149 qraux.data_ptr(), work.data_ptr(), lwork, &info);
150 fml::linalgutils::check_info(info,
"orgqr");
172 template <
typename REAL>
175 err::check_grid(QR, R);
177 const len_t m = QR.
nrows();
178 const len_t n = QR.
ncols();
179 const len_t minmn = std::min(m, n);
183 fml::scalapack::lacpy(
'U', m, n, QR.
data_ptr(), QR.desc_ptr(), R.
data_ptr(),
191 template <
typename REAL>
194 const len_t m = x.
nrows();
195 const len_t n = x.
ncols();
196 const len_t minmn = std::min(m, n);
198 const int *descx = x.desc_ptr();
204 fml::scalapack::gelqf(m, n, NULL, descx, NULL, &tmp, -1, &info);
205 int lwork = std::max((
int) tmp, 1);
206 if (lwork > work.size())
210 work.data_ptr(), lwork, &info);
213 fml::linalgutils::check_info(info,
"gelqf");
240 template <
typename REAL>
244 lq_internals(x, lqaux, work);
266 template <
typename REAL>
269 err::check_grid(LQ, L);
271 const len_t m = LQ.
nrows();
272 const len_t n = LQ.
ncols();
273 const len_t minmn = std::min(m, n);
278 fml::scalapack::lacpy(
'L', m, n, LQ.
data_ptr(), LQ.desc_ptr(), L.
data_ptr(),
303 template <
typename REAL>
306 err::check_grid(LQ, Q);
308 const len_t m = LQ.
nrows();
309 const len_t n = LQ.
ncols();
310 const len_t minmn = std::min(m, n);
312 const int *descLQ = LQ.desc_ptr();
315 const int *descQ = Q.desc_ptr();
319 fml::scalapack::orglq(minmn, n, minmn, NULL, descLQ, NULL,
322 int lwork = (int) tmp;
323 if (lwork > work.size())
329 fml::scalapack::orglq(minmn, n, minmn, Q.
data_ptr(), descQ,
330 lqaux.
data_ptr(), work.data_ptr(), lwork, &info);
331 fml::linalgutils::check_info(info,
"orglq");