5 #ifndef FML_GPU_LINALG_TRACE_H
6 #define FML_GPU_LINALG_TRACE_H
10 #include "../arch/arch.hh"
12 #include "../internals/gpuscalar.hh"
14 #include "../gpumat.hh"
25 template <
typename REAL>
26 __global__
void kernel_trace(
const len_t m,
const len_t n,
const REAL *data, REAL *tr)
28 int i = blockDim.x*blockIdx.x + threadIdx.x;
29 int j = blockDim.y*blockIdx.y + threadIdx.y;
31 if (i < m && j < n && i == j)
32 atomicAdd(tr, data[i + m*i]);
45 template <
typename REAL>
48 const len_t m = x.
nrows();
49 const len_t n = x.
ncols();
50 auto c = x.get_card();
55 kernel_trace<<<x.get_griddim(), x.get_blockdim()>>>(m, n, x.
data_ptr(),