5 #ifndef FML_PAR_CPU_PARMAT_H
6 #define FML_PAR_CPU_PARMAT_H
10 #include "../../_internals/omp.hh"
12 #include "../../cpu/cpumat.hh"
13 #include "../../cpu/cpuvec.hh"
15 #include "../internals/parmat.hh"
20 template <
typename REAL>
27 parmat_cpu(
comm mpi_comm,
const len_global_t nrows,
const len_t ncols);
28 parmat_cpu(
comm mpi_comm,
const len_global_t nrows,
const len_t ncols,
const len_global_t nb4_);
30 void print(uint8_t ndigits=4,
bool add_final_blank=
true);
35 void fill_linspace(
const REAL start,
const REAL stop);
43 template <
typename REAL>
53 template <
typename REAL>
58 this->m_global = nrows;
59 len_t nrows_local = this->get_local_dim();
60 this->data.resize(nrows_local, ncols);
62 this->num_preceding_rows();
67 template <
typename REAL>
72 this->m_global = nrows;
73 len_t nrows_local = this->get_local_dim();
74 this->data.resize(nrows_local, ncols);
97 template <
typename REAL>
100 len_t n = this->data.ncols();
103 int myrank = this->r.rank();
105 this->data.print(ndigits,
false);
107 for (
int rank=1; rank<this->r.size(); rank++)
111 len_t m = this->data.nrows();
112 this->r.send(1, &m, 0);
114 for (
int i=0; i<m; i++)
116 this->data.get_row(i, pv);
117 this->r.send(n, pv.data_ptr(), 0);
120 else if (myrank == 0)
123 this->r.recv(1, &m, rank);
125 for (
int i=0; i<m; i++)
127 this->r.recv(n, pv.data_ptr(), rank);
128 pv.print(ndigits,
false);
137 this->r.printf(0,
"\n");
144 template <
typename REAL>
148 this->fill_val(start);
151 const len_t m_local = this->data.nrows();
152 const len_t n = this->data.ncols();
154 const REAL v = (stop-start)/((REAL) this->m_global*n - 1);
155 REAL *d_p = this->data.data_ptr();
157 #pragma omp parallel for if(m_local*n > fml::omp::OMP_MIN_SIZE)
158 for (len_t j=0; j<n; j++)
161 for (len_t i=0; i<m_local; i++)
163 d_p[i + m_local*j] = v*((REAL) i + this->nb4 + this->m_global*j) + start;
171 template <
typename REAL>
181 template <
typename REAL>
184 const len_t m_local = this->data.nrows();
185 const len_t n = this->data.ncols();
186 REAL *x_p = this->data.data_ptr();
190 for (len_t j=0; j<n; j++)
192 for (len_t i=0; i<m_local; i++)
194 const len_global_t gi = i + this->nb4;
196 x_p[i + m_local*j] = d_p[gi % d.
size()];
198 x_p[i + m_local*j] = 0;