(git:b5558c7)
Loading...
Searching...
No Matches
dbm_distribution.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: BSD-3-Clause */
6/*----------------------------------------------------------------------------*/
7#include "dbm_distribution.h"
8#include "dbm_hyperparams.h"
9#include "dbm_internal.h"
10
11#include <assert.h>
12#include <math.h>
13#include <omp.h>
14#include <stdbool.h>
15#include <stddef.h>
16#include <stdlib.h>
17#include <string.h>
18
19/*******************************************************************************
20 * \brief Private routine for creating a new one dimensional distribution.
21 * \author Ole Schuett
22 ******************************************************************************/
23static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
24 const int coords[length], const cp_mpi_comm_t comm,
25 const int nshards) {
26 dist->comm = comm;
27 dist->nshards = nshards;
28 dist->my_rank = cp_mpi_comm_rank(comm);
29 dist->nranks = cp_mpi_comm_size(comm);
30 dist->length = length;
31 dist->index2coord = malloc(length * sizeof(int));
32 assert(dist->index2coord != NULL || length == 0);
33 if (length != 0) {
34 memcpy(dist->index2coord, coords, length * sizeof(int));
35 }
36
37 // Check that cart coordinates and ranks are equivalent.
38 int cart_dims[1], cart_periods[1], cart_coords[1];
39 cp_mpi_cart_get(comm, 1, cart_dims, cart_periods, cart_coords);
40 assert(dist->nranks == cart_dims[0]);
41 assert(dist->my_rank == cart_coords[0]);
42
43 // Count local rows/columns.
44 for (int i = 0; i < length; i++) {
45 assert(0 <= coords[i] && coords[i] < dist->nranks);
46 if (coords[i] == dist->my_rank) {
47 dist->nlocals++;
48 }
49 }
50
51 // Store local rows/columns.
52 dist->local_indicies = malloc(dist->nlocals * sizeof(int));
53 assert(dist->local_indicies != NULL || dist->nlocals == 0);
54 int j = 0;
55 for (int i = 0; i < length; i++) {
56 if (coords[i] == dist->my_rank) {
57 dist->local_indicies[j++] = i;
58 }
59 }
60 assert(j == dist->nlocals);
61}
62
63/*******************************************************************************
64 * \brief Private routine for releasing a one dimensional distribution.
65 * \author Ole Schuett
66 ******************************************************************************/
67static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
68 free(dist->index2coord);
69 free(dist->local_indicies);
70 cp_mpi_comm_free(&dist->comm);
71}
72
73/*******************************************************************************
74 * \brief Private routine for finding the optimal number of shard rows.
75 * \author Ole Schuett
76 ******************************************************************************/
77static int find_best_nrow_shards(const int nshards, const int nrows,
78 const int ncols) {
79 const double target = imax(nrows, 1) / (double)imax(ncols, 1);
80 int best_nrow_shards = nshards;
81 double best_error = fabs(log(target / (double)nshards));
82
83 for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
84 const int ncol_shards = nshards / nrow_shards;
85 if (nrow_shards * ncol_shards != nshards) {
86 continue; // Not a factor of nshards.
87 }
88 const double ratio = (double)nrow_shards / (double)ncol_shards;
89 const double error = fabs(log(target / ratio));
90 if (error < best_error) {
91 best_error = error;
92 best_nrow_shards = nrow_shards;
93 }
94 }
95 return best_nrow_shards;
96}
97
98/*******************************************************************************
99 * \brief Creates a new two dimensional distribution.
100 * \author Ole Schuett
101 ******************************************************************************/
102void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
103 const int nrows, const int ncols,
104 const int row_dist[nrows],
105 const int col_dist[ncols]) {
106 assert(omp_get_num_threads() == 1);
107 dbm_distribution_t *dist = calloc(1, sizeof(dbm_distribution_t));
108 dist->ref_count = 1;
109
110 dist->comm = cp_mpi_comm_f2c(fortran_comm);
111 dist->my_rank = cp_mpi_comm_rank(dist->comm);
112 dist->nranks = cp_mpi_comm_size(dist->comm);
113
114 const int row_dim_remains[2] = {1, 0};
115 const cp_mpi_comm_t row_comm = cp_mpi_cart_sub(dist->comm, row_dim_remains);
116
117 const int col_dim_remains[2] = {0, 1};
118 const cp_mpi_comm_t col_comm = cp_mpi_cart_sub(dist->comm, col_dim_remains);
119
120 const int nshards = DBM_SHARDS_PER_THREAD * omp_get_max_threads();
121 const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
122 const int ncol_shards = nshards / nrow_shards;
123
124 dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
125 dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);
126
127 assert(*dist_out == NULL);
128 *dist_out = dist;
129}
130
131/*******************************************************************************
132 * \brief Increases the reference counter of the given distribution.
133 * \author Ole Schuett
134 ******************************************************************************/
135void dbm_distribution_hold(dbm_distribution_t *dist) {
136 assert(dist->ref_count > 0);
137 dist->ref_count++;
138}
139
140/*******************************************************************************
141 * \brief Decreases the reference counter of the given distribution.
142 * \author Ole Schuett
143 ******************************************************************************/
144void dbm_distribution_release(dbm_distribution_t *dist) {
145 assert(dist->ref_count > 0);
146 dist->ref_count--;
147 if (dist->ref_count == 0) {
148 dbm_dist_1d_free(&dist->rows);
149 dbm_dist_1d_free(&dist->cols);
150 free(dist);
151 }
152}
153
154/*******************************************************************************
155 * \brief Returns the rows of the given distribution.
156 * \author Ole Schuett
157 ******************************************************************************/
158void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows,
159 const int **row_dist) {
160 assert(dist->ref_count > 0);
161 *nrows = dist->rows.length;
162 *row_dist = dist->rows.index2coord;
163}
164
165/*******************************************************************************
166 * \brief Returns the columns of the given distribution.
167 * \author Ole Schuett
168 ******************************************************************************/
169void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols,
170 const int **col_dist) {
171 assert(dist->ref_count > 0);
172 *ncols = dist->cols.length;
173 *col_dist = dist->cols.index2coord;
174}
175
176/*******************************************************************************
177 * \brief Returns the MPI rank on which the given block should be stored.
178 * \author Ole Schuett
179 ******************************************************************************/
181 const int row, const int col) {
182 assert(dist->ref_count > 0);
183 assert(0 <= row && row < dist->rows.length);
184 assert(0 <= col && col < dist->cols.length);
185 int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
186 return cp_mpi_cart_rank(dist->comm, coords);
187}
188
189// EOF
int cp_mpi_comm_size(const cp_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
Definition cp_mpi.c:113
int cp_mpi_cart_rank(const cp_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
Definition cp_mpi.c:184
void cp_mpi_cart_get(const cp_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
Definition cp_mpi.c:166
void cp_mpi_comm_free(cp_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
Definition cp_mpi.c:217
cp_mpi_comm_t cp_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
Definition cp_mpi.c:69
cp_mpi_comm_t cp_mpi_cart_sub(const cp_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
Definition cp_mpi.c:200
int cp_mpi_comm_rank(const cp_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
Definition cp_mpi.c:96
int cp_mpi_comm_t
Definition cp_mpi.h:18
int dbm_distribution_stored_coords(const dbm_distribution_t *dist, const int row, const int col)
Returns the MPI rank on which the given block should be stored.
static void dbm_dist_1d_free(dbm_dist_1d_t *dist)
Private routine for releasing a one dimensional distribution.
static int find_best_nrow_shards(const int nshards, const int nrows, const int ncols)
Private routine for finding the optimal number of shard rows.
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, const int coords[length], const cp_mpi_comm_t comm, const int nshards)
Private routine for creating a new one dimensional distribution.
#define DBM_SHARDS_PER_THREAD
static int imax(int x, int y)
Returns the larger of two given integers (missing from the C standard)
static void const int const int i
Internal struct for storing a one dimensional distribution.
Internal struct for storing a two dimensional distribution.