5 #ifndef FML_PAR_GPU_PARMAT_H
6 #define FML_PAR_GPU_PARMAT_H
10 #include "../../gpu/card.hh"
11 #include "../../gpu/gpumat.hh"
12 #include "../../gpu/gpuvec.hh"
14 #include "../internals/parmat.hh"
19 template <
typename REAL>
25 parmat_gpu(
comm mpi_comm, card_sp_t gpu_card,
const len_global_t nrows,
const len_t ncols);
26 parmat_gpu(
comm mpi_comm, card_sp_t gpu_card,
const len_global_t nrows,
const len_t ncols,
const len_global_t nb4_);
28 void print(uint8_t ndigits=4,
bool add_final_blank=
true);
30 void fill_linspace(
const REAL start,
const REAL stop);
34 card_sp_t get_card()
const {
return this->data.get_card();};
40 template <
typename REAL>
42 const len_global_t nrows,
const len_t ncols)
46 this->m_global = nrows;
47 len_t nrows_local = this->get_local_dim();
48 this->data.resize(gpu_card, nrows_local, ncols);
50 this->num_preceding_rows();
55 template <
typename REAL>
57 const len_global_t nrows,
const len_t ncols,
const len_global_t nb4_)
61 this->m_global = nrows;
62 len_t nrows_local = this->get_local_dim();
63 this->data.resize(gpu_card, nrows_local, ncols);
70 template <
typename REAL>
73 len_t n = this->data.ncols();
76 int myrank = this->r.rank();
78 this->data.print(ndigits,
false);
80 for (
int rank=1; rank<this->r.size(); rank++)
84 len_t m = this->data.nrows();
85 this->r.send(1, &m, 0);
87 for (
int i=0; i<m; i++)
89 this->data.get_row(i, pv);
90 this->r.send(n, pv.data_ptr(), 0);
96 this->r.recv(1, &m, rank);
98 for (
int i=0; i<m; i++)
100 this->r.recv(n, pv.data_ptr(), rank);
101 pv.print(ndigits,
false);
110 this->r.printf(0,
"\n");
117 template <
typename REAL>
121 this->fill_val(start);
124 const len_t m_local = this->data.nrows();
125 const len_t n = this->data.ncols();
127 const REAL v = (stop-start)/((REAL) this->m_global*n - 1);
132 this->data.c->check();
138 template <
typename REAL>
148 template <
typename REAL>