(git:374b731)
Loading...
Searching...
No Matches
local_gemm_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
9 USE iso_c_binding, ONLY: c_loc, &
10 c_null_ptr, &
11 c_ptr
12#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
14 USE spla, ONLY: spla_pu_host, &
15 spla_pu_gpu, &
16 spla_op_none, &
17 spla_op_transpose, &
18 spla_op_conj_transpose, &
19 spla_ctx_create, &
20 spla_ctx_destroy, &
21 spla_dgemm, &
22 spla_sgemm, &
23 spla_cgemm, &
24 spla_zgemm, &
25 spla_ctx_set_op_threshold_gpu, &
26 spla_success
27#endif
28
30
31#include "./base/base_uses.f90"
32
33 IMPLICIT NONE
34
35 PRIVATE
36
37 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
38
39 PUBLIC :: local_gemm, &
44
45 INTEGER, PARAMETER, PUBLIC :: &
48
49 INTEGER, PRIVATE :: do_dgemm = 1
50
51CONTAINS
52
53! **************************************************************************************************
54!> \brief ...
55!> \param opA ...
56!> \param opB ...
57!> \param m ...
58!> \param n ...
59!> \param k ...
60!> \param alpha ...
61!> \param A ...
62!> \param lda ...
63!> \param B ...
64!> \param ldb ...
65!> \param beta ...
66!> \param C ...
67!> \param ldc ...
68!> \param ctx ...
69! **************************************************************************************************
70 SUBROUTINE local_gemm(opA, opB, m, n, k, &
71 alpha, A, lda, B, ldb, &
72 beta, C, ldc, ctx)
73 CHARACTER, INTENT(in) :: opa
74 CHARACTER, INTENT(in) :: opb
75 INTEGER, INTENT(in) :: m
76 INTEGER, INTENT(in) :: n
77 INTEGER, INTENT(in) :: k
78 REAL(8), INTENT(in) :: alpha
79#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
80 REAL(8), DIMENSION(*), INTENT(in), TARGET :: a
81#else
82 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: a
83#endif
84 INTEGER, INTENT(in) :: lda
85#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
86 REAL(8), DIMENSION(*), INTENT(in), TARGET :: b
87#else
88 REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: b
89#endif
90
91 INTEGER, INTENT(in) :: ldb
92 REAL(8), INTENT(in) :: beta
93#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94 REAL(8), DIMENSION(*), INTENT(inout), TARGET ::c
95#else
96 REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: c
97#endif
98 INTEGER, INTENT(in) :: ldc
99 TYPE(c_ptr), OPTIONAL, INTENT(inout) :: ctx
100
101 INTEGER :: handle
102! no point of using SPLA offloading on CPU ONLY nodes
103#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
104 INTEGER :: spla_op_a, spla_op_b, spla_error
105#endif
106 CHARACTER(LEN=*), PARAMETER :: routinen = 'local_gemm'
107 CALL timeset(routinen, handle)
108
109! no point of using SPLA offloading on CPU ONLY nodes
110#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111 IF (PRESENT(ctx) .AND. do_dgemm == do_dgemm_spla) THEN
112
113 IF (opa == 'N') spla_op_a = spla_op_none
114 IF (opa == 'T') spla_op_a = spla_op_transpose
115
116 IF (opb == 'N') spla_op_b = spla_op_none
117 IF (opb == 'T') spla_op_b = spla_op_transpose
118
119#if __GNUC__ >= 9
120 cpassert(is_contiguous(a))
121 cpassert(is_contiguous(b))
122 cpassert(is_contiguous(c))
123#endif
124
126 spla_error = spla_dgemm(spla_op_a, spla_op_b, &
127 m, n, k, alpha, &
128 c_loc(a), lda, &
129 c_loc(b), ldb, &
130 beta, c_loc(c), ldc, ctx)
131 cpassert(spla_error == spla_success)
132 ELSE
133#endif
134 CALL dgemm(opa, opb, m, n, k, alpha, &
135 a, lda, &
136 b, ldb, beta, c, ldc)
137#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
138 END IF
139#else
140 mark_used(ctx)
141#endif
142 CALL timestop(handle)
143
144 END SUBROUTINE local_gemm
145
146! **************************************************************************************************
147!> \brief create a context for handling gemm offloading
148!> \param ctx newly created context
149!> \param pu processing unit to run the (s,d,c,z}dgemm
150! **************************************************************************************************
151 SUBROUTINE local_gemm_create(ctx, pu)
152 TYPE(c_ptr), INTENT(out) :: ctx
153 INTEGER, INTENT(in) :: pu
154
155#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
156 INTEGER :: error_
157
158 IF (do_dgemm == do_dgemm_spla) THEN
160
161 error_ = spla_ctx_create(ctx, pu)
162 cpassert(error_ == spla_success)
163 ELSE
164 ctx = c_null_ptr
165 END IF
166#else
167 mark_used(pu)
168 mark_used(ctx)
169 ctx = c_null_ptr
170#endif
171 END SUBROUTINE local_gemm_create
172
173! **************************************************************************************************
174!> \brief release resources associated to a gemm context
175!> \param ctx handle
176! **************************************************************************************************
177 SUBROUTINE local_gemm_destroy(ctx)
178 TYPE(c_ptr), INTENT(inout) :: ctx
179
180#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
181 INTEGER :: error_
182
183 IF (do_dgemm == do_dgemm_spla) THEN
185
186 error_ = spla_ctx_destroy(ctx)
187 cpassert(error_ == spla_success)
188 END IF
189#else
190 mark_used(ctx)
191#endif
192 ctx = c_null_ptr
193 END SUBROUTINE local_gemm_destroy
194
195! **************************************************************************************************
196!> \brief ...
197!> \param ctx ...
198!> \param opThresholdGPU ...
199! **************************************************************************************************
200 SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
201 TYPE(c_ptr) :: ctx
202 INTEGER, INTENT(in) :: opthresholdgpu
203
204#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
205 INTEGER :: error__
206
208 error__ = spla_ctx_set_op_threshold_gpu(ctx, opthresholdgpu)
209#else
210 mark_used(ctx)
211 mark_used(opthresholdgpu)
212#endif
214
215! **************************************************************************************************
216!> \brief ...
217!> \param dgemm_library ...
218! **************************************************************************************************
219 SUBROUTINE local_gemm_set_library(dgemm_library)
220 INTEGER, INTENT(IN) :: dgemm_library
221
222 do_dgemm = dgemm_library
223 END SUBROUTINE local_gemm_set_library
224
225END MODULE local_gemm_api
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
collects all constants needed in input so that they can be used without circular dependencies
integer, parameter, public do_dgemm_spla
subroutine, public local_gemm_set_library(dgemm_library)
...
subroutine, public local_gemm(opa, opb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ctx)
...
integer, parameter, public local_gemm_pu_gpu
subroutine, public local_gemm_set_op_threshold_gpu(ctx, opthresholdgpu)
...
subroutine, public local_gemm_destroy(ctx)
release resources associated to a gemm context
subroutine, public local_gemm_create(ctx, pu)
create a context for handling gemm offloading
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
Definition offload_api.F:12
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()