5 #ifndef FML_GPU_LINALG_CHOL_H
6 #define FML_GPU_LINALG_CHOL_H
10 #include "../../_internals/linalgutils.hh"
12 #include "../arch/arch.hh"
14 #include "../internals/gpu_utils.hh"
15 #include "../internals/gpuscalar.hh"
17 #include "../gpumat.hh"
43 template <
typename REAL>
46 const len_t n = x.
nrows();
48 throw std::runtime_error(
"'x' must be a square matrix");
50 auto c = x.get_card();
51 const auto fill = GPUBLAS_FILL_L;
54 gpulapack_status_t check = gpulapack::potrf_buflen(c->lapack_handle(), fill, n,
56 gpulapack::err::check_ret(check,
"potrf_bufferSize");
62 check = gpulapack::potrf(c->lapack_handle(), fill, n, x.
data_ptr(), n,
63 work.data_ptr(), lwork, info_device.data_ptr());
65 info_device.get_val(&info);
66 gpulapack::err::check_ret(check,
"potrf");
68 fml::linalgutils::check_info(info,
"potrf");
70 throw std::runtime_error(
"chol: leading minor of order " + std::to_string(info) +
" is not positive definite");
72 fml::gpu_utils::tri2zero(
'U',
false, n, n, x.
data_ptr(), n);