5 #ifndef FML_GPU_LINALG_DET_H
6 #define FML_GPU_LINALG_DET_H
12 #include "../arch/arch.hh"
14 #include "../internals/gpuscalar.hh"
16 #include "../gpuvec.hh"
17 #include "../gpumat.hh"
28 static __global__
void kernel_lu_pivot_sgn(
const len_t n,
int *ipiv,
int *sgn)
30 int i = blockDim.x*blockIdx.x + threadIdx.x;
34 ipiv[i] = (ipiv[i] != (i+1) ? -1 : 1);
35 atomicAdd(sgn, ipiv[i]);
38 (*sgn) = ((*sgn)%2 == 0 ? 1 : -1);
42 template <
typename REAL>
43 __global__
void kernel_det_mod(
const len_t m,
const len_t n,
const REAL *x, REAL *mod,
int *sgn)
45 int i = blockDim.x*blockIdx.x + threadIdx.x;
46 int j = blockDim.y*blockIdx.y + threadIdx.y;
48 if (i < m && j < n && i == j)
66 if (threadIdx.x == 0 && threadIdx.y == 0)
68 (*sgn) = ((*sgn)%2 == 0 ? 1 : -1);
93 template <
typename REAL>
97 throw std::runtime_error(
"'x' must be a square matrix");
99 auto c = x.get_card();
124 kernel_lu_pivot_sgn<<<p.get_griddim(), p.get_blockdim()>>>(p.
size(),
126 kernel_det_mod<<<x.get_griddim(), x.get_blockdim()>>>(x.
nrows(), x.
ncols(),
127 x.
data_ptr(), modulus_gpu.data_ptr(), sign_gpu.data_ptr());
129 sign_gpu.get_val(&sign);
130 modulus_gpu.get_val(&modulus);