fml  0.1-0
Fused Matrix Library
grid.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_MPI_GRID_H
6 #define FML_MPI_GRID_H
7 #pragma once
8 
9 
10 #define OMPI_SKIP_MPICXX 1
11 #include <mpi.h>
12 
13 #include <cmath>
14 #include <stdexcept>
15 
16 #include "../_internals/print.hh"
17 
18 #include "internals/blacs_prototypes.h"
19 
20 
26 namespace fml
27 {
34  enum gridshape
35  {
49  };
50 
55  enum blacsops
56  {
57  BLACS_SUM,
58  BLACS_MAX,
59  BLACS_MIN
60  };
61 }
62 
63 
64 
65 namespace fml
66 {
70  class grid
71  {
72  public:
73  // constructors/destructor and comm management
74  grid();
75  grid(const gridshape gridtype);
76  void set(const int blacs_context);
77  void exit();
78  void finalize(const bool mpi_continue=false);
79 
80  // utilities
81  void printf(const int row, const int col, const char *fmt, ...) const;
82  void info() const;
83  bool rank0() const;
84  bool ingrid() const;
85 
86  // send/recv
87  void send(const int m, const int n, const int *x, const int rdest=0, const int cdest=0) const;
88  void send(const int m, const int n, const float *x, const int rdest=0, const int cdest=0) const;
89  void send(const int m, const int n, const double *x, const int rdest=0, const int cdest=0) const;
90  void send(const int m, const int n, const int ldx, const int *x, const int rdest=0, const int cdest=0) const;
91  void send(const int m, const int n, const int ldx, const float *x, const int rdest=0, const int cdest=0) const;
92  void send(const int m, const int n, const int ldx, const double *x, const int rdest=0, const int cdest=0) const;
93 
94  void recv(const int m, const int n, int *x, const int rsrc=0, const int csrc=0) const;
95  void recv(const int m, const int n, float *x, const int rsrc=0, const int csrc=0) const;
96  void recv(const int m, const int n, double *x, const int rsrc=0, const int csrc=0) const;
97  void recv(const int m, const int n, const int ldx, int *x, const int rsrc=0, const int csrc=0) const;
98  void recv(const int m, const int n, const int ldx, float *x, const int rsrc=0, const int csrc=0) const;
99  void recv(const int m, const int n, const int ldx, double *x, const int rsrc=0, const int csrc=0) const;
100 
101  // collectives
102  void barrier(const char scope='A') const;
103 
104  void allreduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM) const;
105  void allreduce(const int m, const int n, float *x, const char scope='A', const blacsops op=BLACS_SUM) const;
106  void allreduce(const int m, const int n, double *x, const char scope='A', const blacsops op=BLACS_SUM) const;
107 
108  void reduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM, const int rdest=0, const int cdest=0) const;
109  void reduce(const int m, const int n, float *x, const char scope='A', const blacsops op=BLACS_SUM, const int rdest=0, const int cdest=0) const;
110  void reduce(const int m, const int n, double *x, const char scope='A', const blacsops op=BLACS_SUM, const int rdest=0, const int cdest=0) const;
111 
112  void bcast(const int m, const int n, int *x, const char scope='A', const int rsrc=0, const int csrc=0) const;
113  void bcast(const int m, const int n, float *x, const char scope='A', const int rsrc=0, const int csrc=0) const;
114  void bcast(const int m, const int n, double *x, const char scope='A', const int rsrc=0, const int csrc=0) const;
115 
116 
119  int ictxt() const {return _ictxt;};
121  int nprocs() const {return _nprocs;};
123  int nprow() const {return _nprow;};
125  int npcol() const {return _npcol;};
127  int myrow() const {return _myrow;};
129  int mycol() const {return _mycol;};
131 
133  bool valid_grid() const {return (_ictxt!=UNINITIALIZED_GRID && _ictxt!=EXITED_GRID);};
134 
135  protected:
136  int _ictxt;
137  int _nprocs;
138  int _nprow;
139  int _npcol;
140  int _myrow;
141  int _mycol;
142 
143  private:
144  static const int UNINITIALIZED_GRID = -1;
145  static const int EXITED_GRID = -11;
146 
147  void squarish(int *nr, int *nc) const;
148  };
149 }
150 
151 
152 
153 // -----------------------------------------------------------------------------
154 // public
155 // -----------------------------------------------------------------------------
156 
157 // constructors/destructor and grid management
158 
163 {
164  _ictxt = UNINITIALIZED_GRID;
165  _nprocs = _nprow = _npcol = _myrow = _mycol = -1;
166 }
167 
168 
169 
179 inline fml::grid::grid(const fml::gridshape gridtype)
180 {
181  char order = 'R';
182 
183  int mypnum;
184  Cblacs_pinfo(&mypnum, &_nprocs);
185 
186  Cblacs_get(-1, 0, &_ictxt);
187 
188  if (gridtype == PROC_GRID_SQUARE)
189  {
190  int nr, nc;
191  squarish(&nr, &nc);
192 
193  Cblacs_gridinit(&_ictxt, &order, nr, nc);
194  }
195  else if (gridtype == PROC_GRID_TALL)
196  Cblacs_gridinit(&_ictxt, &order, _nprocs, 1);
197  else if (gridtype == PROC_GRID_WIDE)
198  Cblacs_gridinit(&_ictxt, &order, 1, _nprocs);
199  else
200  throw std::runtime_error("Process grid should be one of PROC_GRID_SQUARE, PROC_GRID_TALL, or PROC_GRID_WIDE");
201 
202  Cblacs_gridinfo(_ictxt, &_nprow, &_npcol, &_myrow, &_mycol);
203 }
204 
205 
206 
212 inline void fml::grid::set(const int blacs_context)
213 {
214  _ictxt = blacs_context;
215  Cblacs_gridinfo(_ictxt, &_nprow, &_npcol, &_myrow, &_mycol);
216 
217  if (_nprow == -1)
218  throw std::runtime_error("context handle does not point at a valid context");
219 
220  _nprocs = _nprow * _npcol;
221 }
222 
223 
224 
226 inline void fml::grid::exit()
227 {
228  if (_ictxt != EXITED_GRID && _ictxt != UNINITIALIZED_GRID)
229  Cblacs_gridexit(_ictxt);
230 
231  _ictxt = EXITED_GRID;
232  _nprocs = _nprow = _npcol = _myrow = _mycol = -1;
233 }
234 
235 
236 
242 inline void fml::grid::finalize(const bool mpi_continue)
243 {
244  exit();
245 
246  int cont = (int) mpi_continue;
247  Cblacs_exit(cont);
248 }
249 
250 
251 
252 // utilities
253 
263 inline void fml::grid::printf(const int row, const int col, const char *fmt, ...) const
264 {
265  if (_myrow == row && _mycol == col)
266  {
267  va_list args;
268 
269  va_start(args, fmt);
270  fml::print::vprintf(fmt, args);
271  va_end(args);
272  }
273 }
274 
275 
276 
281 inline void fml::grid::info() const
282 {
283  printf(0, 0, "## Grid %d %dx%d\n\n", _ictxt, _nprow, _npcol);
284 }
285 
286 
287 
292 inline bool fml::grid::rank0() const
293 {
294  return (_myrow==0 && _mycol==0);
295 }
296 
297 
298 
303 inline bool fml::grid::ingrid() const
304 {
305  return !(_myrow==-1 && _mycol==-1);
306 }
307 
308 
309 
310 // send/recv
311 
320 inline void fml::grid::send(const int m, const int n, const int *x, const int rdest, const int cdest) const
322 {
323  Cigesd2d(_ictxt, m, n, x, m, rdest, cdest);
324 }
325 
326 inline void fml::grid::send(const int m, const int n, const float *x, const int rdest, const int cdest) const
327 {
328  Csgesd2d(_ictxt, m, n, x, m, rdest, cdest);
329 }
330 
331 inline void fml::grid::send(const int m, const int n, const double *x, const int rdest, const int cdest) const
332 {
333  Cdgesd2d(_ictxt, m, n, x, m, rdest, cdest);
334 }
335 
336 inline void fml::grid::send(const int m, const int n, const int ldx, const int *x, const int rdest, const int cdest) const
337 {
338  Cigesd2d(_ictxt, m, n, x, ldx, rdest, cdest);
339 }
340 
341 inline void fml::grid::send(const int m, const int n, const int ldx, const float *x, const int rdest, const int cdest) const
342 {
343  Csgesd2d(_ictxt, m, n, x, ldx, rdest, cdest);
344 }
345 
346 inline void fml::grid::send(const int m, const int n, const int ldx, const double *x, const int rdest, const int cdest) const
347 {
348  Cdgesd2d(_ictxt, m, n, x, ldx, rdest, cdest);
349 }
351 
352 
353 
362 inline void fml::grid::recv(const int m, const int n, int *x, const int rsrc, const int csrc) const
364 {
365  Cigerv2d(_ictxt, m, n, x, m, rsrc, csrc);
366 }
367 
368 inline void fml::grid::recv(const int m, const int n, float *x, const int rsrc, const int csrc) const
369 {
370  Csgerv2d(_ictxt, m, n, x, m, rsrc, csrc);
371 }
372 
373 inline void fml::grid::recv(const int m, const int n, double *x, const int rsrc, const int csrc) const
374 {
375  Cdgerv2d(_ictxt, m, n, x, m, rsrc, csrc);
376 }
377 
378 inline void fml::grid::recv(const int m, const int n, const int ldx, int *x, const int rsrc, const int csrc) const
379 {
380  Cigerv2d(_ictxt, m, n, x, ldx, rsrc, csrc);
381 }
382 
383 inline void fml::grid::recv(const int m, const int n, const int ldx, float *x, const int rsrc, const int csrc) const
384 {
385  Csgerv2d(_ictxt, m, n, x, ldx, rsrc, csrc);
386 }
387 
388 inline void fml::grid::recv(const int m, const int n, const int ldx, double *x, const int rsrc, const int csrc) const
389 {
390  Cdgerv2d(_ictxt, m, n, x, ldx, rsrc, csrc);
391 }
393 
394 
395 
396 // collectives
397 
404 inline void fml::grid::barrier(const char scope) const
405 {
406  Cblacs_barrier(_ictxt, &scope);
407 }
408 
409 
410 
419 inline void fml::grid::allreduce(const int m, const int n, int *x, const char scope, const blacsops op) const
421 {
422  reduce(m, n, x, scope, op, -1, -1);
423 }
424 
425 inline void fml::grid::allreduce(const int m, const int n, float *x, const char scope, const blacsops op) const
426 {
427  reduce(m, n, x, scope, op, -1, -1);
428 }
429 
430 inline void fml::grid::allreduce(const int m, const int n, double *x, const char scope, const blacsops op) const
431 {
432  reduce(m, n, x, scope, op, -1, -1);
433 }
435 
436 
437 
448 inline void fml::grid::reduce(const int m, const int n, int *x, const char scope, const blacsops op, const int rdest, const int cdest) const
450 {
451  char top = ' ';
452 
453  if (op == BLACS_SUM)
454  Cigsum2d(_ictxt, &scope, &top, m, n, x, m, rdest, cdest);
455  else if (op == BLACS_MAX)
456  Cigamx2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
457  else if (op == BLACS_MIN)
458  Cigamn2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
459 }
460 
461 inline void fml::grid::reduce(const int m, const int n, float *x, const char scope, const blacsops op, const int rdest, const int cdest) const
462 {
463  char top = ' ';
464 
465  if (op == BLACS_SUM)
466  Csgsum2d(_ictxt, &scope, &top, m, n, x, m, rdest, cdest);
467  else if (op == BLACS_MAX)
468  Csgamx2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
469  else if (op == BLACS_MIN)
470  Csgamn2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
471 }
472 
473 inline void fml::grid::reduce(const int m, const int n, double *x, const char scope, const blacsops op, const int rdest, const int cdest) const
474 {
475  char top = ' ';
476 
477  if (op == BLACS_SUM)
478  Cdgsum2d(_ictxt, &scope, &top, m, n, x, m, rdest, cdest);
479  else if (op == BLACS_MAX)
480  Cdgamx2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
481  else if (op == BLACS_MIN)
482  Cdgamn2d(_ictxt, &scope, &top, m, n, x, m, NULL, NULL, -1, rdest, cdest);
483 }
485 
486 
487 
498 inline void fml::grid::bcast(const int m, const int n, int *x, const char scope, const int rsrc, const int csrc) const
500 {
501  char top = ' ';
502  if (rsrc == _myrow && csrc == _mycol)
503  Cigebs2d(_ictxt, &scope, &top, m, n, x, m);
504  else
505  Cigebr2d(_ictxt, &scope, &top, m, n, x, m, rsrc, csrc);
506 }
507 
508 inline void fml::grid::bcast(const int m, const int n, float *x, const char scope, const int rsrc, const int csrc) const
509 {
510  char top = ' ';
511  if (rsrc == _myrow && csrc == _mycol)
512  Csgebs2d(_ictxt, &scope, &top, m, n, x, m);
513  else
514  Csgebr2d(_ictxt, &scope, &top, m, n, x, m, rsrc, csrc);
515 }
516 
517 inline void fml::grid::bcast(const int m, const int n, double *x, const char scope, const int rsrc, const int csrc) const
518 {
519  char top = ' ';
520  if (rsrc == _myrow && csrc == _mycol)
521  Cdgebs2d(_ictxt, &scope, &top, m, n, x, m);
522  else
523  Cdgebr2d(_ictxt, &scope, &top, m, n, x, m, rsrc, csrc);
524 }
526 
527 
528 
529 // -----------------------------------------------------------------------------
530 // private
531 // -----------------------------------------------------------------------------
532 
533 inline void fml::grid::squarish(int *nr, int *nc) const
534 {
535  int n = (int) sqrt((double) _nprocs);
536  n = (n<1)?1:n; // suppresses bogus compiler warning
537 
538  for (int i=0; i<n; i++)
539  {
540  (*nc) = n - i;
541  (*nr) = _nprocs % (*nc);
542  if ((*nr) == 0)
543  break;
544  }
545 
546  (*nr) = _nprocs / (*nc);
547 }
548 
549 
550 #endif
fml::grid::ictxt
int ictxt() const
Definition: grid.hh:119
fml::gridshape
gridshape
Supported process grid shapes for 2-dimensional BLACS grids.
Definition: grid.hh:34
fml::PROC_GRID_TALL
@ PROC_GRID_TALL
A grid with 1 column and as many rows as there are MPI ranks.
Definition: grid.hh:48
fml::grid::allreduce
void allreduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM) const
Sum reduce operation across all processes in the grid.
Definition: grid.hh:420
fml::grid
2-dimensional MPI process grid.
Definition: grid.hh:70
fml::grid::send
void send(const int m, const int n, const int *x, const int rdest=0, const int cdest=0) const
Point-to-point send. Should be matched by a corresponding 'recv' call.
Definition: grid.hh:321
fml::grid::barrier
void barrier(const char scope='A') const
Execute a barrier across the specified scope of the BLACS grid.
Definition: grid.hh:404
fml::grid::mycol
int mycol() const
The process column (0-based index) of the calling process.
Definition: grid.hh:129
fml::grid::bcast
void bcast(const int m, const int n, int *x, const char scope='A', const int rsrc=0, const int csrc=0) const
Broadcast.
Definition: grid.hh:499
fml::grid::info
void info() const
Print some brief information about the BLACS grid. The printing is done by row 0 and col 0.
Definition: grid.hh:281
fml::grid::reduce
void reduce(const int m, const int n, int *x, const char scope='A', const blacsops op=BLACS_SUM, const int rdest=0, const int cdest=0) const
Sum reduce operation.
Definition: grid.hh:449
fml::grid::finalize
void finalize(const bool mpi_continue=false)
Shuts down BLACS, and optionally MPI.
Definition: grid.hh:242
fml::PROC_GRID_SQUARE
@ PROC_GRID_SQUARE
Definition: grid.hh:44
fml::grid::printf
void printf(const int row, const int col, const char *fmt,...) const
Helper wrapper around the C standard I/O 'printf()' function. Conceptually similar to guarding a norm...
Definition: grid.hh:263
fml::grid::exit
void exit()
Exits the BLACS grid, but does not shutdown BLACS/MPI.
Definition: grid.hh:226
fml::grid::grid
grid()
Create a new grid object. Does not initialize any BLACS or MPI data.
Definition: grid.hh:162
fml::blacsops
blacsops
Supported operations in reduce/allreduce.
Definition: grid.hh:55
fml::PROC_GRID_WIDE
@ PROC_GRID_WIDE
A grid with 1 row and as many columns as there are MPI ranks.
Definition: grid.hh:46
fml::grid::ingrid
bool ingrid() const
Check if the executing process is in the grid, i.e., if neither the process row nor column are -1.
Definition: grid.hh:303
fml::grid::recv
void recv(const int m, const int n, int *x, const int rsrc=0, const int csrc=0) const
Point-to-point receive. Should be matched by a corresponding 'send' call.
Definition: grid.hh:363
fml
Core namespace.
Definition: dimops.hh:10
fml::grid::set
void set(const int blacs_context)
Create a grid object from an existing BLACS process grid.
Definition: grid.hh:212
fml::grid::rank0
bool rank0() const
Check if the executing process is rank 0, i.e., if the process row and column are 0.
Definition: grid.hh:292
fml::grid::valid_grid
bool valid_grid() const
Is the BLACS grid valid?
Definition: grid.hh:133
fml::grid::npcol
int npcol() const
The number of processes columns in the BLACS context.
Definition: grid.hh:125
fml::grid::nprocs
int nprocs() const
The total number of processes bound to the BLACS context.
Definition: grid.hh:121
fml::grid::myrow
int myrow() const
The process row (0-based index) of the calling process.
Definition: grid.hh:127
fml::grid::nprow
int nprow() const
The number of processes rows in the BLACS context.
Definition: grid.hh:123