(git:33f85d8)
Loading...
Searching...
No Matches
local_gemm_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
9 USE iso_c_binding, ONLY: c_null_ptr, &
10 c_ptr
11#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
13 USE iso_c_binding, ONLY: c_associated, &
14 c_loc
15 USE spla, ONLY: spla_pu_host, &
16 spla_pu_gpu, &
17 spla_op_none, &
18 spla_op_transpose, &
19 spla_op_conj_transpose, &
20 spla_ctx_create, &
21 spla_ctx_destroy, &
22 spla_dgemm, &
23 spla_sgemm, &
24 spla_cgemm, &
25 spla_zgemm, &
26 spla_ctx_set_op_threshold_gpu, &
27 spla_success
28#endif
29
31
32#include "./base/base_uses.f90"
33
34 IMPLICIT NONE
35
36 PRIVATE
37
38 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
39
40 PUBLIC :: local_gemm_ctxt_type, &
42
43 INTEGER, PARAMETER, PUBLIC :: &
46
47 INTEGER, PRIVATE :: do_dgemm = 1
48
50 TYPE(c_ptr) :: spla_context = c_null_ptr
51 CONTAINS
52 PROCEDURE, pass(ctx), non_overridable :: create => local_gemm_create
53 PROCEDURE, pass(ctx), non_overridable :: destroy => local_gemm_destroy
54 PROCEDURE, pass(ctx), non_overridable :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
55 PROCEDURE, pass(ctx), non_overridable :: gemm => local_gemm
56 END TYPE
57
58CONTAINS
59
60! **************************************************************************************************
61!> \brief ...
62!> \param opA ...
63!> \param opB ...
64!> \param m ...
65!> \param n ...
66!> \param k ...
67!> \param alpha ...
68!> \param A ...
69!> \param lda ...
70!> \param B ...
71!> \param ldb ...
72!> \param beta ...
73!> \param C ...
74!> \param ldc ...
75!> \param ctx ...
76! **************************************************************************************************
77 SUBROUTINE local_gemm(opA, opB, m, n, k, &
78 alpha, A, lda, B, ldb, &
79 beta, C, ldc, ctx)
80 CHARACTER, INTENT(in) :: opA
81 CHARACTER, INTENT(in) :: opB
82 INTEGER, INTENT(in) :: m
83 INTEGER, INTENT(in) :: n
84 INTEGER, INTENT(in) :: k
85 REAL(8), INTENT(in) :: alpha
86#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
87 REAL(8), DIMENSION(*), INTENT(in), TARGET :: A
88#else
89 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: A
90#endif
91 INTEGER, INTENT(in) :: lda
92#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
93 REAL(8), DIMENSION(*), INTENT(in), TARGET :: B
94#else
95 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: B
96#endif
97
98 INTEGER, INTENT(in) :: ldb
99 REAL(8), INTENT(in) :: beta
100#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
101 REAL(8), DIMENSION(*), INTENT(inout), TARGET ::C
102#else
103 REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: C
104#endif
105 INTEGER, INTENT(in) :: ldc
106 CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
107
108 INTEGER :: handle
109! no point of using SPLA offloading on CPU ONLY nodes
110#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111 INTEGER :: spla_op_A, spla_op_B, spla_error
112#endif
113 CHARACTER(LEN=*), PARAMETER :: routineN = 'local_gemm'
114 CALL timeset(routinen, handle)
115
116! no point of using SPLA offloading on CPU ONLY nodes
117#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
118 IF (do_dgemm == do_dgemm_spla) THEN
119
120 IF (opa == 'N') spla_op_a = spla_op_none
121 IF (opa == 'T') spla_op_a = spla_op_transpose
122
123 IF (opb == 'N') spla_op_b = spla_op_none
124 IF (opb == 'T') spla_op_b = spla_op_transpose
125
126#if __GNUC__ >= 9
127 cpassert(is_contiguous(a))
128 cpassert(is_contiguous(b))
129 cpassert(is_contiguous(c))
130#endif
131
133 spla_error = spla_dgemm(spla_op_a, spla_op_b, &
134 m, n, k, alpha, &
135 c_loc(a), lda, &
136 c_loc(b), ldb, &
137 beta, c_loc(c), ldc, ctx%spla_context)
138 cpassert(spla_error == spla_success)
139 ELSE
140#endif
141 CALL dgemm(opa, opb, m, n, k, alpha, &
142 a, lda, &
143 b, ldb, beta, c, ldc)
144#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
145 END IF
146#else
147 mark_used(ctx)
148#endif
149 CALL timestop(handle)
150
151 END SUBROUTINE local_gemm
152
153! **************************************************************************************************
154!> \brief create a context for handling gemm offloading
155!> \param ctx newly created context
156!> \param pu processing unit to run the (s,d,c,z}dgemm
157! **************************************************************************************************
158 SUBROUTINE local_gemm_create(ctx, pu)
159 CLASS(local_gemm_ctxt_type), INTENT(out) :: ctx
160 INTEGER, INTENT(in) :: pu
161
162#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
163 INTEGER :: error_
164
165 IF (.NOT. c_associated(ctx%spla_context)) THEN
166 IF (do_dgemm == do_dgemm_spla) THEN
168
169 error_ = spla_ctx_create(ctx%spla_context, pu)
170 cpassert(error_ == spla_success)
171 ELSE
172 ctx%spla_context = c_null_ptr
173 END IF
174 END IF
175#else
176 mark_used(pu)
177 ctx%spla_context = c_null_ptr
178#endif
179 END SUBROUTINE local_gemm_create
180
181! **************************************************************************************************
182!> \brief release resources associated to a gemm context
183!> \param ctx handle
184! **************************************************************************************************
185 SUBROUTINE local_gemm_destroy(ctx)
186 CLASS(local_gemm_ctxt_type), INTENT(inout) :: ctx
187
188#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
189 INTEGER :: error_
190
191 IF (do_dgemm == do_dgemm_spla) THEN
193
194 error_ = spla_ctx_destroy(ctx%spla_context)
195 cpassert(error_ == spla_success)
196 END IF
197#endif
198 ctx%spla_context = c_null_ptr
199 END SUBROUTINE local_gemm_destroy
200
201! **************************************************************************************************
202!> \brief ...
203!> \param ctx ...
204!> \param opThresholdGPU ...
205! **************************************************************************************************
206 SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
207 CLASS(local_gemm_ctxt_type), INTENT(INOUT) :: ctx
208 INTEGER, INTENT(in) :: opThresholdGPU
209
210#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
211 INTEGER :: error__
212
214 error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opthresholdgpu)
215#else
216 mark_used(ctx)
217 mark_used(opthresholdgpu)
218#endif
219 END SUBROUTINE local_gemm_set_op_threshold_gpu
220
221! **************************************************************************************************
222!> \brief ...
223!> \param dgemm_library ...
224! **************************************************************************************************
225 SUBROUTINE local_gemm_set_library(dgemm_library)
226 INTEGER, INTENT(IN) :: dgemm_library
227
228 do_dgemm = dgemm_library
229 END SUBROUTINE local_gemm_set_library
230
231END MODULE local_gemm_api
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
collects all constants needed in input so that they can be used without circular dependencies
integer, parameter, public do_dgemm_spla
subroutine, public local_gemm_set_library(dgemm_library)
...
integer, parameter, public local_gemm_pu_gpu
subroutine local_gemm(opa, opb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ctx)
...
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
Definition offload_api.F:12
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()