(git:374b731)
Loading...
Searching...
No Matches
dbm_distribution.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: BSD-3-Clause */
6/*----------------------------------------------------------------------------*/
7
8#include <assert.h>
9#include <math.h>
10#include <omp.h>
11#include <stdbool.h>
12#include <stddef.h>
13#include <stdlib.h>
14#include <string.h>
15
16#include "dbm_distribution.h"
17#include "dbm_hyperparams.h"
18
19/*******************************************************************************
20 * \brief Private routine for creating a new one dimensional distribution.
21 * \author Ole Schuett
22 ******************************************************************************/
23static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
24 const int coords[length], const dbm_mpi_comm_t comm,
25 const int nshards) {
26 dist->comm = comm;
27 dist->nshards = nshards;
28 dist->my_rank = dbm_mpi_comm_rank(comm);
29 dist->nranks = dbm_mpi_comm_size(comm);
30 dist->length = length;
31 dist->index2coord = malloc(length * sizeof(int));
32 memcpy(dist->index2coord, coords, length * sizeof(int));
33
34 // Check that cart coordinates and ranks are equivalent.
35 int cart_dims[1], cart_periods[1], cart_coords[1];
36 dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
37 assert(dist->nranks == cart_dims[0]);
38 assert(dist->my_rank == cart_coords[0]);
39
40 // Count local rows/columns.
41 for (int i = 0; i < length; i++) {
42 assert(0 <= coords[i] && coords[i] < dist->nranks);
43 if (coords[i] == dist->my_rank) {
44 dist->nlocals++;
45 }
46 }
47
48 // Store local rows/columns.
49 dist->local_indicies = malloc(dist->nlocals * sizeof(int));
50 int j = 0;
51 for (int i = 0; i < length; i++) {
52 if (coords[i] == dist->my_rank) {
53 dist->local_indicies[j++] = i;
54 }
55 }
56 assert(j == dist->nlocals);
57}
58
59/*******************************************************************************
60 * \brief Private routine for releasing a one dimensional distribution.
61 * \author Ole Schuett
62 ******************************************************************************/
63static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
64 free(dist->index2coord);
65 free(dist->local_indicies);
66 dbm_mpi_comm_free(&dist->comm);
67}
68
69/*******************************************************************************
70 * \brief Returns the larger of two given integer (missing from the C standard)
71 * \author Ole Schuett
72 ******************************************************************************/
73static inline int imax(int x, int y) { return (x > y ? x : y); }
74
75/*******************************************************************************
76 * \brief Private routine for finding the optimal number of shard rows.
77 * \author Ole Schuett
78 ******************************************************************************/
79static int find_best_nrow_shards(const int nshards, const int nrows,
80 const int ncols) {
81 const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1);
82 int best_nrow_shards = nshards;
83 double best_error = fabs(log(target / (double)nshards));
84
85 for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
86 const int ncol_shards = nshards / nrow_shards;
87 if (nrow_shards * ncol_shards != nshards)
88 continue; // Not a factor of nshards.
89 const double ratio = (double)nrow_shards / (double)ncol_shards;
90 const double error = fabs(log(target / ratio));
91 if (error < best_error) {
92 best_error = error;
93 best_nrow_shards = nrow_shards;
94 }
95 }
96 return best_nrow_shards;
97}
98
99/*******************************************************************************
100 * \brief Creates a new two dimensional distribution.
101 * \author Ole Schuett
102 ******************************************************************************/
103void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
104 const int nrows, const int ncols,
105 const int row_dist[nrows],
106 const int col_dist[ncols]) {
107 assert(omp_get_num_threads() == 1);
108 dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
109 dist->ref_count = 1;
110
111 dist->comm = dbm_mpi_comm_f2c(fortran_comm);
112 dist->my_rank = dbm_mpi_comm_rank(dist->comm);
113 dist->nranks = dbm_mpi_comm_size(dist->comm);
114
115 const int row_dim_remains[2] = {1, 0};
116 const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
117
118 const int col_dim_remains[2] = {0, 1};
119 const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
120
121 const int nshards = SHARDS_PER_THREAD * omp_get_max_threads();
122 const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
123 const int ncol_shards = nshards / nrow_shards;
124
125 dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
126 dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
127
128 assert(*dist_out == NULL);
129 *dist_out = dist;
130}
131
132/*******************************************************************************
133 * \brief Increases the reference counter of the given distribution.
134 * \author Ole Schuett
135 ******************************************************************************/
136void dbm_distribution_hold(dbm_distribution_t *dist) {
137 assert(dist->ref_count > 0);
138 dist->ref_count++;
139}
140
141/*******************************************************************************
142 * \brief Decreases the reference counter of the given distribution.
143 * \author Ole Schuett
144 ******************************************************************************/
145void dbm_distribution_release(dbm_distribution_t *dist) {
146 assert(dist->ref_count > 0);
147 dist->ref_count--;
148 if (dist->ref_count == 0) {
149 dbm_dist_1d_free(&dist->rows);
150 dbm_dist_1d_free(&dist->cols);
151 free(dist);
152 }
153}
154
155/*******************************************************************************
156 * \brief Returns the rows of the given distribution.
157 * \author Ole Schuett
158 ******************************************************************************/
159void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
160 const int **row_dist) {
161 assert(dist->ref_count > 0);
162 *nrows = dist->rows.length;
163 *row_dist = dist->rows.index2coord;
164}
165
166/*******************************************************************************
167 * \brief Returns the columns of the given distribution.
168 * \author Ole Schuett
169 ******************************************************************************/
170void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
171 const int **col_dist) {
172 assert(dist->ref_count > 0);
173 *ncols = dist->cols.length;
174 *col_dist = dist->cols.index2coord;
175}
176
177/*******************************************************************************
178 * \brief Returns the MPI rank on which the given block should be stored.
179 * \author Ole Schuett
180 ******************************************************************************/
182 const int row, const int col) {
183 assert(dist->ref_count > 0);
184 assert(0 <= row && row < dist->rows.length);
185 assert(0 <= col && col < dist->cols.length);
186 int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
187 return dbm_mpi_cart_rank(dist->comm, coords);
188}
189
190// EOF
int dbm_distribution_stored_coords(const dbm_distribution_t *dist, const int row, const int col)
Returns the MPI rank on which the given block should be stored.
static void dbm_dist_1d_free(dbm_dist_1d_t *dist)
Private routine for releasing a one dimensional distribution.
static int find_best_nrow_shards(const int nshards, const int nrows, const int ncols)
Private routine for finding the optimal number of shard rows.
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, const int coords[length], const dbm_mpi_comm_t comm, const int nshards)
Private routine for creating a new one dimensional distribution.
static int imax(int x, int y)
Returns the larger of two given integer (missing from the C standard)
static const float SHARDS_PER_THREAD
int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
Definition dbm_mpi.c:92
int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
Definition dbm_mpi.c:176
int dbm_mpi_comm_size(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
Definition dbm_mpi.c:107
dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
Definition dbm_mpi.c:192
void dbm_mpi_comm_free(dbm_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
Definition dbm_mpi.c:209
void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
Definition dbm_mpi.c:158
dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
Definition dbm_mpi.c:66
int dbm_mpi_comm_t
Definition dbm_mpi.h:18
static void const int const int i
Internal struct for storing a one dimensional distribution.
Internal struct for storing a two dimensional distribution.