5 #ifndef FML_GPU_LINALG_LINALG_CHOL_H
6 #define FML_GPU_LINALG_LINALG_CHOL_H
12 #include "../../_internals/linalgutils.hh"
14 #include "../arch/arch.hh"
16 #include "../internals/gpu_utils.hh"
17 #include "../internals/gpuscalar.hh"
19 #include "../gpumat.hh"
21 #include "linalg_err.hh"
47 template <
typename REAL>
50 const len_t n = x.
nrows();
52 throw std::runtime_error(
"'x' must be a square matrix");
54 auto c = x.get_card();
55 const auto fill = GPUBLAS_FILL_L;
58 gpulapack_status_t check = gpulapack::potrf_buflen(c->lapack_handle(), fill, n,
60 gpulapack::err::check_ret(check,
"potrf_bufferSize");
66 check = gpulapack::potrf(c->lapack_handle(), fill, n, x.
data_ptr(), n,
67 work.data_ptr(), lwork, info_device.data_ptr());
69 info_device.get_val(&info);
70 gpulapack::err::check_ret(check,
"potrf");
72 fml::linalgutils::check_info(info,
"potrf");
74 throw std::runtime_error(
"chol: leading minor of order " + std::to_string(info) +
" is not positive definite");
76 fml::gpu_utils::tri2zero(
'U',
false, n, n, x.
data_ptr(), n);