(git:ed6f26b)
Loading...
Searching...
No Matches
dbm_multiply_cpu.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: BSD-3-Clause */
6/*----------------------------------------------------------------------------*/
7
8#include <assert.h>
9#include <stddef.h>
10#include <string.h>
11
12#if defined(__LIBXSMM)
13#include <libxsmm.h>
14#if !defined(DBM_LIBXSMM_PREFETCH)
15// #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
16#define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
17#endif
18#if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
19#define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
20#endif
21#endif
22
23#include "dbm_hyperparams.h"
24#include "dbm_multiply_cpu.h"
25
26/*******************************************************************************
27 * \brief Prototype for BLAS dgemm.
28 * \author Ole Schuett
29 ******************************************************************************/
30void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
31 const int *k, const double *alpha, const double *a, const int *lda,
32 const double *b, const int *ldb, const double *beta, double *c,
33 const int *ldc);
34
35/*******************************************************************************
36 * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
37 * \author Ole Schuett
38 ******************************************************************************/
39static inline void dbm_dgemm(const char transa, const char transb, const int m,
40 const int n, const int k, const double alpha,
41 const double *a, const int lda, const double *b,
42 const int ldb, const double beta, double *c,
43 const int ldc) {
44
45 dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
46 &ldc);
47}
48
49/*******************************************************************************
50 * \brief Private hash function based on Szudzik's elegant pairing.
51 * Using unsigned int to return a positive number even after overflow.
52 * https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
53 * https://stackoverflow.com/a/13871379
54 * http://szudzik.com/ElegantPairing.pdf
55 * \author Ole Schuett
56 ******************************************************************************/
57#if defined(__LIBXSMM)
58static inline unsigned int hash(const dbm_task_t task) {
59 const unsigned int m = task.m, n = task.n, k = task.k;
60 const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
61 const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
62 return mnk;
63}
64#endif
65
66/*******************************************************************************
67 * \brief Internal routine for executing the tasks in given batch on the CPU.
68 * \author Ole Schuett
69 ******************************************************************************/
70void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks],
71 const double alpha,
72 const dbm_pack_t *pack_a,
73 const dbm_pack_t *pack_b,
74 dbm_shard_t *shard_c) {
75
76 if (0 >= ntasks) { // nothing to do
77 return;
78 }
80
81#if defined(__LIBXSMM)
82
83 // Sort tasks approximately by m,n,k via bucket sort.
84 int buckets[DBM_BATCH_NUM_BUCKETS] = {0};
85 for (int itask = 0; itask < ntasks; ++itask) {
86 const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
87 ++buckets[i];
88 }
89 for (int i = 1; i < DBM_BATCH_NUM_BUCKETS; ++i) {
90 buckets[i] += buckets[i - 1];
91 }
92 assert(buckets[DBM_BATCH_NUM_BUCKETS - 1] == ntasks);
93 int batch_order[ntasks];
94 for (int itask = 0; itask < ntasks; ++itask) {
95 const int i = hash(batch[itask]) % DBM_BATCH_NUM_BUCKETS;
96 --buckets[i];
97 batch_order[buckets[i]] = itask;
98 }
99
100 // Prepare arguments for libxsmm's kernel-dispatch.
101 const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
102 const int prefetch = DBM_LIBXSMM_PREFETCH;
103 int kernel_m = 0, kernel_n = 0, kernel_k = 0;
104 dbm_task_t task_next = batch[batch_order[0]];
105
106#if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
107 double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
108#endif
109#if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
110 libxsmm_gemmfunction kernel_func = NULL;
111#else
112 libxsmm_dmmfunction kernel_func = NULL;
113 const double beta = 1.0;
114#endif
115
116 // Loop over tasks.
117 for (int itask = 0; itask < ntasks; ++itask) {
118 const dbm_task_t task = task_next;
119 task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
120
121 if (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k) {
122 if (LIBXSMM_SMM(task.m, task.n, task.m, 1 /*assume in-$, no RFO*/,
123 sizeof(double))) {
124#if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
125 const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
126 task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb*/,
127 task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
128 LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
129 LIBXSMM_DATATYPE_F64 /*calcp*/);
130 kernel_func =
131 (LIBXSMM_FEQ(1.0, alpha)
132 ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
133 (libxsmm_bitfield)prefetch)
134 : NULL);
135#else
136 kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
137 NULL /*ldb*/, NULL /*ldc*/, &alpha,
138 &beta, &flags, &prefetch);
139#endif
140 } else {
141 kernel_func = NULL;
142 }
143 kernel_m = task.m;
144 kernel_n = task.n;
145 kernel_k = task.k;
146 }
147
148 // gemm_param wants non-const data even for A and B
149 double *const data_a = pack_a->data + task.offset_a;
150 double *const data_b = pack_b->data + task.offset_b;
151 double *const data_c = shard_c->data + task.offset_c;
152
153 if (kernel_func != NULL) {
154#if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
155 libxsmm_gemm_param gemm_param;
156 gemm_param.a.primary = data_a;
157 gemm_param.b.primary = data_b;
158 gemm_param.c.primary = data_c;
159#if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
160 gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
161 gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
162 gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
163#endif
164 kernel_func(&gemm_param);
165#elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
166 kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
167 pack_b->data + task_next.offset_b,
168 shard_c->data + task_next.offset_c);
169#else
170 kernel_func(data_a, data_b, data_c);
171#endif
172 } else {
173 dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
174 task.n, 1.0, data_c, task.m);
175 }
176 }
177#else
178 // Fallback to BLAS when libxsmm is not available.
179 for (int itask = 0; itask < ntasks; ++itask) {
180 const dbm_task_t task = batch[itask];
181 const double *data_a = &pack_a->data[task.offset_a];
182 const double *data_b = &pack_b->data[task.offset_b];
183 double *data_c = &shard_c->data[task.offset_c];
184 dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
185 task.n, 1.0, data_c, task.m);
186 }
187#endif
188}
189
190// EOF
#define DBM_BATCH_NUM_BUCKETS
static void dbm_dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Private convenient wrapper to hide Fortran nature of dgemm_.
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
Prototype for BLAS dgemm.
void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks], const double alpha, const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, dbm_shard_t *shard_c)
Private hash function based on Szudzik's elegant pairing. Using unsigned int to return a positive num...
void dbm_shard_allocate_promised_blocks(dbm_shard_t *shard)
Internal routine for allocating and zeroing any promised block's data.
Definition dbm_shard.c:231
static unsigned int hash(const unsigned int row, const unsigned int col)
Private hash function based on Cantor pairing function. https://en.wikipedia.org/wiki/Pairing_functio...
Definition dbm_shard.c:139
static void const int const int i
Internal struct for storing a pack - essentially a shard for MPI.
double * data
Internal struct for storing a matrix shard.
Definition dbm_shard.h:30
double * data
Definition dbm_shard.h:43
Internal struct for storing a task, ie. a single block multiplication.