27 dist->nshards = nshards;
30 dist->length = length;
31 dist->index2coord = malloc(length *
sizeof(
int));
32 memcpy(dist->index2coord, coords, length *
sizeof(
int));
35 int cart_dims[1], cart_periods[1], cart_coords[1];
37 assert(dist->nranks == cart_dims[0]);
38 assert(dist->my_rank == cart_coords[0]);
41 for (
int i = 0;
i < length;
i++) {
42 assert(0 <= coords[
i] && coords[
i] < dist->nranks);
43 if (coords[
i] == dist->my_rank) {
49 dist->local_indicies = malloc(dist->nlocals *
sizeof(
int));
51 for (
int i = 0;
i < length;
i++) {
52 if (coords[
i] == dist->my_rank) {
53 dist->local_indicies[j++] =
i;
56 assert(j == dist->nlocals);
64 free(dist->index2coord);
65 free(dist->local_indicies);
73 static inline int imax(
int x,
int y) {
return (x > y ? x : y); }
81 const double target = (double)
imax(nrows, 1) / (double)
imax(ncols, 1);
82 int best_nrow_shards = nshards;
83 double best_error = fabs(log(target / (
double)nshards));
85 for (
int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
86 const int ncol_shards = nshards / nrow_shards;
87 if (nrow_shards * ncol_shards != nshards)
89 const double ratio = (double)nrow_shards / (
double)ncol_shards;
90 const double error = fabs(log(target / ratio));
91 if (error < best_error) {
93 best_nrow_shards = nrow_shards;
96 return best_nrow_shards;
104 const int nrows,
const int ncols,
105 const int row_dist[nrows],
106 const int col_dist[ncols]) {
107 assert(omp_get_num_threads() == 1);
115 const int row_dim_remains[2] = {1, 0};
118 const int col_dim_remains[2] = {0, 1};
123 const int ncol_shards = nshards / nrow_shards;
128 assert(*dist_out == NULL);
137 assert(dist->ref_count > 0);
146 assert(dist->ref_count > 0);
148 if (dist->ref_count == 0) {
160 const int **row_dist) {
161 assert(dist->ref_count > 0);
162 *nrows = dist->rows.length;
163 *row_dist = dist->rows.index2coord;
171 const int **col_dist) {
172 assert(dist->ref_count > 0);
173 *ncols = dist->cols.length;
174 *col_dist = dist->cols.index2coord;
182 const int row,
const int col) {
183 assert(dist->ref_count > 0);
184 assert(0 <= row && row < dist->rows.length);
185 assert(0 <= col && col < dist->cols.length);
186 int coords[2] = {dist->rows.index2coord[row], dist->cols.index2coord[col]};
int dbm_distribution_stored_coords(const dbm_distribution_t *dist, const int row, const int col)
Returns the MPI rank on which the given block should be stored.
static void dbm_dist_1d_free(dbm_dist_1d_t *dist)
Private routine for releasing a one dimensional distribution.
void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm, const int nrows, const int ncols, const int row_dist[nrows], const int col_dist[ncols])
Creates a new two dimensional distribution.
static int find_best_nrow_shards(const int nshards, const int nrows, const int ncols)
Private routine for finding the optimal number of shard rows.
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length, const int coords[length], const dbm_mpi_comm_t comm, const int nshards)
Private routine for creating a new one dimensional distribution.
void dbm_distribution_hold(dbm_distribution_t *dist)
Increases the reference counter of the given distribution.
void dbm_distribution_release(dbm_distribution_t *dist)
Decreases the reference counter of the given distribution.
void dbm_distribution_row_dist(const dbm_distribution_t *dist, int *nrows, const int **row_dist)
Returns the rows of the given distribution.
void dbm_distribution_col_dist(const dbm_distribution_t *dist, int *ncols, const int **col_dist)
Returns the columns of the given distribution.
static int imax(int x, int y)
Returns the larger of two given integer (missing from the C standard)
static const float SHARDS_PER_THREAD
int dbm_mpi_comm_rank(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_rank.
int dbm_mpi_cart_rank(const dbm_mpi_comm_t comm, const int coords[])
Wrapper around MPI_Cart_rank.
int dbm_mpi_comm_size(const dbm_mpi_comm_t comm)
Wrapper around MPI_Comm_size.
dbm_mpi_comm_t dbm_mpi_cart_sub(const dbm_mpi_comm_t comm, const int remain_dims[])
Wrapper around MPI_Cart_sub.
void dbm_mpi_comm_free(dbm_mpi_comm_t *comm)
Wrapper around MPI_Comm_free.
void dbm_mpi_cart_get(const dbm_mpi_comm_t comm, int maxdims, int dims[], int periods[], int coords[])
Wrapper around MPI_Cart_get.
dbm_mpi_comm_t dbm_mpi_comm_f2c(const int fortran_comm)
Wrapper around MPI_Comm_f2c.
static void const int const int i
Internal struct for storing a one dimensional distribution.
Internal struct for storing a two dimensional distribution.