5 #ifndef FML_MPI_MPIHELPERS_H
6 #define FML_MPI_MPIHELPERS_H
12 #include "../_internals/arraytools/src/arraytools.hpp"
14 #include "../cpu/cpumat.hh"
15 #include "../cpu/cpuvec.hh"
17 #include "internals/bcutils.hh"
45 template <
typename REAL_IN,
typename REAL_OUT>
48 grid g = mpi.get_grid();
52 const len_local_t m_local = mpi.nrows_local();
53 const len_local_t n_local = mpi.ncols_local();
55 const int mb = mpi.bf_rows();
57 const len_t m = mpi.
nrows();
58 const len_t n = mpi.
ncols();
68 if (m_local > 0 && n_local > 0)
70 for (len_local_t j=0; j<n_local; j++)
72 const int gj = fml::bcutils::l2g(j, mpi.bf_cols(), g.
npcol(), g.
mycol());
74 for (len_local_t i=0; i<m_local; i+=mb)
76 const int gi = fml::bcutils::l2g(i, mpi.bf_rows(), g.
nprow(), g.
myrow());
78 for (
int ii=0; ii<mb && ii+i<m_local; ii++)
79 gbl[gi+ii + m*gj] = (REAL_OUT) sub[i+ii + m_local*j];
90 template <
typename REAL>
125 template <
typename REAL_IN,
typename REAL_OUT>
128 const grid g = mpi.get_grid();
132 bool i_am_ret = (g.
myrow() == rdest && g.
mycol() == cdest) ?
true :
false;
134 const len_local_t m_local = mpi.nrows_local();
136 const int mb = mpi.bf_rows();
137 const int nb = mpi.bf_cols();
139 const len_t m = mpi.
nrows();
140 const len_t n = mpi.
ncols();
151 const REAL_IN *sub = mpi.
data_ptr();
156 for (len_t gj=0; gj<n; gj+=nb)
158 const int pc = fml::bcutils::g2p(gj, nb, g.
npcol());
159 const len_t j = fml::bcutils::g2l(gj, nb, g.
npcol());
160 const len_t col_copylen = std::min(nb, n-gj);
162 for (len_t gi=0; gi<m; gi+=mb)
164 const int pr = fml::bcutils::g2p(gi, mb, g.
nprow());
165 const len_t i = fml::bcutils::g2l(gi, mb, g.
nprow());
166 const len_t row_copylen = std::min(mb, m-gi);
172 for (
int jj=0; jj<col_copylen; jj++)
173 arraytools::copy(row_copylen, sub + i+m_local*(j+jj), gbl + gi+m*(gj+jj));
176 g.
recv(row_copylen, col_copylen, m, gbl + gi+m*gj, pr, pc);
180 for (len_t jj=0; jj<col_copylen; jj++)
182 for (len_t ii=0; ii<row_copylen; ii++)
183 tmp_d[ii + mb*jj] = (REAL_OUT) sub[i+ii + m_local*(j+jj)];
186 g.
send(row_copylen, col_copylen, mb, tmp_d, rdest, cdest);
195 template <
typename REAL>
228 template <
typename REAL_IN,
typename REAL_OUT>
231 const len_t m = cpu.
nrows();
232 const len_t n = cpu.
ncols();
239 const grid g = mpi.get_grid();
241 const len_local_t m_local = mpi.nrows_local();
242 const len_local_t n_local = mpi.ncols_local();
243 const int mb = mpi.bf_rows();
245 const REAL_IN *gbl = cpu.
data_ptr();
248 if (m_local > 0 && n_local > 0)
250 for (len_local_t j=0; j<n_local; j++)
252 const int gj = fml::bcutils::l2g(j, mpi.bf_cols(), g.
npcol(), g.
mycol());
254 for (len_local_t i=0; i<m_local; i+=mb)
256 const int gi = fml::bcutils::l2g(i, mpi.bf_rows(), g.
nprow(), g.
myrow());
258 for (
int ii=0; ii<mb && ii+i<m_local; ii++)
259 sub[i+ii + m_local*j] = (REAL_OUT) gbl[gi+ii + m*gj];
283 template <
typename REAL_IN,
typename REAL_OUT>
286 if (mpi_in.get_grid().
ictxt() != mpi_out.get_grid().
ictxt())
287 throw std::runtime_error(
"mpimat objects must be distributed on the same process grid");
291 size_t len = (size_t) mpi_in.nrows_local() * mpi_in.ncols_local();