5 #ifndef FML_PAR_GPU_LINALG_QR_ALLREDUCE_NOGPUDIRECT_H
6 #define FML_PAR_GPU_LINALG_QR_ALLREDUCE_NOGPUDIRECT_H
10 #include "../../../_internals/arraytools/src/arraytools.hpp"
11 #include "../../../_internals/restrict.hh"
13 #include "../../../gpu/card.hh"
15 #include "../../internals/mpi_utils.hh"
26 cusolverStatus_t check;
27 dim3 griddim, blockdim;
31 int _m, _n, minmn, mtb;
34 template <
typename REAL>
36 template <
typename REAL>
38 template <
typename REAL>
41 template <
typename REAL>
43 template <
typename REAL>
48 template <
typename REAL>
49 void qr_global_cleanup()
51 c->mem_free(tallboy<REAL>);
54 c->mem_free(work<REAL>);
57 c->mem_free(qraux<REAL>);
60 c->mem_free(info_dev);
67 template <
typename REAL>
68 static inline int qrworksize(
const int m,
const int n)
73 check = fml::gpulapack::geqrf_buflen(c->lapack_handle(), m, n, &tmp, m, &lwork);
78 return std::max(lwork, 1);
83 template <
typename REAL>
84 void qr_global_init(fml::card_sp_t c_,
int m,
int n)
87 blockdim = fml::kernel_launcher::dim_block2();
88 griddim = fml::kernel_launcher::dim_grid(m, n);
92 minmn = std::min(_m, _n);
97 tallboy<REAL> = (REAL*) c->mem_alloc((
size_t)mtb*_n*
sizeof(REAL));
99 lwork = qrworksize<REAL>(mtb, _n);
100 work<REAL> = (REAL*) c->mem_alloc((
size_t)lwork*
sizeof(REAL));
102 qraux<REAL> = (REAL*) c->mem_alloc((
size_t)minmn*
sizeof(REAL));
104 info_dev = (
int*) c->mem_alloc(
sizeof(
int));
109 template <
typename REAL>
110 __global__
void kernel_stack(
const len_t m,
const len_t n,
const len_t mtb,
111 const REAL *a,
const REAL *b, REAL *tallboy)
113 int i = blockDim.x*blockIdx.x + threadIdx.x;
114 int j = blockDim.y*blockIdx.y + threadIdx.y;
118 tallboy[i + mtb*j] = a[i + m*j];
119 tallboy[m+i + mtb*j] = b[i + m*j];
125 template <
typename REAL>
126 void custom_op_qr(
void *a_,
void *b_,
int *len, MPI_Datatype *dtype)
131 REAL *a_cpu = (REAL*)a_;
132 REAL *b_cpu = (REAL*)b_;
134 c->mem_cpu2gpu(a<REAL>, a_cpu, _m*_n*
sizeof(REAL));
135 c->mem_cpu2gpu(b<REAL>, b_cpu, _m*_n*
sizeof(REAL));
138 kernel_stack<<<griddim, blockdim>>>(_m, _n, mtb, a<REAL>, b<REAL>, tallboy<REAL>);
141 c->mem_set(info_dev, 0,
sizeof(
int));
142 check = fml::gpulapack::geqrf(c->lapack_handle(), mtb, _n, tallboy<REAL>, mtb, qraux<REAL>, work<REAL>, lwork, info_dev);
143 c->mem_gpu2cpu(&info, info_dev,
sizeof(
int));
148 c->mem_set(b<REAL>, 0, (
size_t)_m*_n*
sizeof(REAL));
149 fml::gpu_utils::lacpy(
'U', _m, _n, tallboy<REAL>, mtb, b<REAL>, _m);
151 c->mem_gpu2cpu(b_cpu, b<REAL>, _m*_n*
sizeof(REAL));
157 template <
typename REAL>
158 void qr_allreduce(
const int root,
const int m,
const int n,
159 REAL *
const restrict a_, REAL *
const restrict b_, MPI_Comm comm,
164 internals::qr_global_init<REAL>(c_, m, n);
166 internals::a<REAL> = a_;
167 internals::b<REAL> = b_;
170 arraytools::alloc(m, n, &a_cpu);
171 arraytools::alloc(m, n, &b_cpu);
172 arraytools::check_alloc(a_cpu, b_cpu);
174 c_->mem_gpu2cpu(a_cpu, a_, m*n*
sizeof(REAL));
178 MPI_Datatype mat_type;
179 mpi::contig_type(m*n, a_cpu, &mat_type);
183 const int commutative = 1;
185 MPI_Op_create((MPI_User_function*) internals::custom_op_qr<REAL>, commutative, &op);
186 if (root == mpi::REDUCE_TO_ALL)
187 mpi_ret = MPI_Allreduce(a_cpu, b_cpu, 1, mat_type, op, comm);
189 mpi_ret = MPI_Reduce(a_cpu, b_cpu, 1, mat_type, op, root, comm);
192 c_->mem_cpu2gpu(internals::b<REAL>, b_cpu, m*n*
sizeof(REAL));
195 MPI_Type_free(&mat_type);
197 internals::qr_global_cleanup<REAL>();
199 arraytools::free(a_cpu);
200 arraytools::free(b_cpu);
202 mpi::check_MPI_ret(mpi_ret);
203 if (internals::badinfo)
204 throw std::runtime_error(
"unrecoverable error with LAPACK function geqrf() occurred during reduction");