5 #ifndef FML_MPI_LINALG_LINALG_QR_H
6 #define FML_MPI_LINALG_LINALG_QR_H
12 #include "../../_internals/linalgutils.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/bcutils.hh"
16 #include "../internals/mpi_utils.hh"
19 #include "../mpimat.hh"
21 #include "linalg_blas.hh"
22 #include "linalg_err.hh"
23 #include "scalapack.hh"
32 template <
typename REAL>
33 void qr_internals(
const bool pivot, mpimat<REAL> &x, cpuvec<REAL> &qraux, cpuvec<REAL> &work)
35 const len_t m = x.nrows();
36 const len_t n = x.ncols();
37 const len_t minmn = std::min(m, n);
39 const int *descx = x.desc_ptr();
46 fml::scalapack::geqpf(m, n, NULL, descx, NULL, NULL, &tmp, -1, &info);
48 fml::scalapack::geqrf(m, n, NULL, descx, NULL, &tmp, -1, &info);
50 int lwork = std::max((
int) tmp, 1);
51 if (lwork > work.size())
58 fml::scalapack::geqpf(m, n, x.data_ptr(), descx, p.data_ptr(),
59 qraux.data_ptr(), work.data_ptr(), lwork, &info);
62 fml::scalapack::geqrf(m, n, x.data_ptr(), descx, qraux.data_ptr(),
63 work.data_ptr(), lwork, &info);
68 fml::linalgutils::check_info(info,
"geqpf");
70 fml::linalgutils::check_info(info,
"geqrf");
100 template <
typename REAL>
104 qr_internals(pivot, x, qraux, work);
128 template <
typename REAL>
131 err::check_grid(QR, Q);
133 const len_t m = QR.
nrows();
134 const len_t n = QR.
ncols();
135 const len_t minmn = std::min(m, n);
137 const int *descQR = QR.desc_ptr();
140 const int *descQ = Q.desc_ptr();
144 fml::scalapack::orgqr(m, minmn, minmn, NULL, descQR, NULL,
147 int lwork = (int) tmp;
148 if (lwork > work.size())
154 fml::scalapack::orgqr(m, minmn, minmn, Q.
data_ptr(), descQR,
155 qraux.data_ptr(), work.data_ptr(), lwork, &info);
156 fml::linalgutils::check_info(info,
"orgqr");
178 template <
typename REAL>
181 err::check_grid(QR, R);
183 const len_t m = QR.
nrows();
184 const len_t n = QR.
ncols();
185 const len_t minmn = std::min(m, n);
189 fml::scalapack::lacpy(
'U', m, n, QR.
data_ptr(), QR.desc_ptr(), R.
data_ptr(),
197 template <
typename REAL>
200 const len_t m = x.
nrows();
201 const len_t n = x.
ncols();
202 const len_t minmn = std::min(m, n);
204 const int *descx = x.desc_ptr();
210 fml::scalapack::gelqf(m, n, NULL, descx, NULL, &tmp, -1, &info);
211 int lwork = std::max((
int) tmp, 1);
212 if (lwork > work.size())
216 work.data_ptr(), lwork, &info);
219 fml::linalgutils::check_info(info,
"gelqf");
246 template <
typename REAL>
250 lq_internals(x, lqaux, work);
272 template <
typename REAL>
275 err::check_grid(LQ, L);
277 const len_t m = LQ.
nrows();
278 const len_t n = LQ.
ncols();
279 const len_t minmn = std::min(m, n);
284 fml::scalapack::lacpy(
'L', m, n, LQ.
data_ptr(), LQ.desc_ptr(), L.
data_ptr(),
309 template <
typename REAL>
312 err::check_grid(LQ, Q);
314 const len_t m = LQ.
nrows();
315 const len_t n = LQ.
ncols();
316 const len_t minmn = std::min(m, n);
318 const int *descLQ = LQ.desc_ptr();
321 const int *descQ = Q.desc_ptr();
325 fml::scalapack::orglq(minmn, n, minmn, NULL, descLQ, NULL,
328 int lwork = (int) tmp;
329 if (lwork > work.size())
335 fml::scalapack::orglq(minmn, n, minmn, Q.
data_ptr(), descQ,
336 lqaux.
data_ptr(), work.data_ptr(), lwork, &info);
337 fml::linalgutils::check_info(info,
"orglq");