(git:1f9fd2c)
Loading...
Searching...
No Matches
dbm_multiply.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: BSD-3-Clause */
6/*----------------------------------------------------------------------------*/
7#include "dbm_multiply.h"
8#include "../offload/offload_mempool.h"
9#include "../offload/offload_runtime.h"
10#include "dbm_hyperparams.h"
11#include "dbm_internal.h"
12#include "dbm_library.h"
13#include "dbm_multiply_comm.h"
14#include "dbm_multiply_cpu.h"
15#include "dbm_multiply_gpu.h"
16
17#include <assert.h>
18#include <limits.h>
19#include <math.h>
20#include <omp.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24
25/*******************************************************************************
26 * \brief Private routine for computing the max filter threshold for each row.
27 * \author Ole Schuett
28 ******************************************************************************/
29static float *compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix,
30 const double filter_eps) {
31 const int nrows = (trans) ? matrix->ncols : matrix->nrows;
32 int *nblocks_per_row = calloc(nrows, sizeof(int));
33 float *row_max_eps = malloc(nrows * sizeof(float));
34 assert((nblocks_per_row != NULL && row_max_eps != NULL) || nrows == 0);
35
36#pragma omp parallel
37 {
38#pragma omp for
39 for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
40 dbm_shard_t *shard = &matrix->shards[ishard];
41 for (int iblock = 0; iblock < shard->nblocks; iblock++) {
42 const dbm_block_t *blk = &shard->blocks[iblock];
43 const int row = (trans) ? blk->col : blk->row;
44#pragma omp atomic
45 ++nblocks_per_row[row];
46 }
47 }
48#pragma omp master
49 cp_mpi_sum_int(nblocks_per_row, nrows, matrix->dist->comm);
50#pragma omp barrier
51#pragma omp for
52 for (int i = 0; i < nrows; i++) {
53 const float f =
54 ((float)filter_eps) / ((float)imax(1, nblocks_per_row[i]));
55 row_max_eps[i] = f * f;
56 }
57 } // end of omp parallel region
58
59 free(nblocks_per_row);
60 return row_max_eps; // Ownership of row_max_eps transfers to caller.
61}
62
63/*******************************************************************************
64 * \brief Private struct for storing the context of the multiplication backend.
65 * \author Ole Schuett
66 ******************************************************************************/
67typedef struct {
68#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
69 dbm_multiply_gpu_context_t gpu;
70#endif
71 int cpu_options; // Binary or'ed dbm_multiply_cpu_options (enum).
73
74/*******************************************************************************
75 * \brief Private routine for initializing the multiplication backend.
76 * \author Ole Schuett
77 ******************************************************************************/
79 backend_context_t *const ctx = calloc(1, sizeof(backend_context_t));
80 // BLAS and LIBXS benefit in general from DBM_MULTIPLY_TASK_REORDER.
82
83#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
84 dbm_multiply_gpu_start(DBM_MAX_BATCH_SIZE, dbm_get_num_shards(matrix_c),
85 matrix_c->shards, &ctx->gpu);
86#else
87 (void)matrix_c; // mark as used
88#endif
89
90 return ctx;
91}
92
93/*******************************************************************************
94 * \brief Private routine for handing newly arrived packs to the backend.
95 * \author Ole Schuett
96 ******************************************************************************/
97static bool backend_upload_packs(const dbm_pack_t *pack_a,
98 const dbm_pack_t *pack_b,
99 backend_context_t *ctx) {
100#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
101 return dbm_multiply_gpu_upload_packs(pack_a, pack_b, &ctx->gpu);
102#else
103 (void)pack_a; // mark as used
104 (void)pack_b;
105 (void)ctx;
106 return false;
107#endif
108}
109
110/*******************************************************************************
111 * \brief Private routine for sending a batch to the multiplication backend.
112 * \author Ole Schuett
113 ******************************************************************************/
114static void backend_process_batch(const int ntasks,
115 const dbm_task_t batch[ntasks],
116 const double alpha, const dbm_pack_t *pack_a,
117 const dbm_pack_t *pack_b, const int kshard,
118 dbm_shard_t *shard_c, const bool finish,
119 const bool force_cpu,
120 backend_context_t *ctx) {
121 if (NULL != ctx) {
122#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
123 if (!force_cpu) {
124 dbm_multiply_gpu_process_batch(ntasks, batch, alpha, shard_c, kshard,
125 finish, &ctx->gpu);
126 } else
127#endif
128 {
129 (void)kshard;
130 (void)finish;
131 (void)force_cpu;
132 dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
133 shard_c, ctx->cpu_options);
134 }
135 } else { // Validate against host (aka CPU).
136 dbm_multiply_cpu_process_batch(ntasks, batch, alpha, pack_a, pack_b,
138 }
139}
140
141/*******************************************************************************
142 * \brief Private routine for shutting down the multiplication backend.
143 * \author Ole Schuett
144 ******************************************************************************/
146#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_DBM)
147 dbm_multiply_gpu_stop(&ctx->gpu);
148#endif
149 free(ctx);
150}
151
152/*******************************************************************************
153 * \brief Private routine for multiplying two packs (C += alpha * A * B).
154 *
155 * Blocks in each pack are grouped by shard (free_index % nshards) and sorted
156 * by sum_index within each group. The algorithm:
157 * 1. Builds shard-boundary lookup tables for A (rows) and B (cols).
158 * 2. For each (shard_row, shard_col) pair, determines the contiguous A and B
159 * block ranges belonging to that shard.
160 * 3. Performs a merge-join over sum_index: advances A and B cursors in
161 *lockstep, caching the B sub-range for each sum_index so that multiple A blocks
162 *with the same sum_index reuse it without rescanning.
163 * 4. Applies a norm-based filter (alpha^2 * norm_a * norm_b < eps) for early
164 * rejection before looking up or allocating the C block.
165 * 5. Accumulates matching pairs into a batched GEMM task list, flushing to the
166 * backend (CPU or GPU) every DBM_MAX_BATCH_SIZE tasks.
167 *
168 * \author Ole Schuett and Hans Pabst
169 ******************************************************************************/
170static void multiply_packs(const bool transa, const bool transb,
171 const double alpha, const dbm_pack_t *pack_a,
172 const dbm_pack_t *pack_b,
173 const dbm_matrix_t *matrix_a,
174 const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
175 const float *rows_max_eps,
176 const bool retain_sparsity, const bool force_cpu,
177 int64_t *flop, backend_context_t *ctx) {
178 // For validation, FLOPS do not count, and relying on ctx is not necessary.
179 backend_context_t *const context = (NULL != flop ? ctx : NULL);
180 const float alpha2 = (float)(alpha * alpha);
181 int64_t flop_sum = 0;
182
183 const int nshard_rows = matrix_c->dist->rows.nshards;
184 const int nshard_cols = matrix_c->dist->cols.nshards;
185 int *shard_row_start = calloc(nshard_rows, sizeof(int));
186 int *shard_col_start = calloc(nshard_cols, sizeof(int));
187 assert(NULL != shard_row_start && NULL != shard_col_start);
188
189 const int *sum_index_sizes_a =
190 (transa) ? matrix_a->row_sizes : matrix_a->col_sizes;
191 const int *sum_index_sizes_b =
192 (transb) ? matrix_b->col_sizes : matrix_b->row_sizes;
193 const int *free_index_sizes_a =
194 (transa) ? matrix_a->col_sizes : matrix_a->row_sizes;
195 const int *free_index_sizes_b =
196 (transb) ? matrix_b->row_sizes : matrix_b->col_sizes;
197
198#pragma omp parallel reduction(+ : flop_sum)
199 {
200 // Thread-private array covering given work in piece-wise fashion.
201 dbm_task_t *batch =
203
204 // Blocks are ordered first by shard. Creating lookup tables of boundaries.
205#pragma omp for nowait
206 for (int iblock = 1; iblock < pack_a->nblocks; iblock++) {
207 const int shard_row = pack_a->blocks[iblock].free_index % nshard_rows;
208 const int prev_shard_row =
209 pack_a->blocks[iblock - 1].free_index % nshard_rows;
210 if (prev_shard_row != shard_row) {
211 shard_row_start[shard_row] = iblock;
212 }
213 }
214#pragma omp for
215 for (int jblock = 1; jblock < pack_b->nblocks; jblock++) {
216 const int shard_col = pack_b->blocks[jblock].free_index % nshard_cols;
217 const int prev_shard_col =
218 pack_b->blocks[jblock - 1].free_index % nshard_cols;
219 if (prev_shard_col != shard_col) {
220 shard_col_start[shard_col] = jblock;
221 }
222 }
223
224#pragma omp for collapse(2) DBM_OMP_SCHEDULE
225 for (int shard_row = 0; shard_row < nshard_rows; shard_row++) {
226 for (int shard_col = 0; shard_col < nshard_cols; shard_col++) {
227 const int ishard = shard_row * nshard_cols + shard_col;
228 dbm_shard_t *const shard_c = &matrix_c->shards[ishard];
229 int ntasks = 0;
230
231 // Determine contiguous block ranges for this shard in A and B.
232 // Use a merge-join to find pairs of blocks with matching sum indices.
233 // This utilizes that blocks within a shard are ordered by sum_index.
234 const int iblock_start = shard_row_start[shard_row];
235 int iblock_end = pack_a->nblocks;
236 for (int t = iblock_start; t < pack_a->nblocks; ++t) {
237 if (pack_a->blocks[t].free_index % nshard_rows != shard_row) {
238 iblock_end = t;
239 break;
240 }
241 }
242 const int jblock_start = shard_col_start[shard_col];
243 int jblock_end = pack_b->nblocks;
244 for (int t = jblock_start; t < pack_b->nblocks; ++t) {
245 if (pack_b->blocks[t].free_index % nshard_cols != shard_col) {
246 jblock_end = t;
247 break;
248 }
249 }
250 if (iblock_start >= iblock_end || jblock_start >= jblock_end) {
251 backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
252 shard_c, true, force_cpu, context);
253 continue;
254 }
255
256 // Merge over sum_index (both ranges sorted by sum_index).
257 // Cache the B sub-range for each sum_index so that multiple A blocks
258 // sharing the same sum_index reuse it without re-scanning B.
259 int i = iblock_start, j = jblock_start, last_sum_index = -1;
260 int b_range_start = -1, b_range_end = -1;
261
262 while (i < iblock_end) {
263 const dbm_pack_block_t *blk_a = &pack_a->blocks[i];
264 const int sum_a = blk_a->sum_index;
265
266 // Advance j until sum_b >= sum_a.
267 while (j < jblock_end && pack_b->blocks[j].sum_index < sum_a) {
268 ++j;
269 }
270 if (j >= jblock_end) {
271 break; // No more matches possible.
272 }
273
274 const int sum_b = pack_b->blocks[j].sum_index;
275 if (sum_b > sum_a) {
276 ++i;
277 continue; // Need next A block with higher sum_index.
278 }
279
280 // sum_a == sum_b: establish (or reuse) B range with this sum_index.
281 if (sum_a != last_sum_index) {
282 b_range_start = j;
283 int t = j + 1;
284 while (t < jblock_end && pack_b->blocks[t].sum_index == sum_a) {
285 ++t;
286 }
287 b_range_end = t;
288 last_sum_index = sum_a;
289 }
290
291 // Iterate over B blocks in current sum_index range.
292 for (int jb = b_range_start; jb < b_range_end; ++jb) {
293 const dbm_pack_block_t *const blk_b = &pack_b->blocks[jb];
294
295 // Norm filter first (early reject).
296 const float result_norm = alpha2 * blk_a->norm * blk_b->norm;
297 if (result_norm < rows_max_eps[blk_a->free_index]) {
298 continue;
299 }
300
301 // Check block sizes.
302 const int m = free_index_sizes_a[blk_a->free_index];
303 const int n = free_index_sizes_b[blk_b->free_index];
304 const int k = sum_index_sizes_a[sum_a];
305 assert(m == matrix_c->row_sizes[blk_a->free_index]);
306 assert(n == matrix_c->col_sizes[blk_b->free_index]);
307 assert(k == sum_index_sizes_b[blk_b->sum_index]);
308
309 if (m == 0 || n == 0 || k == 0) {
310 continue;
311 }
312
313 // Get C block.
314 const int row = blk_a->free_index, col = blk_b->free_index;
315 dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row, col);
316 if (blk_c == NULL) {
317 if (retain_sparsity) {
318 continue;
319 }
320 assert(dbm_get_shard_index(matrix_c, row, col) == ishard);
321 assert(dbm_get_stored_coordinates(matrix_c, row, col) ==
322 matrix_c->dist->my_rank);
323 blk_c = dbm_shard_promise_new_block(shard_c, row, col, m * n);
324 }
325
326 // Count flops.
327 const int64_t task_flops = 2LL * m * n * k;
328 flop_sum += task_flops;
330
331 // Add block multiplication to batch.
332 dbm_task_t *const tptr = &batch[ntasks];
333 tptr->offset_a = blk_a->offset;
334 tptr->offset_b = blk_b->offset;
335 tptr->offset_c = blk_c->offset;
336 tptr->m = m;
337 tptr->n = n;
338 tptr->k = k;
339 ++ntasks;
340
341 if (ntasks == DBM_MAX_BATCH_SIZE) {
342 backend_process_batch(ntasks, batch, alpha, pack_a, pack_b,
343 ishard, shard_c, false, force_cpu, context);
344 ntasks = 0;
345 }
346 }
347
348 // Advance i; if next A block has same sum_index, B range is reused.
349 ++i;
350 }
351 backend_process_batch(ntasks, batch, alpha, pack_a, pack_b, ishard,
352 shard_c, true, force_cpu, context);
353 }
354 }
355
357 }
358
359 free(shard_row_start);
360 free(shard_col_start);
361
362 if (NULL != flop) {
363 *flop += flop_sum;
364 }
365}
366
367/*******************************************************************************
368 * \brief Performs a multiplication of two dbm_matrix_t matrices.
369 * See dbm_matrix.h for details.
370 * \author Ole Schuett
371 ******************************************************************************/
372void dbm_multiply(const bool transa, const bool transb, const double alpha,
373 const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b,
374 const double beta, dbm_matrix_t *matrix_c,
375 const bool retain_sparsity, const double filter_eps,
376 int64_t *flop) {
377 assert(omp_get_num_threads() == 1);
378 assert(matrix_a != NULL && matrix_b != NULL && matrix_c != NULL);
379
380 // Throughout the matrix multiplication code the "sum_index" and "free_index"
381 // denote the summation (aka dummy) and free index from the Einstein notation.
382 const int num_sum_index_a = (transa) ? matrix_a->nrows : matrix_a->ncols;
383 const int num_sum_index_b = (transb) ? matrix_b->ncols : matrix_b->nrows;
384 const int num_free_index_a = (transa) ? matrix_a->ncols : matrix_a->nrows;
385 const int num_free_index_b = (transb) ? matrix_b->nrows : matrix_b->ncols;
386
387 // Sanity check matrix dimensions.
388 assert(num_sum_index_a == num_sum_index_b);
389 assert(num_free_index_a == matrix_c->nrows);
390 assert(num_free_index_b == matrix_c->ncols);
391
392 // Prepare matrix_c (host).
393 dbm_scale(matrix_c, beta);
394
395 // Determine if validation shall be performed.
396 const char *const maxeps_env = getenv("DBM_MULTIPLY_MAXEPS");
397 const char *const verify_env = getenv("DBM_MULTIPLY_VERIFY");
398 const double maxeps = (NULL == maxeps_env ? 1E-1 : fabs(atof(maxeps_env)));
399 const int verify =
400 (NULL == verify_env ? (NULL == maxeps_env ? 0 : 1) : atoi(verify_env));
401 dbm_matrix_t *matrix_d = NULL;
402 if (0 != verify) {
403 dbm_distribution_t *const dist_shared = matrix_c->dist;
404 dbm_create(&matrix_d, dist_shared, matrix_c->name, matrix_c->nrows,
405 matrix_c->ncols, matrix_c->row_sizes, matrix_c->col_sizes);
406 dbm_copy(matrix_d, matrix_c);
407 }
408
409 // Compute filter thresholds for each row.
410 float *rows_max_eps = compute_rows_max_eps(transa, matrix_a, filter_eps);
411
412 // Start uploading matrix_c to the GPU.
413 backend_context_t *ctx = backend_start(matrix_c);
414
415 // Redistribute matrix_a and matrix_b across MPI ranks.
416 dbm_comm_iterator_t *iter =
417 dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_c);
418
419 // Count flops if requested.
420 if (NULL != flop) {
421 *flop = 0;
422 }
423
424 // Main loop.
425 dbm_pack_t *pack_a, *pack_b;
426 while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
427 const bool uploaded = backend_upload_packs(pack_a, pack_b, ctx);
428 (void)uploaded; // mark used
429 multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
430 matrix_c, rows_max_eps, retain_sparsity, false /*!uploaded*/,
431 flop, ctx);
432 }
433
434 // Wait for all other MPI ranks to complete, then release ressources.
436 backend_stop(ctx);
437
438 if (NULL != matrix_d) {
439 ctx = backend_start(matrix_d);
440 iter =
441 dbm_comm_iterator_start(transa, transb, matrix_a, matrix_b, matrix_d);
442 while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
443 multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
444 matrix_d, rows_max_eps, retain_sparsity, true, NULL, ctx);
445 }
447 backend_stop(ctx);
448 const double epsilon = dbm_maxeps(matrix_d, matrix_c);
449 if (maxeps < epsilon) {
450 if (1 == verify) {
451 fprintf(stderr, "WARN ACC/LIBDBM: diff=%g\n", epsilon);
452 } else {
453 fprintf(stderr, "ERROR ACC/LIBDBM: diff=%g\n", epsilon);
454 exit(EXIT_FAILURE);
455 }
456 }
457 dbm_release(matrix_d);
458 }
459
460 // Release filter thresholds.
461 free(rows_max_eps);
462
463 // Final filter pass.
464 dbm_filter(matrix_c, filter_eps);
465}
466
467// EOF
void cp_mpi_sum_int(int *values, const int count, const cp_mpi_comm_t comm)
Wrapper around MPI_Allreduce for op MPI_SUM and datatype MPI_INT.
Definition cp_mpi.c:317
#define DBM_MAX_BATCH_SIZE
static int imax(int x, int y)
Returns the larger of two given integers (missing from the C standard)
void dbm_library_counter_increment(const int m, const int n, const int k)
Add given block multiplication to stats. This routine is thread-safe.
Definition dbm_library.c:97
double dbm_maxeps(const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b)
Calculates maximum relative difference between matrix_a and matrix_b.
Definition dbm_matrix.c:565
static int dbm_get_shard_index(const dbm_matrix_t *matrix, const int row, const int col)
Internal routine for getting a block's shard index.
Definition dbm_matrix.h:245
static int dbm_get_num_shards(const dbm_matrix_t *matrix)
Internal routine that returns the number of shards for given matrix.
Definition dbm_matrix.h:237
static float * compute_rows_max_eps(const bool trans, const dbm_matrix_t *matrix, const double filter_eps)
Private routine for computing the max filter threshold for each row.
static backend_context_t * backend_start(const dbm_matrix_t *matrix_c)
Private routine for initializing the multiplication backend.
static bool backend_upload_packs(const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, backend_context_t *ctx)
Private routine for handing newly arrived packs to the backend.
static void backend_stop(backend_context_t *ctx)
Private routine for shutting down the multiplication backend.
static void multiply_packs(const bool transa, const bool transb, const double alpha, const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c, const float *rows_max_eps, const bool retain_sparsity, const bool force_cpu, int64_t *flop, backend_context_t *ctx)
Private routine for multiplying two packs (C += alpha * A * B).
static void backend_process_batch(const int ntasks, const dbm_task_t batch[ntasks], const double alpha, const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, const int kshard, dbm_shard_t *shard_c, const bool finish, const bool force_cpu, backend_context_t *ctx)
Private routine for sending a batch to the multiplication backend.
dbm_comm_iterator_t * dbm_comm_iterator_start(const bool transa, const bool transb, const dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b, const dbm_matrix_t *matrix_c)
Internal routine for creating a communication iterator.
void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter)
Internal routine for releasing the given communication iterator.
bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a, dbm_pack_t **pack_b)
Internal routine for retrieving next pair of packs of given iterator.
void dbm_multiply_cpu_process_batch(int ntasks, const dbm_task_t batch[ntasks], double alpha, const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, dbm_shard_t *shard_c, int options)
Internal routine for executing the tasks in given batch on the CPU.
@ DBM_MULTIPLY_BLAS_LIBRARY
@ DBM_MULTIPLY_TASK_REORDER
dbm_block_t * dbm_shard_promise_new_block(dbm_shard_t *shard, const int row, const int col, const int block_size)
Internal routine for allocating the metadata of a new block.
Definition dbm_shard.c:205
dbm_block_t * dbm_shard_lookup(const dbm_shard_t *shard, const int row, const int col)
Internal routine for looking up a block from a shard.
Definition dbm_shard.c:181
static void const int const int i
void offload_mempool_host_free(const void *memory)
Internal routine for releasing memory back to the pool.
void * offload_mempool_host_malloc(const size_t size)
Internal routine for allocating host memory from the pool.
Private struct for storing the context of the multiplication backend.
Internal struct for storing a block's metadata.
Definition dbm_shard.h:20
Internal struct for storing a communication iterator.
Internal struct for storing a two dimensional distribution.
Internal struct for storing a matrix.
Definition dbm_matrix.h:19
int * row_sizes
Definition dbm_matrix.h:24
char * name
Definition dbm_matrix.h:21
int * col_sizes
Definition dbm_matrix.h:25
dbm_shard_t * shards
Definition dbm_matrix.h:27
dbm_distribution_t * dist
Definition dbm_matrix.h:20
Internal struct for storing a dbm_block_t plus its norm.
Internal struct for storing a pack - essentially a shard for MPI.
dbm_pack_block_t * blocks
Internal struct for storing a matrix shard.
Definition dbm_shard.h:30
dbm_block_t * blocks
Definition dbm_shard.h:33
Internal struct for storing a task, ie. a single block multiplication.