5 #ifndef FML_GPU_LINALG_LINALG_EIGEN_H
6 #define FML_GPU_LINALG_LINALG_EIGEN_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpuscalar.hh"
19 #include "../gpumat.hh"
20 #include "../gpuvec.hh"
22 #include "linalg_err.hh"
31 template <
typename REAL>
32 int eig_sym_internals(
const bool only_values, gpumat<REAL> &x,
33 gpuvec<REAL> &values, gpumat<REAL> &vectors)
36 throw std::runtime_error(
"'x' must be a square matrix");
38 auto c = x.get_card();
43 cusolverEigMode_t jobz;
45 jobz = CUSOLVER_EIG_MODE_NOVECTOR;
47 jobz = CUSOLVER_EIG_MODE_VECTOR;
50 gpulapack_status_t check = gpulapack::syevd_buflen(c->lapack_handle(), jobz,
51 GPUBLAS_FILL_L, n, x.data_ptr(), n, values.data_ptr(), &lwork);
52 gpulapack::err::check_ret(check,
"syevd_bufferSize");
54 gpuvec<REAL> work(c, lwork);
57 gpuscalar<int> info_device(c, info);
59 check = gpulapack::syevd(c->lapack_handle(), jobz, GPUBLAS_FILL_L,
60 n, x.data_ptr(), n, values.data_ptr(), work.data_ptr(), lwork,
61 info_device.data_ptr());
63 info_device.get_val(&info);
64 gpulapack::err::check_ret(check,
"syevd");
65 fml::linalgutils::check_info(info,
"syevd");
99 template <
typename REAL>
102 err::check_card(x, values);
105 int info = eig_sym_internals(
true, x, values, ignored);
106 fml::linalgutils::check_info(info,
"syevd");
110 template <
typename REAL>
113 err::check_card(x, values, values);
115 int info = eig_sym_internals(
false, x, values, vectors);
116 fml::linalgutils::check_info(info,
"syevd");