fml  0.1-0
Fused Matrix Library
comm.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_PAR_COMM_H
6 #define FML_PAR_COMM_H
7 #pragma once
8 
9 
10 #define OMPI_SKIP_MPICXX 1
11 #include <mpi.h>
12 
13 #include <cstdarg>
14 #include <cstdint>
15 #include <cstdio>
16 #include <stdexcept>
17 #include <typeinfo>
18 #include <vector>
19 
20 
21 namespace fml
22 {
24  class comm
25  {
26  public:
27  // constructors/destructor and comm management
28  comm(MPI_Comm comm=MPI_COMM_WORLD);
29  void set(MPI_Comm comm);
30  comm create(MPI_Group group);
31  comm split(int color, int key);
32  void free();
33  void finalize();
34 
35  // utilities
36  void printf(int rank, const char *fmt, ...) const;
37  void info() const;
38  bool rank0() const;
39  std::vector<int> jid(const int n) const;
40 
41  // send/recv
42  template <typename T>
43  void send(int n, const T *data, int dest, int tag=0) const;
44  template <typename T>
45  void isend(int n, const T *data, int dest, int tag=0) const;
46  template <typename T>
47  void recv(int n, T *data, int source, int tag=0) const;
48  template <typename T>
49  void irecv(int n, T *data, int source, int tag=0) const;
50 
51  // collectives
52  void barrier() const;
53  template <typename T>
54  void allreduce(int n, T *data, MPI_Op op=MPI_SUM) const;
55  template <typename T>
56  void reduce(int n, T *data, MPI_Op op=MPI_SUM, int root=0) const;
57  template <typename T>
58  void bcast(int n, T *data, int root) const;
59 
62  MPI_Comm get_comm() const {return _comm;};
64  int rank() const {return _rank;};
66  int size() const {return _size;};
68  int localrank() const {return _localrank;};
70  int localsize() const {return _localsize;};
72 
73  protected:
74  MPI_Comm _comm;
75  int _rank;
76  int _size;
77  int _localrank;
78  int _localsize;
79 
80  private:
81  void init();
82  void set_metadata();
83  void check_ret(const int ret) const;
84  template <typename T>
85  MPI_Datatype mpi_type_lookup(const T *x) const;
86  };
87 }
88 
89 
90 
91 // -----------------------------------------------------------------------------
92 // public
93 // -----------------------------------------------------------------------------
94 
95 // constructors/destructor and comm management
96 
103 inline fml::comm::comm(MPI_Comm comm)
104 {
105  init();
106 
107  _comm = comm;
108 
109  set_metadata();
110 }
111 
112 
113 
119 inline void fml::comm::set(MPI_Comm comm)
120 {
121  _comm = comm;
122  set_metadata();
123 }
124 
125 
126 
133 inline fml::comm fml::comm::create(MPI_Group group)
134 {
135  MPI_Comm newcomm;
136  int mpi_ret = MPI_Comm_create(_comm, group, &newcomm);
137  check_ret(mpi_ret);
138 
139  fml::comm ret(newcomm);
140  return ret;
141 }
142 
143 
144 
152 inline fml::comm fml::comm::split(int color, int key)
153 {
154  MPI_Comm newcomm;
155  int mpi_ret = MPI_Comm_split(_comm, color, key, &newcomm);
156  check_ret(mpi_ret);
157 
158  fml::comm ret(newcomm);
159  return ret;
160 }
161 
162 
163 
165 inline void fml::comm::free()
166 {
167  int mpi_ret = MPI_Comm_free(&_comm);
168  check_ret(mpi_ret);
169 
170  _comm = MPI_COMM_NULL;
171  _rank = -1;
172  _size = -1;
173 }
174 
175 
176 
178 inline void fml::comm::finalize()
179 {
180  int ret = MPI_Finalize();
181  check_ret(ret);
182 }
183 
184 
185 
186 // utilities
187 
197 inline void fml::comm::printf(int rank, const char *fmt, ...) const
198 {
199  if (_rank == rank)
200  {
201  va_list args;
202 
203  va_start(args, fmt);
204  vfprintf(stdout, fmt, args);
205  va_end(args);
206  }
207 }
208 
209 
210 
215 inline void fml::comm::info() const
216 {
217  printf(0, "## MPI on %d ranks\n\n", _size);
218 }
219 
220 
221 
223 inline bool fml::comm::rank0() const
224 {
225  return (_rank == 0);
226 }
227 
228 
229 
238 inline std::vector<int> fml::comm::jid(const int n) const
239 {
240  std::vector<int> ret;
241 
242  if (n > _size)
243  {
244  int local = n / _size;
245  int rem = n % _size;
246 
247  if (rem == 0 || (_rank < (_size - rem)))
248  {
249  ret.resize(local);
250  for (int i=0; i<local; i++)
251  ret[i] = i + (_rank*local);
252  }
253  else
254  {
255  ret.resize(local+1);
256  for (int i=0; i<=local; i++)
257  ret[i] = i + (_rank*(local+1)) - (_size - rem);
258  }
259  }
260  else
261  {
262  if (n > _rank)
263  {
264  ret.resize(1);
265  ret[0] = _rank;
266  }
267  else
268  ret.resize(0);
269  }
270 
271  return ret;
272 }
273 
274 
275 
276 // send/recv
277 
286 template <typename T>
288 inline void fml::comm::send(int n, const T *data, int dest, int tag) const
289 {
290  MPI_Datatype type = mpi_type_lookup(data);
291  int ret = MPI_Send(data, n, type, dest, tag, _comm);
292  check_ret(ret);
293 }
294 
295 template <typename T>
296 inline void fml::comm::isend(int n, const T *data, int dest, int tag) const
297 {
298  MPI_Datatype type = mpi_type_lookup(data);
299  int ret = MPI_Isend(data, n, type, dest, tag, _comm);
300  check_ret(ret);
301 }
303 
304 
305 
315 template <typename T>
317 inline void fml::comm::recv(int n, T *data, int source, int tag) const
318 {
319  MPI_Datatype type = mpi_type_lookup(data);
320  int ret = MPI_Recv(data, n, type, source, tag, _comm, MPI_STATUS_IGNORE);
321  check_ret(ret);
322 }
323 
324 template <typename T>
325 inline void fml::comm::irecv(int n, T *data, int source, int tag) const
326 {
327  MPI_Datatype type = mpi_type_lookup(data);
328  int ret = MPI_Irecv(data, n, type, source, tag, _comm, MPI_STATUS_IGNORE);
329  check_ret(ret);
330 }
332 
333 
334 
335 // collectives
336 
340 inline void fml::comm::barrier() const
341 {
342  int ret = MPI_Barrier(_comm);
343  check_ret(ret);
344 }
345 
346 
347 
355 template <typename T>
357 inline void fml::comm::allreduce(int n, T *data, MPI_Op op) const
358 {
359  MPI_Datatype type = mpi_type_lookup(data);
360  int ret = MPI_Allreduce(MPI_IN_PLACE, data, n, type, op, _comm);
361  check_ret(ret);
362 }
364 
365 
366 
374 template <typename T>
376 inline void fml::comm::reduce(int n, T *data, MPI_Op op, int root) const
377 {
378  MPI_Datatype type = mpi_type_lookup(data);
379  int ret = MPI_Reduce(MPI_IN_PLACE, data, n, type, op, root, _comm);
380  check_ret(ret);
381 }
383 
384 
385 
393 template <typename T>
395 inline void fml::comm::bcast(int n, T *data, int root) const
396 {
397  MPI_Datatype type = mpi_type_lookup(data);
398  int ret = MPI_Bcast(data, n, type, root, _comm);
399  check_ret(ret);
400 }
402 
403 
404 
405 // -----------------------------------------------------------------------------
406 // private
407 // -----------------------------------------------------------------------------
408 
409 inline void fml::comm::init()
410 {
411  int ret;
412  int flag;
413 
414  ret = MPI_Initialized(&flag);
415  check_ret(ret);
416 
417  if (!flag)
418  {
419  ret = MPI_Init(NULL, NULL);
420  check_ret(ret);
421  }
422 }
423 
424 
425 
426 inline void fml::comm::set_metadata()
427 {
428  int ret;
429 
430  ret = MPI_Comm_rank(_comm, &_rank);
431  check_ret(ret);
432  ret = MPI_Comm_size(_comm, &_size);
433  check_ret(ret);
434 
435  MPI_Comm localcomm;
436  ret = MPI_Comm_split_type(_comm, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &localcomm);
437  check_ret(ret);
438 
439  ret = MPI_Comm_rank(localcomm, &_localrank);
440  check_ret(ret);
441  ret = MPI_Comm_size(localcomm, &_localsize);
442  check_ret(ret);
443 }
444 
445 
446 
447 inline void fml::comm::check_ret(const int ret) const
448 {
449  if (ret != MPI_SUCCESS && _rank == 0)
450  {
451  int slen;
452  char s[MPI_MAX_ERROR_STRING];
453 
454  MPI_Error_string(ret, s, &slen);
455  throw std::runtime_error(s);
456  }
457 }
458 
459 
460 
461 template <typename T>
462 inline MPI_Datatype fml::comm::mpi_type_lookup(const T *x) const
463 {
464  (void) x;
465 
466  // C types
467  if (typeid(T) == typeid(char))
468  return MPI_CHAR;
469  else if (typeid(T) == typeid(double))
470  return MPI_DOUBLE;
471  else if (typeid(T) == typeid(float))
472  return MPI_FLOAT;
473  else if (typeid(T) == typeid(int))
474  return MPI_INT;
475  else if (typeid(T) == typeid(long))
476  return MPI_LONG;
477  else if (typeid(T) == typeid(long double))
478  return MPI_LONG_DOUBLE;
479  else if (typeid(T) == typeid(long long))
480  return MPI_LONG_LONG_INT;
481  else if (typeid(T) == typeid(short))
482  return MPI_SHORT;
483  else if (typeid(T) == typeid(unsigned int))
484  return MPI_UNSIGNED;
485  else if (typeid(T) == typeid(unsigned char))
486  return MPI_UNSIGNED_CHAR;
487  else if (typeid(T) == typeid(unsigned long))
488  return MPI_UNSIGNED_LONG;
489  else if (typeid(T) == typeid(unsigned short))
490  return MPI_UNSIGNED_SHORT;
491  else if (typeid(T) == typeid(uint32_t))
492  return MPI_UINT32_T;
493 
494  // stdint types
495  else if (typeid(T) == typeid(int8_t))
496  return MPI_INT8_T;
497  else if (typeid(T) == typeid(int16_t))
498  return MPI_INT16_T;
499  else if (typeid(T) == typeid(int32_t))
500  return MPI_INT32_T;
501  else if (typeid(T) == typeid(int64_t))
502  return MPI_INT64_T;
503  else if (typeid(T) == typeid(uint8_t))
504  return MPI_UINT8_T;
505  else if (typeid(T) == typeid(uint16_t))
506  return MPI_UINT16_T;
507  else if (typeid(T) == typeid(uint32_t))
508  return MPI_UINT32_T;
509  else if (typeid(T) == typeid(uint64_t))
510  return MPI_UINT64_T;
511 
512  else
513  return MPI_DATATYPE_NULL;
514 }
515 
516 
517 
518 #endif
fml::comm::send
void send(int n, const T *data, int dest, int tag=0) const
Point-to-point send. Should be matched by a corresponding 'recv' call.
Definition: comm.hh:288
fml::comm::finalize
void finalize()
Shut down MPI.
Definition: comm.hh:178
fml::comm::recv
void recv(int n, T *data, int source, int tag=0) const
Point-to-point receive. Should be matched by a corresponding 'send' call.
Definition: comm.hh:317
fml::comm::size
int size() const
Total number of ranks in the MPI communicator. The same across all ranks.
Definition: comm.hh:66
fml::comm::localsize
int localsize() const
Total number of ranks within the node. Can vary across nodes.
Definition: comm.hh:70
fml::comm::jid
std::vector< int > jid(const int n) const
Definition: comm.hh:238
fml::comm::allreduce
void allreduce(int n, T *data, MPI_Op op=MPI_SUM) const
Sum reduce operation across all processes in the MPI communicator.
Definition: comm.hh:357
fml::comm::comm
comm(MPI_Comm comm=MPI_COMM_WORLD)
Create a new comm object and uses 'MPI_COMM_WORLD' as the communicator.
Definition: comm.hh:103
fml::comm::bcast
void bcast(int n, T *data, int root) const
Broadcast.
Definition: comm.hh:395
fml::comm::set
void set(MPI_Comm comm)
Change communicator to an existing one.
Definition: comm.hh:119
fml::comm::reduce
void reduce(int n, T *data, MPI_Op op=MPI_SUM, int root=0) const
Sum reduce operation across all processes in the MPI communicator.
Definition: comm.hh:376
fml::comm::barrier
void barrier() const
Execute a barrier.
Definition: comm.hh:340
fml::comm
MPI communicator data and helpers.
Definition: comm.hh:24
fml::comm::free
void free()
Destroy communicator.
Definition: comm.hh:165
fml::comm::create
comm create(MPI_Group group)
Create new communicator based on color/key.
Definition: comm.hh:133
fml::comm::printf
void printf(int rank, const char *fmt,...) const
Helper wrapper around the C standard I/O 'printf()' function. Conceptually similar to guarding a norm...
Definition: comm.hh:197
fml::comm::rank0
bool rank0() const
Check if the executing process is rank 0.
Definition: comm.hh:223
fml
Core namespace.
Definition: dimops.hh:10
fml::comm::info
void info() const
Print some brief information about the MPI communicator. The printing is done by rank 0.
Definition: comm.hh:215
fml::comm::split
comm split(int color, int key)
Create new communicator based on color/key.
Definition: comm.hh:152
fml::comm::get_comm
MPI_Comm get_comm() const
Definition: comm.hh:62
fml::comm::localrank
int localrank() const
Calling process rank (0-based index) in within the node.
Definition: comm.hh:68
fml::comm::rank
int rank() const
Calling process rank (0-based index) in the MPI communicator.
Definition: comm.hh:64