(git:e7e05ae)
dbm_distribution.c
Go to the documentation of this file.
1 /*----------------------------------------------------------------------------*/
2 /* CP2K: A general program to perform molecular dynamics simulations */
3 /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 /* */
5 /* SPDX-License-Identifier: BSD-3-Clause */
6 /*----------------------------------------------------------------------------*/
7 
8 #include <assert.h>
9 #include <math.h>
10 #include <omp.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdlib.h>
14 #include <string.h>
15 
16 #include "dbm_distribution.h"
17 #include "dbm_hyperparams.h"
18 
19 /*******************************************************************************
20  * \brief Private routine for creating a new one dimensional distribution.
21  * \author Ole Schuett
22  ******************************************************************************/
23 static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
24  const int coords[length], const dbm_mpi_comm_t comm,
25  const int nshards) {
26  dist->comm = comm;
27  dist->nshards = nshards;
28  dist->my_rank = dbm_mpi_comm_rank(comm);
29  dist->nranks = dbm_mpi_comm_size(comm);
30  dist->length = length;
31  dist->index2coord = malloc(length * sizeof(int));
32  memcpy(dist->index2coord, coords, length * sizeof(int));
33 
34  // Check that cart coordinates and ranks are equivalent.
35  int cart_dims[1], cart_periods[1], cart_coords[1];
36  dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
37  assert(dist->nranks == cart_dims[0]);
38  assert(dist->my_rank == cart_coords[0]);
39 
40  // Count local rows/columns.
41  for (int i = 0; i < length; i++) {
42  assert(0 <= coords[i] && coords[i] < dist->nranks);
43  if (coords[i] == dist->my_rank) {
44  dist->nlocals++;
45  }
46  }
47 
48  // Store local rows/columns.
49  dist->local_indicies = malloc(dist->nlocals * sizeof(int));
50  int j = 0;
51  for (int i = 0; i < length; i++) {
52  if (coords[i] == dist->my_rank) {
53  dist->local_indicies[j++] = i;
54  }
55  }
56  assert(j == dist->nlocals);
57 }
58 
59 /*******************************************************************************
60  * \brief Private routine for releasing a one dimensional distribution.
61  * \author Ole Schuett
62  ******************************************************************************/
63 static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
64  free(dist->index2coord);
65  free(dist->local_indicies);
66  dbm_mpi_comm_free(&dist->comm);
67 }
68 
69 /*******************************************************************************
70  * \brief Returns the larger of two given integer (missing from the C standard)
71  * \author Ole Schuett
72  ******************************************************************************/
73 static inline int imax(int x, int y) { return (x > y ? x : y); }
74 
75 /*******************************************************************************
76  * \brief Private routine for finding the optimal number of shard rows.
77  * \author Ole Schuett
78  ******************************************************************************/
79 static int find_best_nrow_shards(const int nshards, const int nrows,
80  const int ncols) {
81  const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1);
82  int best_nrow_shards = nshards;
83  double best_error = fabs(log(target / (double)nshards));
84 
85  for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
86  const int ncol_shards = nshards / nrow_shards;
87  if (nrow_shards * ncol_shards != nshards)
88  continue; // Not a factor of nshards.
89  const double ratio = (double)nrow_shards / (double)ncol_shards;
90  const double error = fabs(log(target / ratio));
91  if (error < best_error) {
92  best_error = error;
93  best_nrow_shards = nrow_shards;
94  }
95  }
96  return best_nrow_shards;
97 }
98 
99 /*******************************************************************************
100  * \brief Creates a new two dimensional distribution.
101  * \author Ole Schuett
102  ******************************************************************************/
103 void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
104  const int nrows, const int ncols,
105  const int row_dist[nrows],
106  const int col_dist[ncols]) {
107  assert(omp_get_num_threads() == 1);
108  dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
109  dist->ref_count = 1;
110 
111  dist->comm = dbm_mpi_comm_f2c(fortran_comm);
112  dist->my_rank = dbm_mpi_comm_rank(dist->comm);
113  dist->nranks = dbm_mpi_comm_size(dist->comm);
114 
115  const int row_dim_remains[2] = {1, 0};
116  const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
117 
118  const int col_dim_remains[2] = {0, 1};
119  const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
120 
121  const int nshards = SHARDS_PER_THREAD * omp_get_max_threads();
122  const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
123  const int ncol_shards = nshards / nrow_shards;
124 
125  dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
126  dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
127 
128  assert(*dist_out == NULL);
129  *dist_out = dist;
130 }
131 
132 /*******************************************************************************
133  * \brief Increases the reference counter of the given distribution.
134  * \author Ole Schuett
135  ******************************************************************************/
137  assert(dist->ref_count > 0);
138  dist->ref_count++;
139 }
140 
141 /*******************************************************************************
142  * \brief Decreases the reference counter of the given distribution.
143  * \author Ole Schuett
144  ******************************************************************************/
146  assert(dist->ref_count > 0);
147  dist->ref_count--;
148  if (dist->ref_count == 0) {
149  dbm_dist_1d_free(&dist->rows);
150  dbm_dist_1d_free(&dist->cols);
151  free(dist);
152  }
153 }
154 
155 /*******************************************************************************
156  * \brief Returns the rows of the given distribution.
157  * \author Ole Schuett
158  ******************************************************************************/
159 void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
160  const int **row_dist) {
161  assert(dist->ref_count > 0);
162  *nrows = dist->rows.length;
163  *row_dist = dist->rows.index2coord;
164 }
165 
166 /*******************************************************************************
167  * \brief Returns the columns of the given distribution.
168  * \author Ole Schuett
169  ******************************************************************************/
170 void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
171  const int **col_dist) {
172  assert(dist->ref_count > 0);
173  *ncols = dist->cols.length;
174  *col_dist = dist->cols.index2coord;
175 }
176 
177 /*******************************************************************************
178  * \brief Returns the MPI rank on which the given block should be stored.
179  * \author Ole Schuett
180  ******************************************************************************/
182  const int row, const int col) {
183  assert(dist->ref_count > 0);
184  assert(0 <= row && row < dist->rows.length);
185  assert(0 <= col && col < dist->cols.length);
186  int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
187  return dbm_mpi_cart_rank(dist->comm, coords);
188 }
189 
190 // EOF
int dbm_distribution_stored_coords(const dbm_distribution_t *dist, const int row, const int col)
Returns the MPI rank on which the given block should be stored.
static void dbm_dist_1d_free(dbm_dist_1d_t *dist)
Private routine for releasing a one dimensional distribution.
void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm, const int nrows, const int ncols, const int row_dist[nrows], const int col_dist[ncols])
Creates a new two dimensional distribution.
static int find_best_nrow_shards(const int nshards, const int nrows, const int ncols)
Private routine for finding the optimal number of shard rows.
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, const int coords[length], const dbm_mpi_comm_t comm, const int nshards)
Private routine for creating a new one dimensional distribution.
void dbm_distribution_hold(dbm_distribution_t *dist)
Increases the reference counter of the given distribution.
void dbm_distribution_release(dbm_distribution_t *dist)
Decreases the reference counter of the given distribution.
void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows, const int **row_dist)
Returns the rows of the given distribution.
void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols, const int **col_dist)
Returns the columns of the given distribution.
static int imax(int x, int y)
Returns the larger of two given integer (missing from the C standard)
static const float SHARDS_PER_THREAD
int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
Definition: dbm_mpi.c:92
int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
Definition: dbm_mpi.c:176
int dbm_mpi_comm_size(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
Definition: dbm_mpi.c:107
dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
Definition: dbm_mpi.c:192
void dbm_mpi_comm_free(dbm_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
Definition: dbm_mpi.c:209
void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
Definition: dbm_mpi.c:158
dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
Definition: dbm_mpi.c:66
int dbm_mpi_comm_t
Definition: dbm_mpi.h:18
static void const int const int i
Internal struct for storing a one dimensional distribution.
Internal struct for storing a two dimensional distribution.