(git:0de0cc2)
local_gemm_api.F
Go to the documentation of this file.
1 !--------------------------------------------------------------------------------------------------!
2 ! CP2K: A general program to perform molecular dynamics simulations !
3 ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 ! !
5 ! SPDX-License-Identifier: GPL-2.0-or-later !
6 !--------------------------------------------------------------------------------------------------!
7 
9  USE iso_c_binding, ONLY: c_loc, &
10  c_null_ptr, &
11  c_ptr
12 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
14  USE spla, ONLY: spla_pu_host, &
15  spla_pu_gpu, &
16  spla_op_none, &
17  spla_op_transpose, &
18  spla_op_conj_transpose, &
19  spla_ctx_create, &
20  spla_ctx_destroy, &
21  spla_dgemm, &
22  spla_sgemm, &
23  spla_cgemm, &
24  spla_zgemm, &
25  spla_ctx_set_op_threshold_gpu, &
26  spla_success
27 #endif
28 
30 
31 #include "./base/base_uses.f90"
32 
33  IMPLICIT NONE
34 
35  PRIVATE
36 
37  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'local_gemm_api'
38 
39  PUBLIC :: local_gemm, &
44 
45  INTEGER, PARAMETER, PUBLIC :: &
46  local_gemm_pu_host = 0, &
48 
49  INTEGER, PRIVATE :: do_dgemm = 1
50 
51 CONTAINS
52 
53 ! **************************************************************************************************
54 !> \brief ...
55 !> \param opA ...
56 !> \param opB ...
57 !> \param m ...
58 !> \param n ...
59 !> \param k ...
60 !> \param alpha ...
61 !> \param A ...
62 !> \param lda ...
63 !> \param B ...
64 !> \param ldb ...
65 !> \param beta ...
66 !> \param C ...
67 !> \param ldc ...
68 !> \param ctx ...
69 ! **************************************************************************************************
70  SUBROUTINE local_gemm(opA, opB, m, n, k, &
71  alpha, A, lda, B, ldb, &
72  beta, C, ldc, ctx)
73  CHARACTER, INTENT(in) :: opa
74  CHARACTER, INTENT(in) :: opb
75  INTEGER, INTENT(in) :: m
76  INTEGER, INTENT(in) :: n
77  INTEGER, INTENT(in) :: k
78  REAL(8), INTENT(in) :: alpha
79 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
80  REAL(8), DIMENSION(*), INTENT(in), TARGET :: a
81 #else
82  REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: a
83 #endif
84  INTEGER, INTENT(in) :: lda
85 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
86  REAL(8), DIMENSION(*), INTENT(in), TARGET :: b
87 #else
88  REAL(8), DIMENSION(:, :), INTENT(in), TARGET :: b
89 #endif
90 
91  INTEGER, INTENT(in) :: ldb
92  REAL(8), INTENT(in) :: beta
93 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94  REAL(8), DIMENSION(*), INTENT(inout), TARGET ::c
95 #else
96  REAL(8), DIMENSION(:, :), INTENT(inout), TARGET :: c
97 #endif
98  INTEGER, INTENT(in) :: ldc
99  TYPE(c_ptr), OPTIONAL, INTENT(inout) :: ctx
100 
101  INTEGER :: handle
102 ! no point of using SPLA offloading on CPU ONLY nodes
103 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
104  INTEGER :: spla_op_a, spla_op_b, spla_error
105 #endif
106  CHARACTER(LEN=*), PARAMETER :: routinen = 'local_gemm'
107  CALL timeset(routinen, handle)
108 
109 ! no point of using SPLA offloading on CPU ONLY nodes
110 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111  IF (PRESENT(ctx) .AND. do_dgemm == do_dgemm_spla) THEN
112 
113  IF (opa == 'N') spla_op_a = spla_op_none
114  IF (opa == 'T') spla_op_a = spla_op_transpose
115 
116  IF (opb == 'N') spla_op_b = spla_op_none
117  IF (opb == 'T') spla_op_b = spla_op_transpose
118 
119 #if __GNUC__ >= 9
120  cpassert(is_contiguous(a))
121  cpassert(is_contiguous(b))
122  cpassert(is_contiguous(c))
123 #endif
124 
126  spla_error = spla_dgemm(spla_op_a, spla_op_b, &
127  m, n, k, alpha, &
128  c_loc(a), lda, &
129  c_loc(b), ldb, &
130  beta, c_loc(c), ldc, ctx)
131  cpassert(spla_error == spla_success)
132  ELSE
133 #endif
134  CALL dgemm(opa, opb, m, n, k, alpha, &
135  a, lda, &
136  b, ldb, beta, c, ldc)
137 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
138  END IF
139 #else
140  mark_used(ctx)
141 #endif
142  CALL timestop(handle)
143 
144  END SUBROUTINE local_gemm
145 
146 ! **************************************************************************************************
147 !> \brief create a context for handling gemm offloading
148 !> \param ctx newly created context
149 !> \param pu processing unit to run the (s,d,c,z}dgemm
150 ! **************************************************************************************************
151  SUBROUTINE local_gemm_create(ctx, pu)
152  TYPE(c_ptr), INTENT(out) :: ctx
153  INTEGER, INTENT(in) :: pu
154 
155 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
156  INTEGER :: error_
157 
158  IF (do_dgemm == do_dgemm_spla) THEN
160 
161  error_ = spla_ctx_create(ctx, pu)
162  cpassert(error_ == spla_success)
163  ELSE
164  ctx = c_null_ptr
165  END IF
166 #else
167  mark_used(pu)
168  mark_used(ctx)
169  ctx = c_null_ptr
170 #endif
171  END SUBROUTINE local_gemm_create
172 
173 ! **************************************************************************************************
174 !> \brief release resources associated to a gemm context
175 !> \param ctx handle
176 ! **************************************************************************************************
177  SUBROUTINE local_gemm_destroy(ctx)
178  TYPE(c_ptr), INTENT(inout) :: ctx
179 
180 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
181  INTEGER :: error_
182 
183  IF (do_dgemm == do_dgemm_spla) THEN
185 
186  error_ = spla_ctx_destroy(ctx)
187  cpassert(error_ == spla_success)
188  END IF
189 #else
190  mark_used(ctx)
191 #endif
192  ctx = c_null_ptr
193  END SUBROUTINE local_gemm_destroy
194 
195 ! **************************************************************************************************
196 !> \brief ...
197 !> \param ctx ...
198 !> \param opThresholdGPU ...
199 ! **************************************************************************************************
200  SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
201  TYPE(c_ptr) :: ctx
202  INTEGER, INTENT(in) :: opthresholdgpu
203 
204 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
205  INTEGER :: error__
206 
208  error__ = spla_ctx_set_op_threshold_gpu(ctx, opthresholdgpu)
209 #else
210  mark_used(ctx)
211  mark_used(opthresholdgpu)
212 #endif
213  END SUBROUTINE local_gemm_set_op_threshold_gpu
214 
215 ! **************************************************************************************************
216 !> \brief ...
217 !> \param dgemm_library ...
218 ! **************************************************************************************************
219  SUBROUTINE local_gemm_set_library(dgemm_library)
220  INTEGER, INTENT(IN) :: dgemm_library
221 
222  do_dgemm = dgemm_library
223  END SUBROUTINE local_gemm_set_library
224 
225 END MODULE local_gemm_api
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
collects all constants needed in input so that they can be used without circular dependencies
integer, parameter, public do_dgemm_spla
subroutine, public local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
...
subroutine, public local_gemm_set_library(dgemm_library)
...
integer, parameter, public local_gemm_pu_gpu
subroutine, public local_gemm_destroy(ctx)
release resources associated to a gemm context
subroutine, public local_gemm_create(ctx, pu)
create a context for handling gemm offloading
subroutine, public local_gemm(opA, opB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, ctx)
...
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
Definition: offload_api.F:12
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()
Definition: offload_api.F:174