5 #ifndef FML_MPI_DIMOPS_H
6 #define FML_MPI_DIMOPS_H
12 #include "../_internals/dimops.hh"
14 #include "../cpu/internals/vecops.hh"
15 #include "../cpu/cpuvec.hh"
17 #include "internals/bcutils.hh"
34 template <
typename REAL>
38 const len_t m = x.
nrows();
39 const len_t m_local = x.nrows_local();
40 const len_t n_local = x.ncols_local();
41 const len_t mb = x.bf_rows();
43 const grid g = x.get_grid();
49 for (len_t j=0; j<n_local; j++)
51 for (len_t i=0; i<m_local; i++)
53 const len_t gi = fml::bcutils::l2g(i, mb, g.
nprow(), g.
myrow());
54 s_d[gi] += x_d[i + m_local*j];
58 x.get_grid().allreduce(m, 1, s_d,
'A');
71 template <
typename REAL>
88 template <
typename REAL>
92 const len_t n = x.
ncols();
93 const len_t m_local = x.nrows_local();
94 const len_t n_local = x.ncols_local();
95 const len_t nb = x.bf_cols();
97 const grid g = x.get_grid();
103 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
104 for (len_t j=0; j<n_local; j++)
106 const len_t gj = fml::bcutils::l2g(j, nb, g.
npcol(), g.
mycol());
107 fml::vecops::cpu::sum(m_local, x_d + m_local*j, s_d[gj]);
110 x.get_grid().allreduce(n, 1, s_d,
'A');
123 template <
typename REAL>
145 template <
typename REAL>
149 throw std::runtime_error(
"non-conformal arguments");
151 const len_t m_local = x.nrows_local();
152 const len_t n_local = x.ncols_local();
156 const auto g = x.get_grid();
160 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
161 for (len_t j=0; j<n_local; j++)
164 for (len_t i=0; i<m_local; i++)
165 x_d[i + m_local*j] += s_d[bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow())];
168 else if (op == SWEEP_SUB)
170 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
171 for (len_t j=0; j<n_local; j++)
174 for (len_t i=0; i<m_local; i++)
175 x_d[i + m_local*j] -= s_d[bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow())];
178 else if (op == SWEEP_MUL)
180 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
181 for (len_t j=0; j<n_local; j++)
184 for (len_t i=0; i<m_local; i++)
185 x_d[i + m_local*j] *= s_d[bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow())];
188 else if (op == SWEEP_DIV)
190 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
191 for (len_t j=0; j<n_local; j++)
194 for (len_t i=0; i<m_local; i++)
195 x_d[i + m_local*j] /= s_d[bcutils::l2g(i, x.bf_rows(), g.nprow(), g.myrow())];
215 template <
typename REAL>
216 static inline void colsweep(mpimat<REAL> &x,
const cpuvec<REAL> &s,
const sweep_op op)
218 if (s.size() != x.ncols())
219 throw std::runtime_error(
"non-conformal arguments");
221 const len_t m_local = x.nrows_local();
222 const len_t n_local = x.ncols_local();
223 REAL *x_d = x.data_ptr();
224 const REAL *s_d = s.data_ptr();
226 const auto g = x.get_grid();
230 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
231 for (len_t j=0; j<n_local; j++)
234 for (len_t i=0; i<m_local; i++)
235 x_d[i + m_local*j] += s_d[bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol())];
238 else if (op == SWEEP_SUB)
240 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
241 for (len_t j=0; j<n_local; j++)
244 for (len_t i=0; i<m_local; i++)
245 x_d[i + m_local*j] -= s_d[bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol())];
248 else if (op == SWEEP_MUL)
250 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
251 for (len_t j=0; j<n_local; j++)
254 for (len_t i=0; i<m_local; i++)
255 x_d[i + m_local*j] *= s_d[bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol())];
258 else if (op == SWEEP_DIV)
260 #pragma omp parallel for if(m_local*n_local > fml::omp::OMP_MIN_SIZE)
261 for (len_t j=0; j<n_local; j++)
264 for (len_t i=0; i<m_local; i++)
265 x_d[i + m_local*j] /= s_d[bcutils::l2g(j, x.bf_cols(), g.npcol(), g.mycol())];
274 template <
typename REAL>
275 static inline void col_mean(
const grid g,
const len_t j,
const len_t m,
const len_t m_local,
const REAL *x, REAL &mean)
278 fml::vecops::cpu::sum(m_local, x + m_local*j, mean);
279 g.allreduce(1, 1, &mean,
'C');
283 template <
typename REAL>
284 static inline void col_var(
const grid g,
const len_t j,
const len_t m,
const len_t m_local,
const REAL *x,
const REAL &mean, REAL *work, REAL &var)
289 for (len_t i = 0; i<m_local; i++)
291 REAL diff = x[i + m_local*j] - mean;
292 work[0] += diff*diff;
296 g.allreduce(2, 1, work,
'C');
298 var = (work[0] - work[1]*work[1]/m) / (m-1);
303 template <
typename REAL>
304 static inline void center(mpimat<REAL> &x)
306 REAL *x_d = x.data_ptr();
307 const len_t m = x.nrows();
308 const len_t m_local = x.nrows_local();
309 const len_t n_local = x.ncols_local();
311 grid g = x.get_grid();
313 for (len_t j=0; j<n_local; j++)
316 col_mean(g, j, m, m_local, x_d, mean);
317 fml::vecops::cpu::sweep_add(-mean, m_local, x_d + m_local*j);
321 template <
typename REAL>
322 static inline void scale(mpimat<REAL> &x)
324 REAL *x_d = x.data_ptr();
325 const len_t m = x.nrows();
326 const len_t m_local = x.nrows_local();
327 const len_t n_local = x.ncols_local();
329 grid g = x.get_grid();
333 for (len_t j=0; j<n_local; j++)
336 col_mean(g, j, m, m_local, x_d, mean);
339 col_var(g, j, m, m_local, x_d, mean, work, var);
340 var = (REAL)1.0/sqrt(var);
341 fml::vecops::cpu::sweep_mul(var, m_local, x_d + m_local*j);
345 template <
typename REAL>
346 static inline void center_and_scale(mpimat<REAL> &x)
348 REAL *x_d = x.data_ptr();
349 const len_t m = x.nrows();
350 const len_t m_local = x.nrows_local();
351 const len_t n_local = x.ncols_local();
353 grid g = x.get_grid();
357 for (len_t j=0; j<n_local; j++)
360 col_mean(g, j, m, m_local, x_d, mean);
363 col_var(g, j, m, m_local, x_d, mean, work, var);
366 for (len_t i=0; i<m_local; i++)
367 x_d[i + m_local*j] = (x_d[i + m_local*j] - mean) / sqrt(var);
383 template <
typename REAL>
386 if (rm_mean && rm_sd)
387 internals::center_and_scale(x);
389 internals::center(x);