(git:e7e05ae)
dbm_mpi.c
Go to the documentation of this file.
1 /*----------------------------------------------------------------------------*/
2 /* CP2K: A general program to perform molecular dynamics simulations */
3 /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 /* */
5 /* SPDX-License-Identifier: BSD-3-Clause */
6 /*----------------------------------------------------------------------------*/
7 
8 #include <assert.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 
13 #include "dbm_mpi.h"
14 
15 #if defined(__parallel)
16 /*******************************************************************************
17  * \brief Check given MPI status and upon failure abort with a nice message.
18  * \author Ole Schuett
19  ******************************************************************************/
20 #define CHECK(status) \
21  if (status != MPI_SUCCESS) { \
22  fprintf(stderr, "MPI error in %s:%i\n", __FILE__, __LINE__); \
23  MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); \
24  }
25 #endif
26 
27 /*******************************************************************************
28  * \brief Wrapper around MPI_Init.
29  * \author Ole Schuett
30  ******************************************************************************/
31 void dbm_mpi_init(int *argc, char ***argv) {
32 #if defined(__parallel)
33  CHECK(MPI_Init(argc, argv));
34 #else
35  (void)argc; // mark used
36  (void)argv;
37 #endif
38 }
39 
40 /*******************************************************************************
41  * \brief Wrapper around MPI_Finalize.
42  * \author Ole Schuett
43  ******************************************************************************/
45 #if defined(__parallel)
46  CHECK(MPI_Finalize());
47 #endif
48 }
49 
50 /*******************************************************************************
51  * \brief Returns MPI_COMM_WORLD.
52  * \author Ole Schuett
53  ******************************************************************************/
55 #if defined(__parallel)
56  return MPI_COMM_WORLD;
57 #else
58  return -1;
59 #endif
60 }
61 
62 /*******************************************************************************
63  * \brief Wrapper around MPI_Comm_f2c.
64  * \author Ole Schuett
65  ******************************************************************************/
66 dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm) {
67 #if defined(__parallel)
68  return MPI_Comm_f2c(fortran_comm);
69 #else
70  (void)fortran_comm; // mark used
71  return -1;
72 #endif
73 }
74 
75 /*******************************************************************************
76  * \brief Wrapper around MPI_Comm_c2f.
77  * \author Ole Schuett
78  ******************************************************************************/
80 #if defined(__parallel)
81  return MPI_Comm_c2f(comm);
82 #else
83  (void)comm; // mark used
84  return -1;
85 #endif
86 }
87 
88 /*******************************************************************************
89  * \brief Wrapper around MPI_Comm_rank.
90  * \author Ole Schuett
91  ******************************************************************************/
93 #if defined(__parallel)
94  int rank;
95  CHECK(MPI_Comm_rank(comm, &rank));
96  return rank;
97 #else
98  (void)comm; // mark used
99  return 0;
100 #endif
101 }
102 
103 /*******************************************************************************
104  * \brief Wrapper around MPI_Comm_size.
105  * \author Ole Schuett
106  ******************************************************************************/
108 #if defined(__parallel)
109  int nranks;
110  CHECK(MPI_Comm_size(comm, &nranks));
111  return nranks;
112 #else
113  (void)comm; // mark used
114  return 1;
115 #endif
116 }
117 
118 /*******************************************************************************
119  * \brief Wrapper around MPI_Dims_create.
120  * \author Ole Schuett
121  ******************************************************************************/
122 void dbm_mpi_dims_create(const int nnodes, const int ndims, int dims[]) {
123 #if defined(__parallel)
124  CHECK(MPI_Dims_create(nnodes, ndims, dims));
125 #else
126  dims[0] = nnodes;
127  for (int i = 1; i < ndims; i++) {
128  dims[i] = 1;
129  }
130 #endif
131 }
132 
133 /*******************************************************************************
134  * \brief Wrapper around MPI_Cart_create.
135  * \author Ole Schuett
136  ******************************************************************************/
138  const int ndims, const int dims[],
139  const int periods[], const int reorder) {
140 #if defined(__parallel)
141  dbm_mpi_comm_t comm_cart;
142  CHECK(MPI_Cart_create(comm_old, ndims, dims, periods, reorder, &comm_cart));
143  return comm_cart;
144 #else
145  (void)comm_old; // mark used
146  (void)ndims;
147  (void)dims;
148  (void)periods;
149  (void)reorder;
150  return -1;
151 #endif
152 }
153 
154 /*******************************************************************************
155  * \brief Wrapper around MPI_Cart_get.
156  * \author Ole Schuett
157  ******************************************************************************/
158 void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[],
159  int periods[], int coords[]) {
160 #if defined(__parallel)
161  CHECK(MPI_Cart_get(comm, maxdims, dims, periods, coords));
162 #else
163  (void)comm; // mark used
164  for (int i = 0; i < maxdims; i++) {
165  dims[i] = 1;
166  periods[i] = 1;
167  coords[i] = 0;
168  }
169 #endif
170 }
171 
172 /*******************************************************************************
173  * \brief Wrapper around MPI_Cart_rank.
174  * \author Ole Schuett
175  ******************************************************************************/
176 int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[]) {
177 #if defined(__parallel)
178  int rank;
179  CHECK(MPI_Cart_rank(comm, coords, &rank));
180  return rank;
181 #else
182  (void)comm; // mark used
183  (void)coords;
184  return 0;
185 #endif
186 }
187 
188 /*******************************************************************************
189  * \brief Wrapper around MPI_Cart_sub.
190  * \author Ole Schuett
191  ******************************************************************************/
193  const int remain_dims[]) {
194 #if defined(__parallel)
195  dbm_mpi_comm_t newcomm;
196  CHECK(MPI_Cart_sub(comm, remain_dims, &newcomm));
197  return newcomm;
198 #else
199  (void)comm; // mark used
200  (void)remain_dims;
201  return -1;
202 #endif
203 }
204 
205 /*******************************************************************************
206  * \brief Wrapper around MPI_Comm_free.
207  * \author Ole Schuett
208  ******************************************************************************/
210 #if defined(__parallel)
211  CHECK(MPI_Comm_free(comm));
212 #else
213  (void)comm; // mark used
214 #endif
215 }
216 
217 /*******************************************************************************
218  * \brief Wrapper around MPI_Comm_compare.
219  * \author Ole Schuett
220  ******************************************************************************/
222  const dbm_mpi_comm_t comm2) {
223 #if defined(__parallel)
224  int res;
225  CHECK(MPI_Comm_compare(comm1, comm2, &res));
226  return res == MPI_IDENT || res == MPI_CONGRUENT || res == MPI_SIMILAR;
227 #else
228  (void)comm1; // mark used
229  (void)comm2;
230  return true;
231 #endif
232 }
233 
234 /*******************************************************************************
235  * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_INT.
236  * \author Ole Schuett
237  ******************************************************************************/
238 void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm) {
239 #if defined(__parallel)
240  int *recvbuf = malloc(count * sizeof(int));
241  CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_MAX, comm));
242  memcpy(values, recvbuf, count * sizeof(int));
243  free(recvbuf);
244 #else
245  (void)comm; // mark used
246  (void)values;
247  (void)count;
248 #endif
249 }
250 
251 /*******************************************************************************
252  * \brief Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_DOUBLE.
253  * \author Ole Schuett
254  ******************************************************************************/
255 void dbm_mpi_max_double(double *values, const int count,
256  const dbm_mpi_comm_t comm) {
257 #if defined(__parallel)
258  double *recvbuf = malloc(count * sizeof(double));
259  CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_MAX, comm));
260  memcpy(values, recvbuf, count * sizeof(double));
261  free(recvbuf);
262 #else
263  (void)comm; // mark used
264  (void)values;
265  (void)count;
266 #endif
267 }
268 
269 /*******************************************************************************
270  * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT.
271  * \author Ole Schuett
272  ******************************************************************************/
273 void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm) {
274 #if defined(__parallel)
275  int *recvbuf = malloc(count * sizeof(int));
276  CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_SUM, comm));
277  memcpy(values, recvbuf, count * sizeof(int));
278  free(recvbuf);
279 #else
280  (void)comm; // mark used
281  (void)values;
282  (void)count;
283 #endif
284 }
285 
286 /*******************************************************************************
287  * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT64_T.
288  * \author Ole Schuett
289  ******************************************************************************/
290 void dbm_mpi_sum_int64(int64_t *values, const int count,
291  const dbm_mpi_comm_t comm) {
292 #if defined(__parallel)
293  int64_t *recvbuf = malloc(count * sizeof(int64_t));
294  CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT64_T, MPI_SUM, comm));
295  memcpy(values, recvbuf, count * sizeof(int64_t));
296  free(recvbuf);
297 #else
298  (void)comm; // mark used
299  (void)values;
300  (void)count;
301 #endif
302 }
303 
304 /*******************************************************************************
305  * \brief Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_DOUBLE.
306  * \author Ole Schuett
307  ******************************************************************************/
308 void dbm_mpi_sum_double(double *values, const int count,
309  const dbm_mpi_comm_t comm) {
310 #if defined(__parallel)
311  double *recvbuf = malloc(count * sizeof(double));
312  CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_SUM, comm));
313  memcpy(values, recvbuf, count * sizeof(double));
314  free(recvbuf);
315 #else
316  (void)comm; // mark used
317  (void)values;
318  (void)count;
319 #endif
320 }
321 
322 /*******************************************************************************
323  * \brief Wrapper around MPI_Sendrecv for datatype MPI_BYTE.
324  * \author Ole Schuett
325  ******************************************************************************/
326 int dbm_mpi_sendrecv_byte(const void *sendbuf, const int sendcount,
327  const int dest, const int sendtag, void *recvbuf,
328  const int recvcount, const int source,
329  const int recvtag, const dbm_mpi_comm_t comm) {
330 #if defined(__parallel)
331  MPI_Status status;
332  CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_BYTE, dest, sendtag, recvbuf,
333  recvcount, MPI_BYTE, source, recvtag, comm, &status))
334  int count_received;
335  CHECK(MPI_Get_count(&status, MPI_BYTE, &count_received));
336  return count_received;
337 #else
338  (void)sendbuf; // mark used
339  (void)sendcount;
340  (void)dest;
341  (void)sendtag;
342  (void)recvbuf;
343  (void)recvcount;
344  (void)source;
345  (void)recvtag;
346  (void)comm;
347  fprintf(stderr, "Error: dbm_mpi_sendrecv_byte not available without MPI\n");
348  abort();
349 #endif
350 }
351 
352 /*******************************************************************************
353  * \brief Wrapper around MPI_Sendrecv for datatype MPI_DOUBLE.
354  * \author Ole Schuett
355  ******************************************************************************/
356 int dbm_mpi_sendrecv_double(const double *sendbuf, const int sendcount,
357  const int dest, const int sendtag, double *recvbuf,
358  const int recvcount, const int source,
359  const int recvtag, const dbm_mpi_comm_t comm) {
360 #if defined(__parallel)
361  MPI_Status status;
362  CHECK(MPI_Sendrecv(sendbuf, sendcount, MPI_DOUBLE, dest, sendtag, recvbuf,
363  recvcount, MPI_DOUBLE, source, recvtag, comm, &status))
364  int count_received;
365  CHECK(MPI_Get_count(&status, MPI_DOUBLE, &count_received));
366  return count_received;
367 #else
368  (void)sendbuf; // mark used
369  (void)sendcount;
370  (void)dest;
371  (void)sendtag;
372  (void)recvbuf;
373  (void)recvcount;
374  (void)source;
375  (void)recvtag;
376  (void)comm;
377  fprintf(stderr, "Error: dbm_mpi_sendrecv_double not available without MPI\n");
378  abort();
379 #endif
380 }
381 
382 /*******************************************************************************
383  * \brief Wrapper around MPI_Alltoall for datatype MPI_INT.
384  * \author Ole Schuett
385  ******************************************************************************/
386 void dbm_mpi_alltoall_int(const int *sendbuf, const int sendcount, int *recvbuf,
387  const int recvcount, const dbm_mpi_comm_t comm) {
388 #if defined(__parallel)
389  CHECK(MPI_Alltoall(sendbuf, sendcount, MPI_INT, recvbuf, recvcount, MPI_INT,
390  comm));
391 #else
392  (void)comm; // mark used
393  assert(sendcount == recvcount);
394  memcpy(recvbuf, sendbuf, sendcount * sizeof(int));
395 #endif
396 }
397 
398 /*******************************************************************************
399  * \brief Wrapper around MPI_Alltoallv for datatype MPI_BYTE.
400  * \author Ole Schuett
401  ******************************************************************************/
402 void dbm_mpi_alltoallv_byte(const void *sendbuf, const int *sendcounts,
403  const int *sdispls, void *recvbuf,
404  const int *recvcounts, const int *rdispls,
405  const dbm_mpi_comm_t comm) {
406 #if defined(__parallel)
407  CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_BYTE, recvbuf,
408  recvcounts, rdispls, MPI_BYTE, comm));
409 #else
410  (void)comm; // mark used
411  assert(sendcounts[0] == recvcounts[0]);
412  assert(sdispls[0] == 0 && rdispls[0] == 0);
413  memcpy(recvbuf, sendbuf, sendcounts[0]);
414 #endif
415 }
416 
417 /*******************************************************************************
418  * \brief Wrapper around MPI_Alltoallv for datatype MPI_DOUBLE.
419  * \author Ole Schuett
420  ******************************************************************************/
421 void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts,
422  const int *sdispls, double *recvbuf,
423  const int *recvcounts, const int *rdispls,
424  const dbm_mpi_comm_t comm) {
425 #if defined(__parallel)
426  CHECK(MPI_Alltoallv(sendbuf, sendcounts, sdispls, MPI_DOUBLE, recvbuf,
427  recvcounts, rdispls, MPI_DOUBLE, comm));
428 #else
429  (void)comm; // mark used
430  assert(sendcounts[0] == recvcounts[0]);
431  assert(sdispls[0] == 0 && rdispls[0] == 0);
432  memcpy(recvbuf, sendbuf, sendcounts[0] * sizeof(double));
433 #endif
434 }
435 
436 // EOF
int dbm_mpi_sendrecv_byte(const void *sendbuf, const int sendcount, const int dest, const int sendtag, void *recvbuf, const int recvcount, const int source, const int recvtag, const dbm_mpi_comm_t comm)
Wrapper around MPI_Sendrecv for datatype MPI_BYTE.
Definition: dbm_mpi.c:326
void dbm_mpi_finalize()
Wrapper around MPI_Finalize.
Definition: dbm_mpi.c:44
int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
Definition: dbm_mpi.c:92
int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
Definition: dbm_mpi.c:176
bool dbm_mpi_comms_are_similar(const dbm_mpi_comm_t comm1, const dbm_mpi_comm_t comm2)
Wrapper around MPI_Comm_compare.
Definition: dbm_mpi.c:221
int dbm_mpi_comm_size(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
Definition: dbm_mpi.c:107
void dbm_mpi_alltoallv_byte(const void *sendbuf, const int *sendcounts, const int *sdispls, void *recvbuf, const int *recvcounts, const int *rdispls, const dbm_mpi_comm_t comm)
Wrapper around MPI_Alltoallv for datatype MPI_BYTE.
Definition: dbm_mpi.c:402
void dbm_mpi_init(int *argc, char ***argv)
Wrapper around MPI_Init.
Definition: dbm_mpi.c:31
void dbm_mpi_sum_double(double *values, const int count, const dbm_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_DOUBLE.
Definition: dbm_mpi.c:308
dbm_mpi_comm_t dbm_mpi_cart_create(const dbm_mpi_comm_t comm_old, const int ndims, const int dims[], const int periods[], const int reorder)
Wrapper around MPI_Cart_create.
Definition: dbm_mpi.c:137
void dbm_mpi_sum_int64(int64_t *values, const int count, const dbm_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT64_T.
Definition: dbm_mpi.c:290
void dbm_mpi_max_double(double *values, const int count, const dbm_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_DOUBLE.
Definition: dbm_mpi.c:255
void dbm_mpi_alltoall_int(const int *sendbuf, const int sendcount, int *recvbuf, const int recvcount, const dbm_mpi_comm_t comm)
Wrapper around MPI_Alltoall for datatype MPI_INT.
Definition: dbm_mpi.c:386
int dbm_mpi_sendrecv_double(const double *sendbuf, const int sendcount, const int dest, const int sendtag, double *recvbuf, const int recvcount, const int source, const int recvtag, const dbm_mpi_comm_t comm)
Wrapper around MPI_Sendrecv for datatype MPI_DOUBLE.
Definition: dbm_mpi.c:356
void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_MAX and datatype MPI_INT.
Definition: dbm_mpi.c:238
dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
Definition: dbm_mpi.c:192
void dbm_mpi_dims_create(const int nnodes, const int ndims, int dims[])
Wrapper around MPI_Dims_create.
Definition: dbm_mpi.c:122
int dbm_mpi_comm_c2f(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_c2f.
Definition: dbm_mpi.c:79
void dbm_mpi_comm_free(dbm_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
Definition: dbm_mpi.c:209
void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts, const int *sdispls, double *recvbuf, const int *recvcounts, const int *rdispls, const dbm_mpi_comm_t comm)
Wrapper around MPI_Alltoallv for datatype MPI_DOUBLE.
Definition: dbm_mpi.c:421
void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
Definition: dbm_mpi.c:158
dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
Definition: dbm_mpi.c:66
dbm_mpi_comm_t dbm_mpi_get_comm_world()
Returns MPI_COMM_WORLD.
Definition: dbm_mpi.c:54
void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT.
Definition: dbm_mpi.c:273
int dbm_mpi_comm_t
Definition: dbm_mpi.h:18
static void const int const int i