(git:936074a)
Loading...
Searching...
No Matches
local_gemm_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
9 USE iso_c_binding, ONLY: c_null_ptr, &
10 c_ptr
11#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
13 USE iso_c_binding, ONLY: c_associated, &
14 c_loc
15 USE spla, ONLY: spla_pu_host, &
16 spla_pu_gpu, &
17 spla_op_none, &
18 spla_op_transpose, &
19 spla_op_conj_transpose, &
20 spla_ctx_create, &
21 spla_ctx_destroy, &
22 spla_dgemm, &
23 spla_sgemm, &
24 spla_cgemm, &
25 spla_zgemm, &
26 spla_ctx_set_op_threshold_gpu, &
27 spla_success
28#endif
29
32
33#include "./base/base_uses.f90"
34
35 IMPLICIT NONE
36
37 PRIVATE
38
39 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
40
41 PUBLIC :: local_gemm_ctxt_type, &
43
44 INTEGER, PARAMETER, PUBLIC :: &
47
48 INTEGER, PRIVATE :: do_dgemm = 1
49
51 TYPE(c_ptr) :: spla_context = c_null_ptr
52 CONTAINS
53 PROCEDURE, pass(ctx), non_overridable :: create => local_gemm_create
54 PROCEDURE, pass(ctx), non_overridable :: destroy => local_gemm_destroy
55 PROCEDURE, pass(ctx), non_overridable :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
56 PROCEDURE, pass(ctx), non_overridable :: gemm => local_gemm
57 END TYPE
58
59CONTAINS
60
61! **************************************************************************************************
62!> \brief ...
63!> \param opA ...
64!> \param opB ...
65!> \param m ...
66!> \param n ...
67!> \param k ...
68!> \param alpha ...
69!> \param A ...
70!> \param lda ...
71!> \param B ...
72!> \param ldb ...
73!> \param beta ...
74!> \param C ...
75!> \param ldc ...
76!> \param ctx ...
77! **************************************************************************************************
78 SUBROUTINE local_gemm(opA, opB, m, n, k, &
79 alpha, A, lda, B, ldb, &
80 beta, C, ldc, ctx)
81 CHARACTER, INTENT(in) :: opA
82 CHARACTER, INTENT(in) :: opB
83 INTEGER, INTENT(in) :: m
84 INTEGER, INTENT(in) :: n
85 INTEGER, INTENT(in) :: k
86 REAL(8), INTENT(in) :: alpha
87#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
88 REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
89#else
90 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
91#endif
92 INTEGER, INTENT(in) :: lda
93#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94 REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
95#else
96 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
97#endif
98
99 INTEGER, INTENT(in) :: ldb
100 REAL(8), INTENT(in) :: beta
101#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
102 REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
103#else
104 REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
105#endif
106 INTEGER, INTENT(in) :: ldc
107 CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
108
109 INTEGER :: handle
110! no point of using SPLA offloading on CPU ONLY nodes
111#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
112 INTEGER :: spla_op_A, spla_op_B, spla_error
113#endif
114 CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
115 CALL timeset(routinen, handle)
116
117! no point of using SPLA offloading on CPU ONLY nodes
118#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
119 IF (do_dgemm == do_dgemm_spla) THEN
120
121 IF (opa == 'N') spla_op_a = spla_op_none
122 IF (opa == 'T') spla_op_a = spla_op_transpose
123
124 IF (opb == 'N') spla_op_b = spla_op_none
125 IF (opb == 'T') spla_op_b = spla_op_transpose
126
127#if __GNUC__ >= 9
128 cpassert(is_contiguous(a))
129 cpassert(is_contiguous(b))
130 cpassert(is_contiguous(c))
131#endif
132
134 spla_error = spla_dgemm(spla_op_a, spla_op_b, &
135 m, n, k, alpha, &
136 c_loc(a), lda, &
137 c_loc(b), ldb, &
138 beta, c_loc(c), ldc, ctx%spla_context)
139 IF (spla_error /= spla_success) &
140 cpabort("spla_dgemm failed: "//cp_to_string(spla_error))
141 ELSE
142#endif
143 CALL dgemm(opa, opb, m, n, k, alpha, &
144 a, lda, &
145 b, ldb, beta, c, ldc)
146#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
147 END IF
148#else
149 mark_used(ctx)
150#endif
151 CALL timestop(handle)
152
153 END SUBROUTINE local_gemm
154
155! **************************************************************************************************
156!> \brief create a context for handling gemm offloading
157!> \param ctx newly created context
158!> \param pu processing unit to run the (s,d,c,z}dgemm
159! **************************************************************************************************
160 SUBROUTINE local_gemm_create(ctx, pu)
161 CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
162 INTEGER, INTENT(in) :: pu
163
164#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
165 INTEGER :: error_
166
167 IF (.NOT. c_associated(ctx%spla_context)) THEN
168 IF (do_dgemm == do_dgemm_spla) THEN
170
171 error_ = spla_ctx_create(ctx%spla_context, pu)
172 IF (error_ /= spla_success) &
173 cpabort("spla_ctx_create failed: "//cp_to_string(error_))
174 ELSE
175 ctx%spla_context = c_null_ptr
176 END IF
177 END IF
178#else
179 mark_used(pu)
180 ctx%spla_context = c_null_ptr
181#endif
182 END SUBROUTINE local_gemm_create
183
184! **************************************************************************************************
185!> \brief release resources associated to a gemm context
186!> \param ctx handle
187! **************************************************************************************************
188 SUBROUTINE local_gemm_destroy(ctx)
189 CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
190
191#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
192 INTEGER :: error_
193
194 IF (do_dgemm == do_dgemm_spla) THEN
196
197 error_ = spla_ctx_destroy(ctx%spla_context)
198 IF (error_ /= spla_success) &
199 cpabort("spla_ctx_destroy failed: "//cp_to_string(error_))
200 END IF
201#endif
202 ctx%spla_context = c_null_ptr
203 END SUBROUTINE local_gemm_destroy
204
205! **************************************************************************************************
206!> \brief ...
207!> \param ctx ...
208!> \param opThresholdGPU ...
209! **************************************************************************************************
210 SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
211 CLASS(local_gemm_ctxt_type), INTENT(INOUT) :: ctx
212 INTEGER, INTENT(in) :: opThresholdGPU
213
214#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
215 INTEGER :: error__
216
218 error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opthresholdgpu)
219#else
220 mark_used(ctx)
221 mark_used(opthresholdgpu)
222#endif
223 END SUBROUTINE local_gemm_set_op_threshold_gpu
224
225! **************************************************************************************************
226!> \brief ...
227!> \param dgemm_library ...
228! **************************************************************************************************
229 SUBROUTINE local_gemm_set_library(dgemm_library)
230 INTEGER, INTENT(IN) :: dgemm_library
231
232 do_dgemm = dgemm_library
233 END SUBROUTINE local_gemm_set_library
234
235END MODULE local_gemm_api
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
various routines to log and control the output. The idea is that decisions about where to log should ...
collects all constants needed in input so that they can be used without circular dependencies
integer, parameter, public do_dgemm_spla
subroutine, public local_gemm_set_library(dgemm_library)
...
integer, parameter, public local_gemm_pu_gpu
subroutine local_gemm(opa, opb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ctx)
...
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
Definition offload_api.F:12
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()