5 #ifndef FML_PAR_INTERNALS_PARMAT_H
6 #define FML_PAR_INTERNALS_PARMAT_H
15 #include "../../_internals/rand.hh"
16 #include "../../_internals/types.hh"
22 template <
class MAT,
class VEC,
typename REAL>
30 void resize(len_global_t nrows, len_t ncols);
32 void print(uint8_t ndigits=4,
bool add_final_blank=
true);
36 void fill_val(
const REAL v);
37 void fill_linspace(
const REAL start,
const REAL stop);
39 void fill_diag(
const VEC &d);
40 void fill_runif(
const uint32_t seed,
const REAL min=0,
const REAL max=1);
41 void fill_runif(
const REAL min=0,
const REAL max=1);
42 void fill_rnorm(
const uint32_t seed,
const REAL mean=0,
const REAL sd=1);
43 void fill_rnorm(
const REAL mean=0,
const REAL sd=1);
47 void scale(
const REAL s);
54 REAL get(
const len_global_t i)
const;
55 REAL get(
const len_global_t i,
const len_t j)
const;
56 void set(
const len_global_t i,
const REAL v);
57 void set(
const len_global_t i,
const len_t j,
const REAL v);
65 len_global_t nrows()
const {
return m_global;};
66 len_local_t nrows_local()
const {
return data.nrows();};
67 len_local_t ncols()
const {
return data.ncols();};
68 len_global_t nrows_before()
const {
return nb4;};
69 comm get_comm()
const {
return r;};
70 const MAT& data_obj()
const {
return data;};
71 MAT& data_obj() {
return data;};
75 len_global_t m_global;
79 void num_preceding_rows();
80 len_t get_local_dim();
81 void check_index(
const len_global_t i)
const;
82 void check_index(
const len_global_t i,
const len_t j)
const;
88 template <
class MAT,
class VEC,
typename REAL>
94 m_global = (len_global_t) data.nrows();
101 template <
class MAT,
class VEC,
typename REAL>
104 this->data = x.data_obj();
105 this->m_global = x.nrows();
106 this->r = x.get_comm();
107 this->nb4 = x.nrows_before();
112 template <
class MAT,
class VEC,
typename REAL>
115 this->m_global = nrows;
116 len_t m_local = this->get_local_dim();
118 this->data.resize(m_local, ncols);
119 num_preceding_rows();
124 template <
class MAT,
class VEC,
typename REAL>
127 r.printf(0,
"# parmat");
128 r.printf(0,
" %" PRIu64
"x%d", m_global, data.ncols());
129 r.printf(0,
" type=%s",
typeid(REAL).name());
135 template <
class MAT,
class VEC,
typename REAL>
143 template <
class MAT,
class VEC,
typename REAL>
151 template <
class MAT,
class VEC,
typename REAL>
154 data.fill_runif(seed, min, max);
157 template <
class MAT,
class VEC,
typename REAL>
160 uint32_t seed = fml::rand::get_seed() + r.rank();
161 data.fill_runif(seed, min, max);
166 template <
class MAT,
class VEC,
typename REAL>
169 data.fill_rnorm(seed, mean, sd);
172 template <
class MAT,
class VEC,
typename REAL>
175 uint32_t seed = fml::rand::get_seed() + r.rank();
176 data.fill_rnorm(seed, mean, sd);
181 template <
class MAT,
class VEC,
typename REAL>
189 template <
class MAT,
class VEC,
typename REAL>
197 template <
class MAT,
class VEC,
typename REAL>
200 int ret = (int) data.any_inf();
201 r.allreduce(1, &ret);
207 template <
class MAT,
class VEC,
typename REAL>
210 int ret = (int) data.any_nan();
211 r.allreduce(1, &ret);
217 template <
class MAT,
class VEC,
typename REAL>
222 len_global_t row = i % m_global;
223 len_t j = (len_t) (i / m_global);
226 if (row >= nb4 && row < nb4+data.nrows())
227 ret = data.get(row-nb4, j);
231 r.allreduce(1, &ret);
235 template <
class MAT,
class VEC,
typename REAL>
241 if (i >= nb4 && i < nb4+data.nrows())
242 ret = data.get(i-nb4, j);
246 r.allreduce(1, &ret);
250 template <
class MAT,
class VEC,
typename REAL>
255 len_global_t row = i % m_global;
256 len_t j = (len_t) (i / m_global);
258 if (row >= nb4 && row < nb4+data.nrows())
259 data.set(row-nb4, j, v);
262 template <
class MAT,
class VEC,
typename REAL>
267 if (i >= nb4 && i < nb4+data.nrows())
268 data.set(i-nb4, j, v);
273 template <
class MAT,
class VEC,
typename REAL>
276 int neq_count = (int) (data != x.data_obj());
277 r.allreduce(1, &neq_count);
279 return (neq_count == 0);
282 template <
class MAT,
class VEC,
typename REAL>
285 return !(*
this == x);
290 template <
class MAT,
class VEC,
typename REAL>
293 this->data = x.data_obj();
294 this->m_global = x.nrows();
295 this->r = x.get_comm();
296 this->nb4 = x.nrows_before();
304 template <
class MAT,
class VEC,
typename REAL>
307 int myrank = r.rank();
308 int size = r.size();;
311 len_t m_local = data.nrows();
313 for (
int rank=1; rank<size; rank++)
315 if (myrank == (rank - 1))
317 len_global_t nb4_send = nb4 + ((len_global_t) m_local);
318 r.send(1, &nb4_send, rank);
320 else if (myrank == rank)
322 len_global_t nr_prev_rank;
323 r.recv(1, &nr_prev_rank, rank-1);
332 template <
class MAT,
class VEC,
typename REAL>
335 len_t local = m_global / r.size();
336 len_t rem = (len_t) (m_global - (len_global_t) local*r.size());
337 if (r.rank()+1 <= rem)
345 template <
class MAT,
class VEC,
typename REAL>
348 if (i < 0 || i >= (m_global * data.ncols()))
349 throw std::runtime_error(
"index out of bounds");
352 template <
class MAT,
class VEC,
typename REAL>
355 if (i < 0 || i >= m_global || j < 0 || j >= data.ncols())
356 throw std::runtime_error(
"index out of bounds");