fml  0.1-0
Fused Matrix Library
qr_allreduce_gpudirect.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_TSSVD_GPU_QR_ALLREDUCE_H
6 #define FML_TSSVD_GPU_QR_ALLREDUCE_H
7 #pragma once
8 
9 
10 #include <fml/gpu/card.hh>
11 
12 #include "../defs.hh"
13 #include "../globals.hh"
14 #include "../mpi_utils.hh"
15 #include "../restrict.hh"
16 
17 
18 namespace fml
19 {
20 namespace tsqr
21 {
22  namespace internals
23  {
24  fml::card_sp_t c;
25  cusolverStatus_t check;
26  dim3 griddim, blockdim;
27  int *info_dev;
28 
29 
30 
31  template <typename REAL>
32  void qr_global_cleanup()
33  {
34  c->mem_free(tallboy<REAL>);
35  tallboy<REAL> = NULL;
36 
37  c->mem_free(work<REAL>);
38  work<REAL> = NULL;
39 
40  c->mem_free(qraux<REAL>);
41  qraux<REAL> = NULL;
42 
43  c->mem_free(info_dev);
44 
45  c.reset();
46  }
47 
48 
49 
50  template <typename REAL>
51  static inline int qrworksize(const int m, const int n)
52  {
53  REAL tmp;
54 
55  int info;
56  check = fml::gpulapack::geqrf_buflen(c->lapack_handle(), m, n, &tmp, m, &lwork);
57  // TODO check
58 
59  return std::max(lwork, 1);
60  }
61 
62 
63 
64  template <typename REAL>
65  void qr_global_init(fml::card_sp_t c_, int m, int n)
66  {
67  c = c_;
68  blockdim = fml::kernel_launcher::dim_block2();
69  griddim = fml::kernel_launcher::dim_grid(m, n);
70 
71  _m = m;
72  _n = n;
73  minmn = std::min(_m, _n);
74  mtb = 2*_m;
75 
76  badinfo = false;
77 
78  tallboy<REAL> = (REAL*) c->mem_alloc((size_t)mtb*_n*sizeof(REAL));
79 
80  lwork = qrworksize<REAL>(mtb, _n);
81  work<REAL> = (REAL*) c->mem_alloc((size_t)lwork*sizeof(REAL));
82 
83  qraux<REAL> = (REAL*) c->mem_alloc((size_t)minmn*sizeof(REAL));
84 
85  info_dev = (int*) c->mem_alloc(sizeof(int));
86  }
87 
88 
89 
90  template <typename REAL>
91  __global__ void kernel_stack(const len_t m, const len_t n, const len_t mtb,
92  const REAL *a, const REAL *b, REAL *tallboy)
93  {
94  int i = blockDim.x*blockIdx.x + threadIdx.x;
95  int j = blockDim.y*blockIdx.y + threadIdx.y;
96 
97  if (i < m && j < n)
98  {
99  tallboy[i + mtb*j] = a[i + m*j];
100  tallboy[m+i + mtb*j] = b[i + m*j];
101  }
102  }
103 
104 
105 
106  template <typename REAL>
107  void custom_op_qr(void *a_, void *b_, int *len, MPI_Datatype *dtype)
108  {
109  (void)len;
110  (void)dtype;
111 
112  REAL *a = (REAL*)a_;
113  REAL *b = (REAL*)b_;
114 
115  kernel_stack<<<griddim, blockdim>>>(_m, _n, mtb, a, b, tallboy<REAL>);
116 
117  int info;
118  c->mem_set(info_dev, 0, sizeof(int));
119  check = fml::gpulapack::geqrf(c->lapack_handle(), mtb, _n, tallboy<REAL>, mtb, qraux<REAL>, work<REAL>, lwork, info_dev);
120  c->mem_gpu2cpu(&info, info_dev, sizeof(int));
121  // TODO check
122  if (info != 0)
123  badinfo = true;
124 
125  c->mem_set(b, 0, (size_t)_m*_n*sizeof(REAL));
126  fml::gpu_utils::lacpy('U', _m, _n, tallboy<REAL>, mtb, b, _m);
127  }
128  }
129 
130 
131 
132  template <typename REAL>
133  void qr_allreduce(const int root, const int m, const int n,
134  const REAL *const restrict a, REAL *const restrict b, MPI_Comm comm,
135  fml::card_sp_t c_)
136  {
137  int mpi_ret;
138 
139  internals::qr_global_init<REAL>(c_, m, n);
140 
141  // custom data type
142  MPI_Datatype mat_type;
143  mpi::contig_type(m*n, a, &mat_type);
144 
145  // custom op + reduce
146  MPI_Op op;
147  const int commutative = 1;
148 
149  MPI_Op_create((MPI_User_function*) internals::custom_op_qr<REAL>, commutative, &op);
150  if (root == defs::REDUCE_TO_ALL)
151  mpi_ret = MPI_Allreduce(a, b, 1, mat_type, op, comm);
152  else
153  mpi_ret = MPI_Reduce(a, b, 1, mat_type, op, root, comm);
154 
155  // cleanup and return
156  MPI_Op_free(&op);
157  MPI_Type_free(&mat_type);
158 
159  internals::qr_global_cleanup<REAL>();
160 
161  mpi::check_MPI_ret(mpi_ret);
162  if (internals::badinfo)
163  throw std::runtime_error("unrecoverable error with LAPACK function geqrf() occurred during reduction");
164  }
165 }
166 }
167 
168 
169 #endif
fml
Core namespace.
Definition: dimops.hh:10