(git:ed6f26b)
Loading...
Searching...
No Matches
dbm_distribution.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: BSD-3-Clause */
6/*----------------------------------------------------------------------------*/
7
8#include <assert.h>
9#include <math.h>
10#include <omp.h>
11#include <stdbool.h>
12#include <stddef.h>
13#include <stdlib.h>
14#include <string.h>
15
16#include "dbm_distribution.h"
17#include "dbm_hyperparams.h"
18#include "dbm_internal.h"
19
20/*******************************************************************************
21 * \brief Private routine for creating a new one dimensional distribution.
22 * \author Ole Schuett
23 ******************************************************************************/
24static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
25 const int coords[length], const dbm_mpi_comm_t comm,
26 const int nshards) {
27 dist->comm = comm;
28 dist->nshards = nshards;
29 dist->my_rank = dbm_mpi_comm_rank(comm);
30 dist->nranks = dbm_mpi_comm_size(comm);
31 dist->length = length;
32 dist->index2coord = malloc(length * sizeof(int));
33 assert(dist->index2coord != NULL);
34 memcpy(dist->index2coord, coords, length * sizeof(int));
35
36 // Check that cart coordinates and ranks are equivalent.
37 int cart_dims[1], cart_periods[1], cart_coords[1];
38 dbm_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
39 assert(dist->nranks == cart_dims[0]);
40 assert(dist->my_rank == cart_coords[0]);
41
42 // Count local rows/columns.
43 for (int i = 0; i < length; i++) {
44 assert(0 <= coords[i] && coords[i] < dist->nranks);
45 if (coords[i] == dist->my_rank) {
46 dist->nlocals++;
47 }
48 }
49
50 // Store local rows/columns.
51 dist->local_indicies = malloc(dist->nlocals * sizeof(int));
52 assert(dist->local_indicies != NULL);
53 int j = 0;
54 for (int i = 0; i < length; i++) {
55 if (coords[i] == dist->my_rank) {
56 dist->local_indicies[j++] = i;
57 }
58 }
59 assert(j == dist->nlocals);
60}
61
62/*******************************************************************************
63 * \brief Private routine for releasing a one dimensional distribution.
64 * \author Ole Schuett
65 ******************************************************************************/
66static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
67 free(dist->index2coord);
68 free(dist->local_indicies);
69 dbm_mpi_comm_free(&dist->comm);
70}
71
72/*******************************************************************************
73 * \brief Private routine for finding the optimal number of shard rows.
74 * \author Ole Schuett
75 ******************************************************************************/
76static int find_best_nrow_shards(const int nshards, const int nrows,
77 const int ncols) {
78 const double target = imax(nrows, 1) / (double)imax(ncols, 1);
79 int best_nrow_shards = nshards;
80 double best_error = fabs(log(target / (double)nshards));
81
82 for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
83 const int ncol_shards = nshards / nrow_shards;
84 if (nrow_shards * ncol_shards != nshards) {
85 continue; // Not a factor of nshards.
86 }
87 const double ratio = (double)nrow_shards / (double)ncol_shards;
88 const double error = fabs(log(target / ratio));
89 if (error < best_error) {
90 best_error = error;
91 best_nrow_shards = nrow_shards;
92 }
93 }
94 return best_nrow_shards;
95}
96
97/*******************************************************************************
98 * \brief Creates a new two dimensional distribution.
99 * \author Ole Schuett
100 ******************************************************************************/
101void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
102 const int nrows, const int ncols,
103 const int row_dist[nrows],
104 const int col_dist[ncols]) {
105 assert(omp_get_num_threads() == 1);
106 dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
107 dist->ref_count = 1;
108
109 dist->comm = dbm_mpi_comm_f2c(fortran_comm);
110 dist->my_rank = dbm_mpi_comm_rank(dist->comm);
111 dist->nranks = dbm_mpi_comm_size(dist->comm);
112
113 const int row_dim_remains[2] = {1, 0};
114 const dbm_mpi_comm_t row_comm = dbm_mpi_cart_sub(dist->comm, row_dim_remains);
115
116 const int col_dim_remains[2] = {0, 1};
117 const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);
118
119 const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
120 const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
121 const int ncol_shards = nshards / nrow_shards;
122
123 dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
124 dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
125
126 assert(*dist_out == NULL);
127 *dist_out = dist;
128}
129
130/*******************************************************************************
131 * \brief Increases the reference counter of the given distribution.
132 * \author Ole Schuett
133 ******************************************************************************/
134void dbm_distribution_hold(dbm_distribution_t *dist) {
135 assert(dist->ref_count > 0);
136 dist->ref_count++;
137}
138
139/*******************************************************************************
140 * \brief Decreases the reference counter of the given distribution.
141 * \author Ole Schuett
142 ******************************************************************************/
143void dbm_distribution_release(dbm_distribution_t *dist) {
144 assert(dist->ref_count > 0);
145 dist->ref_count--;
146 if (dist->ref_count == 0) {
147 dbm_dist_1d_free(&dist->rows);
148 dbm_dist_1d_free(&dist->cols);
149 free(dist);
150 }
151}
152
153/*******************************************************************************
154 * \brief Returns the rows of the given distribution.
155 * \author Ole Schuett
156 ******************************************************************************/
157void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
158 const int **row_dist) {
159 assert(dist->ref_count > 0);
160 *nrows = dist->rows.length;
161 *row_dist = dist->rows.index2coord;
162}
163
164/*******************************************************************************
165 * \brief Returns the columns of the given distribution.
166 * \author Ole Schuett
167 ******************************************************************************/
168void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
169 const int **col_dist) {
170 assert(dist->ref_count > 0);
171 *ncols = dist->cols.length;
172 *col_dist = dist->cols.index2coord;
173}
174
175/*******************************************************************************
176 * \brief Returns the MPI rank on which the given block should be stored.
177 * \author Ole Schuett
178 ******************************************************************************/
180 const int row, const int col) {
181 assert(dist->ref_count > 0);
182 assert(0 <= row && row < dist->rows.length);
183 assert(0 <= col && col < dist->cols.length);
184 int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
185 return dbm_mpi_cart_rank(dist->comm, coords);
186}
187
188// EOF
int dbm_distribution_stored_coords(const dbm_distribution_t *dist, const int row, const int col)
Returns the MPI rank on which the given block should be stored.
static void dbm_dist_1d_free(dbm_dist_1d_t *dist)
Private routine for releasing a one dimensional distribution.
static int find_best_nrow_shards(const int nshards, const int nrows, const int ncols)
Private routine for finding the optimal number of shard rows.
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, const int coords[length], const dbm_mpi_comm_t comm, const int nshards)
Private routine for creating a new one dimensional distribution.
#define DBM_SHARDS_PER_THREAD
static int imax(int x, int y)
Returns the larger of two given integers (missing from the C standard)
int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
Definition dbm_mpi.c:95
int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
Definition dbm_mpi.c:179
int dbm_mpi_comm_size(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
Definition dbm_mpi.c:110
dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
Definition dbm_mpi.c:195
void dbm_mpi_comm_free(dbm_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
Definition dbm_mpi.c:212
void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
Definition dbm_mpi.c:161
dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
Definition dbm_mpi.c:69
int dbm_mpi_comm_t
Definition dbm_mpi.h:19
static void const int const int i
Internal struct for storing a one dimensional distribution.
Internal struct for storing a two dimensional distribution.