fml  0.1-0
Fused Matrix Library
chol.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_CHOL_H
6 #define FML_GPU_LINALG_CHOL_H
7 #pragma once
8 
9 
10 #include "../../_internals/linalgutils.hh"
11 
12 #include "../arch/arch.hh"
13 
14 #include "../internals/gpu_utils.hh"
15 #include "../internals/gpuscalar.hh"
16 
17 #include "../gpumat.hh"
18 
19 
20 namespace fml
21 {
22 namespace linalg
23 {
43  template <typename REAL>
44  void chol(gpumat<REAL> &x)
45  {
46  const len_t n = x.nrows();
47  if (n != x.ncols())
48  throw std::runtime_error("'x' must be a square matrix");
49 
50  auto c = x.get_card();
51  const auto fill = GPUBLAS_FILL_L;
52 
53  int lwork;
54  gpulapack_status_t check = gpulapack::potrf_buflen(c->lapack_handle(), fill, n,
55  x.data_ptr(), n, &lwork);
56  gpulapack::err::check_ret(check, "potrf_bufferSize");
57 
58  gpuvec<REAL> work(c, lwork);
59 
60  int info = 0;
61  gpuscalar<int> info_device(c, info);
62  check = gpulapack::potrf(c->lapack_handle(), fill, n, x.data_ptr(), n,
63  work.data_ptr(), lwork, info_device.data_ptr());
64 
65  info_device.get_val(&info);
66  gpulapack::err::check_ret(check, "potrf");
67  if (info < 0)
68  fml::linalgutils::check_info(info, "potrf");
69  else if (info > 0)
70  throw std::runtime_error("chol: leading minor of order " + std::to_string(info) + " is not positive definite");
71 
72  fml::gpu_utils::tri2zero('U', false, n, n, x.data_ptr(), n);
73  }
74 }
75 }
76 
77 
78 #endif
fml::gpuvec
Vector class for data held on a single GPU.
Definition: gpuvec.hh:32
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::chol
void chol(cpumat< REAL > &x)
Compute the Choleski factorization.
Definition: chol.hh:46
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::gpuscalar
Definition: gpuscalar.hh:16