5 #ifndef FML_TSSVD_GPU_QR_ALLREDUCE_H
6 #define FML_TSSVD_GPU_QR_ALLREDUCE_H
10 #include <fml/gpu/card.hh>
13 #include "../globals.hh"
14 #include "../mpi_utils.hh"
15 #include "../restrict.hh"
25 cusolverStatus_t check;
26 dim3 griddim, blockdim;
31 template <
typename REAL>
32 void qr_global_cleanup()
34 c->mem_free(tallboy<REAL>);
37 c->mem_free(work<REAL>);
40 c->mem_free(qraux<REAL>);
43 c->mem_free(info_dev);
50 template <
typename REAL>
51 static inline int qrworksize(
const int m,
const int n)
56 check = fml::gpulapack::geqrf_buflen(c->lapack_handle(), m, n, &tmp, m, &lwork);
59 return std::max(lwork, 1);
64 template <
typename REAL>
65 void qr_global_init(fml::card_sp_t c_,
int m,
int n)
68 blockdim = fml::kernel_launcher::dim_block2();
69 griddim = fml::kernel_launcher::dim_grid(m, n);
73 minmn = std::min(_m, _n);
78 tallboy<REAL> = (REAL*) c->mem_alloc((
size_t)mtb*_n*
sizeof(REAL));
80 lwork = qrworksize<REAL>(mtb, _n);
81 work<REAL> = (REAL*) c->mem_alloc((
size_t)lwork*
sizeof(REAL));
83 qraux<REAL> = (REAL*) c->mem_alloc((
size_t)minmn*
sizeof(REAL));
85 info_dev = (
int*) c->mem_alloc(
sizeof(
int));
90 template <
typename REAL>
91 __global__
void kernel_stack(
const len_t m,
const len_t n,
const len_t mtb,
92 const REAL *a,
const REAL *b, REAL *tallboy)
94 int i = blockDim.x*blockIdx.x + threadIdx.x;
95 int j = blockDim.y*blockIdx.y + threadIdx.y;
99 tallboy[i + mtb*j] = a[i + m*j];
100 tallboy[m+i + mtb*j] = b[i + m*j];
106 template <
typename REAL>
107 void custom_op_qr(
void *a_,
void *b_,
int *len, MPI_Datatype *dtype)
115 kernel_stack<<<griddim, blockdim>>>(_m, _n, mtb, a, b, tallboy<REAL>);
118 c->mem_set(info_dev, 0,
sizeof(
int));
119 check = fml::gpulapack::geqrf(c->lapack_handle(), mtb, _n, tallboy<REAL>, mtb, qraux<REAL>, work<REAL>, lwork, info_dev);
120 c->mem_gpu2cpu(&info, info_dev,
sizeof(
int));
125 c->mem_set(b, 0, (
size_t)_m*_n*
sizeof(REAL));
126 fml::gpu_utils::lacpy(
'U', _m, _n, tallboy<REAL>, mtb, b, _m);
132 template <
typename REAL>
133 void qr_allreduce(
const int root,
const int m,
const int n,
134 const REAL *
const restrict a, REAL *
const restrict b, MPI_Comm comm,
139 internals::qr_global_init<REAL>(c_, m, n);
142 MPI_Datatype mat_type;
143 mpi::contig_type(m*n, a, &mat_type);
147 const int commutative = 1;
149 MPI_Op_create((MPI_User_function*) internals::custom_op_qr<REAL>, commutative, &op);
150 if (root == defs::REDUCE_TO_ALL)
151 mpi_ret = MPI_Allreduce(a, b, 1, mat_type, op, comm);
153 mpi_ret = MPI_Reduce(a, b, 1, mat_type, op, root, comm);
157 MPI_Type_free(&mat_type);
159 internals::qr_global_cleanup<REAL>();
161 mpi::check_MPI_ret(mpi_ret);
162 if (internals::badinfo)
163 throw std::runtime_error(
"unrecoverable error with LAPACK function geqrf() occurred during reduction");