(git:e7e05ae)
dbm_multiply_cpu.c
Go to the documentation of this file.
1 /*----------------------------------------------------------------------------*/
2 /* CP2K: A general program to perform molecular dynamics simulations */
3 /* Copyright 2000-2024 CP2K developers group <https://cp2k.org> */
4 /* */
5 /* SPDX-License-Identifier: BSD-3-Clause */
6 /*----------------------------------------------------------------------------*/
7 
8 #include <assert.h>
9 #include <stddef.h>
10 #include <string.h>
11 
12 #if defined(__LIBXSMM)
13 #include <libxsmm.h>
14 #if !defined(DBM_LIBXSMM_PREFETCH)
15 //#define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_AL2_AHEAD
16 #define DBM_LIBXSMM_PREFETCH LIBXSMM_GEMM_PREFETCH_NONE
17 #endif
18 #if LIBXSMM_VERSION4(1, 17, 0, 3710) > LIBXSMM_VERSION_NUMBER
19 #define libxsmm_dispatch_gemm libxsmm_dispatch_gemm_v2
20 #endif
21 #endif
22 
23 #include "dbm_hyperparams.h"
24 #include "dbm_multiply_cpu.h"
25 
26 /*******************************************************************************
27  * \brief Prototype for BLAS dgemm.
28  * \author Ole Schuett
29  ******************************************************************************/
30 void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
31  const int *k, const double *alpha, const double *a, const int *lda,
32  const double *b, const int *ldb, const double *beta, double *c,
33  const int *ldc);
34 
35 /*******************************************************************************
36  * \brief Private convenient wrapper to hide Fortran nature of dgemm_.
37  * \author Ole Schuett
38  ******************************************************************************/
39 static inline void dbm_dgemm(const char transa, const char transb, const int m,
40  const int n, const int k, const double alpha,
41  const double *a, const int lda, const double *b,
42  const int ldb, const double beta, double *c,
43  const int ldc) {
44 
45  dgemm_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c,
46  &ldc);
47 }
48 
49 /*******************************************************************************
50  * \brief Private hash function based on Szudzik's elegant pairing.
51  * Using unsigned int to return a positive number even after overflow.
52  * https://en.wikipedia.org/wiki/Pairing_function#Other_pairing_functions
53  * https://stackoverflow.com/a/13871379
54  * http://szudzik.com/ElegantPairing.pdf
55  * \author Ole Schuett
56  ******************************************************************************/
57 static inline unsigned int hash(const dbm_task_t task) {
58  const unsigned int m = task.m, n = task.n, k = task.k;
59  const unsigned int mn = (m >= n) ? m * m + m + n : m + n * n;
60  const unsigned int mnk = (mn >= k) ? mn * mn + mn + k : mn + k * k;
61  return mnk;
62 }
63 
64 /*******************************************************************************
65  * \brief Internal routine for executing the tasks in given batch on the CPU.
66  * \author Ole Schuett
67  ******************************************************************************/
68 void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks],
69  const double alpha,
70  const dbm_pack_t *pack_a,
71  const dbm_pack_t *pack_b,
72  dbm_shard_t *shard_c) {
73 
74  if (0 >= ntasks) { // nothing to do
75  return;
76  }
78 
79 #if defined(__LIBXSMM)
80 
81  // Sort tasks approximately by m,n,k via bucket sort.
82  int buckets[BATCH_NUM_BUCKETS];
83  memset(buckets, 0, BATCH_NUM_BUCKETS * sizeof(int));
84  for (int itask = 0; itask < ntasks; ++itask) {
85  const int i = hash(batch[itask]) % BATCH_NUM_BUCKETS;
86  ++buckets[i];
87  }
88  for (int i = 1; i < BATCH_NUM_BUCKETS; ++i) {
89  buckets[i] += buckets[i - 1];
90  }
91  assert(buckets[BATCH_NUM_BUCKETS - 1] == ntasks);
92  int batch_order[ntasks];
93  for (int itask = 0; itask < ntasks; ++itask) {
94  const int i = hash(batch[itask]) % BATCH_NUM_BUCKETS;
95  --buckets[i];
96  batch_order[buckets[i]] = itask;
97  }
98 
99  // Prepare arguments for libxsmm's kernel-dispatch.
100  const int flags = LIBXSMM_GEMM_FLAG_TRANS_B; // transa = "N", transb = "T"
101  const int prefetch = DBM_LIBXSMM_PREFETCH;
102  int kernel_m = 0, kernel_n = 0, kernel_k = 0;
103  dbm_task_t task_next = batch[batch_order[0]];
104 
105 #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
106  double *data_a_next = NULL, *data_b_next = NULL, *data_c_next = NULL;
107 #endif
108 #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
109  libxsmm_gemmfunction kernel_func = NULL;
110 #else
111  libxsmm_dmmfunction kernel_func = NULL;
112  const double beta = 1.0;
113 #endif
114 
115  // Loop over tasks.
116  for (int itask = 0; itask < ntasks; ++itask) {
117  const dbm_task_t task = task_next;
118  task_next = batch[batch_order[(itask + 1) < ntasks ? (itask + 1) : itask]];
119 
120  if (task.m != kernel_m || task.n != kernel_n || task.k != kernel_k) {
121 #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
122  const libxsmm_gemm_shape shape = libxsmm_create_gemm_shape(
123  task.m, task.n, task.k, task.m /*lda*/, task.n /*ldb/transb*/,
124  task.m /*ldc*/, LIBXSMM_DATATYPE_F64 /*aprec*/,
125  LIBXSMM_DATATYPE_F64 /*bprec*/, LIBXSMM_DATATYPE_F64 /*cprec*/,
126  LIBXSMM_DATATYPE_F64 /*calcp*/);
127  kernel_func = (LIBXSMM_FEQ(1.0, alpha)
128  ? libxsmm_dispatch_gemm(shape, (libxsmm_bitfield)flags,
129  (libxsmm_bitfield)prefetch)
130  : NULL);
131 #else
132  kernel_func = libxsmm_dmmdispatch(task.m, task.n, task.k, NULL /*lda*/,
133  NULL /*ldb*/, NULL /*ldc*/, &alpha,
134  &beta, &flags, &prefetch);
135 #endif
136  kernel_m = task.m;
137  kernel_n = task.n;
138  kernel_k = task.k;
139  }
140 
141  // gemm_param wants non-const data even for A and B
142  double *const data_a = pack_a->data + task.offset_a;
143  double *const data_b = pack_b->data + task.offset_b;
144  double *const data_c = shard_c->data + task.offset_c;
145 
146  if (kernel_func != NULL) {
147 #if LIBXSMM_VERSION2(1, 17) < LIBXSMM_VERSION_NUMBER
148  libxsmm_gemm_param gemm_param;
149  gemm_param.a.primary = data_a;
150  gemm_param.b.primary = data_b;
151  gemm_param.c.primary = data_c;
152 #if (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
153  gemm_param.a.quaternary = pack_a->data + task_next.offset_a;
154  gemm_param.b.quaternary = pack_b->data + task_next.offset_b;
155  gemm_param.c.quaternary = shard_c->data + task_next.offset_c;
156 #endif
157  kernel_func(&gemm_param);
158 #elif (LIBXSMM_GEMM_PREFETCH_NONE != DBM_LIBXSMM_PREFETCH)
159  kernel_func(data_a, data_b, data_c, pack_a->data + task_next.offset_a,
160  pack_b->data + task_next.offset_b,
161  shard_c->data + task_next.offset_c);
162 #else
163  kernel_func(data_a, data_b, data_c);
164 #endif
165  } else {
166  dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
167  task.n, 1.0, data_c, task.m);
168  }
169  }
170 #else
171  // Fallback to BLAS when libxsmm is not available.
172  for (int itask = 0; itask < ntasks; ++itask) {
173  const dbm_task_t task = batch[itask];
174  const double *data_a = &pack_a->data[task.offset_a];
175  const double *data_b = &pack_b->data[task.offset_b];
176  double *data_c = &shard_c->data[task.offset_c];
177  dbm_dgemm('N', 'T', task.m, task.n, task.k, alpha, data_a, task.m, data_b,
178  task.n, 1.0, data_c, task.m);
179  }
180 #endif
181 }
182 
183 // EOF
static const int BATCH_NUM_BUCKETS
static void dbm_dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Private convenient wrapper to hide Fortran nature of dgemm_.
static unsigned int hash(const dbm_task_t task)
Private hash function based on Szudzik's elegant pairing. Using unsigned int to return a positive num...
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
Prototype for BLAS dgemm.
void dbm_multiply_cpu_process_batch(const int ntasks, dbm_task_t batch[ntasks], const double alpha, const dbm_pack_t *pack_a, const dbm_pack_t *pack_b, dbm_shard_t *shard_c)
Internal routine for executing the tasks in given batch on the CPU.
void dbm_shard_allocate_promised_blocks(dbm_shard_t *shard)
Internal routine for allocating and zeroing any promised block's data.
Definition: dbm_shard.c:203
static void const int const int i
real(dp), dimension(3) c
Definition: ai_eri_debug.F:31
real(dp), dimension(3) a
Definition: ai_eri_debug.F:31
real(dp), dimension(3) b
Definition: ai_eri_debug.F:31
Internal struct for storing a pack - essentially a shard for MPI.
Internal struct for storing a matrix shard.
Definition: dbm_shard.h:30
double * data
Definition: dbm_shard.h:44
Internal struct for storing a task, ie. a single block multiplication.