fml  0.1-0
Fused Matrix Library
parmat.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_PAR_PARMAT_H
6 #define FML_PAR_PARMAT_H
7 #pragma once
8 
9 
10 #include <cmath>
11 #include <cstdlib>
12 #include <cstring>
13 #include <random>
14 
15 #include "../_internals/types.hh"
16 #include "comm.hh"
17 
18 
19 template <class MAT, class VEC, typename REAL>
20 class parmat
21 {
22  public:
23  parmat(){};
24  parmat(comm &mpi_comm, MAT &_data);
25 
26  void print(uint8_t ndigits=4, bool add_final_blank=true);
27  void info();
28 
29  void fill_zero();
30  void fill_one();
31  void fill_val(const REAL v);
32  void fill_linspace(const REAL start, const REAL stop);
33  void fill_eye();
34  void fill_diag(const VEC &d);
35  // void fill_runif(const uint32_t seed, const REAL min=0, const REAL max=1);
36  // void fill_runif(const REAL min=0, const REAL max=1);
37  // void fill_rnorm(const uint32_t seed, const REAL mean=0, const REAL sd=1);
38  // void fill_rnorm(const REAL mean=0, const REAL sd=1);
39 
40  // void diag(cpuvec<REAL> &v);
41  // void antidiag(cpuvec<REAL> &v);
42  void scale(const REAL s);
43  // void rev_rows();
44  void rev_cols();
45 
46  bool any_inf() const;
47  bool any_nan() const;
48 
49  len_global_t nrows() const {return m_global;};
50  len_local_t nrows_local() const {return data.nrows();};
51  len_local_t ncols() const {return data.ncols();};
52  comm get_comm() const {return r;};
53  MAT& data_obj() const {return data;};
54  MAT& data_obj() {return data;};
55 
56  protected:
57  MAT data;
58  len_global_t m_global;
59  comm r;
60  len_global_t nb4;
61  void num_preceding_rows();
62  len_t get_local_dim();
63 };
64 
65 
66 
67 template <class MAT, class VEC, typename REAL>
68 parmat<MAT, VEC, REAL>::parmat(comm &mpi_comm, MAT &_data)
69 {
70  r = mpi_comm;
71  data = _data;
72 
73  m_global = (len_global_t) data.nrows();
74  r.allreduce(1, &(m_global));
75  num_preceding_rows();
76 }
77 
78 
79 
80 template <class MAT, class VEC, typename REAL>
81 void parmat<MAT, VEC, REAL>::print(uint8_t ndigits, bool add_final_blank)
82 {
83  len_t n = data.ncols();
84  VEC pv(n);
85 
86  int myrank = r.rank();
87  if (myrank == 0)
88  data.print(ndigits, false);
89 
90  for (int rank=1; rank<r.size(); rank++)
91  {
92  if (rank == myrank)
93  {
94  len_t m = data.nrows();
95  r.send(1, &m, 0);
96 
97  REAL *d = data.data_ptr();
98  for (int i=0; i<m; i++)
99  {
100  data.get_row(i, pv);
101  r.send(n, pv.data_ptr(), 0);
102  }
103  }
104  else if (myrank == 0)
105  {
106  len_t m;
107  r.recv(1, &m, rank);
108 
109  for (int i=0; i<m; i++)
110  {
111  r.recv(n, pv.data_ptr(), rank);
112  pv.print(ndigits, false);
113  }
114  }
115 
116  r.barrier();
117  }
118 
119  if (add_final_blank)
120  {
121  r.printf(0, "\n");
122  r.barrier();
123  }
124 }
125 
126 
127 
128 template <class MAT, class VEC, typename REAL>
130 {
131  r.printf(0, "# parmat");
132  r.printf(0, " %" PRIu64 "x%d", m_global, data.ncols());
133  r.printf(0, " type=%s", typeid(REAL).name());
134  r.printf(0, "\n");
135 }
136 
137 
138 
139 template <class MAT, class VEC, typename REAL>
141 {
142  data.fill_zero();
143 }
144 
145 
146 
147 template <class MAT, class VEC, typename REAL>
149 {
150  data.fill_one();
151 }
152 
153 
154 
155 template <class MAT, class VEC, typename REAL>
156 void parmat<MAT, VEC, REAL>::fill_val(const REAL v)
157 {
158  data.fill_val(v);
159 }
160 
161 
162 
163 template <class MAT, class VEC, typename REAL>
164 void parmat<MAT, VEC, REAL>::scale(const REAL v)
165 {
166  data.scale(v);
167 }
168 
169 
170 
171 template <class MAT, class VEC, typename REAL>
173 {
174  data.rev_cols();
175 }
176 
177 
178 
179 template <class MAT, class VEC, typename REAL>
181 {
182  int ret = (int) data.any_inf();
183  r.allreduce(1, &ret);
184  return (bool) ret;
185 }
186 
187 
188 
189 template <class MAT, class VEC, typename REAL>
191 {
192  int ret = (int) data.any_nan();
193  r.allreduce(1, &ret);
194  return (bool) ret;
195 }
196 
197 
198 
199 // -----------------------------------------------------------------------------
200 // private
201 // -----------------------------------------------------------------------------
202 
203 template <class MAT, class VEC, typename REAL>
205 {
206  int myrank = r.rank();
207  int size = r.size();;
208 
209  nb4 = 0;
210  len_t m_local = data.nrows();
211 
212  for (int rank=1; rank<size; rank++)
213  {
214  if (myrank == (rank - 1))
215  {
216  len_global_t nb4_send = nb4 + ((len_global_t) m_local);
217  r.send(1, &nb4_send, rank);
218  }
219  else if (myrank == rank)
220  {
221  len_global_t nr_prev_rank;
222  r.recv(1, &nr_prev_rank, rank-1);
223 
224  nb4 += nr_prev_rank;
225  }
226  }
227 }
228 
229 
230 
231 template <class MAT, class VEC, typename REAL>
233 {
234  len_t local = m_global / r.size();
235  len_t rem = (len_t) (m_global - (len_global_t) local*r.size());
236  if (r.rank()+1 <= rem)
237  local++;
238 
239  return local;
240 }
241 
242 
243 #endif
MPI communicator data and helpers.
Definition: comm.hh:24
void barrier() const
Execute a barrier.
Definition: comm.hh:341
Definition: parmat.hh:20
void send(int n, const T *data, int dest, int tag=0) const
Point-to-point send. Should be matched by a corresponding &#39;recv&#39; call.
Definition: comm.hh:289
void allreduce(int n, T *data, MPI_Op op=MPI_SUM) const
Sum reduce operation across all processes in the MPI communicator.
Definition: comm.hh:358
int rank() const
Calling process rank (0-based index) in the MPI communicator.
Definition: comm.hh:66
int size() const
Total number of ranks in the MPI communicator.
Definition: comm.hh:68
void printf(int rank, const char *fmt,...) const
Helper wrapper around the C standard I/O &#39;printf()&#39; function. Conceptually similar to guarding a norm...
Definition: comm.hh:196
void recv(int n, T *data, int source, int tag=0) const
Point-to-point receive. Should be matched by a corresponding &#39;send&#39; call.
Definition: comm.hh:318