(git:0de0cc2)
parallel_gemm_api.F
Go to the documentation of this file.
1 !--------------------------------------------------------------------------------------------------!
2 ! CP2K: A general program to perform molecular dynamics simulations !
3 ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 ! !
5 ! SPDX-License-Identifier: GPL-2.0-or-later !
6 !--------------------------------------------------------------------------------------------------!
7 
8 ! **************************************************************************************************
9 !> \brief basic linear algebra operations for full matrixes
10 !> \par History
11 !> 08.2002 splitted out of qs_blacs [fawzi]
12 !> \author Fawzi Mohamed
13 ! **************************************************************************************************
15  USE iso_c_binding, ONLY: c_char,&
16  c_double,&
17  c_int,&
18  c_loc,&
19  c_ptr
21  USE cp_cfm_types, ONLY: cp_cfm_type
23  USE cp_fm_types, ONLY: cp_fm_get_mm_type,&
24  cp_fm_type
25  USE input_constants, ONLY: do_cosma,&
27  USE kinds, ONLY: dp
29 #include "./base/base_uses.f90"
30 
31  IMPLICIT NONE
32  PRIVATE
33 
34  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'parallel_gemm_api'
35 
36  PUBLIC :: parallel_gemm
37 
38  INTERFACE parallel_gemm
39  MODULE PROCEDURE parallel_gemm_fm
40  MODULE PROCEDURE parallel_gemm_cfm
41  END INTERFACE parallel_gemm
42 
43 CONTAINS
44 
45 ! **************************************************************************************************
46 !> \brief ...
47 !> \param transa ...
48 !> \param transb ...
49 !> \param m ...
50 !> \param n ...
51 !> \param k ...
52 !> \param alpha ...
53 !> \param matrix_a ...
54 !> \param matrix_b ...
55 !> \param beta ...
56 !> \param matrix_c ...
57 !> \param a_first_col ...
58 !> \param a_first_row ...
59 !> \param b_first_col ...
60 !> \param b_first_row ...
61 !> \param c_first_col ...
62 !> \param c_first_row ...
63 ! **************************************************************************************************
64  SUBROUTINE parallel_gemm_fm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
65  matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
66  c_first_col, c_first_row)
67  CHARACTER(LEN=1), INTENT(IN) :: transa, transb
68  INTEGER, INTENT(IN) :: m, n, k
69  REAL(KIND=dp), INTENT(IN) :: alpha
70  TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
71  REAL(KIND=dp), INTENT(IN) :: beta
72  TYPE(cp_fm_type), INTENT(IN) :: matrix_c
73  INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
74  b_first_row, c_first_col, c_first_row
75 
76  CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_fm'
77 
78  INTEGER :: handle, handle1, my_multi
79 
80  CALL timeset(routinen, handle)
81 
82  my_multi = cp_fm_get_mm_type()
83 
84  SELECT CASE (my_multi)
85  CASE (do_scalapack)
86  CALL timeset(routinen//"_gemm", handle1)
87  CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
88  a_first_col=a_first_col, &
89  a_first_row=a_first_row, &
90  b_first_col=b_first_col, &
91  b_first_row=b_first_row, &
92  c_first_col=c_first_col, &
93  c_first_row=c_first_row)
94  CALL timestop(handle1)
95  CASE (do_cosma)
96 #if defined(__COSMA)
97  CALL timeset(routinen//"_cosma", handle1)
99  CALL cosma_pdgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
100  matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
101  a_first_col=a_first_col, &
102  a_first_row=a_first_row, &
103  b_first_col=b_first_col, &
104  b_first_row=b_first_row, &
105  c_first_col=c_first_col, &
106  c_first_row=c_first_row)
107  CALL timestop(handle1)
108 #else
109  cpabort("CP2K compiled without the COSMA library.")
110 #endif
111  END SELECT
112  CALL timestop(handle)
113 
114  END SUBROUTINE parallel_gemm_fm
115 
116 ! **************************************************************************************************
117 !> \brief ...
118 !> \param transa ...
119 !> \param transb ...
120 !> \param m ...
121 !> \param n ...
122 !> \param k ...
123 !> \param alpha ...
124 !> \param matrix_a ...
125 !> \param matrix_b ...
126 !> \param beta ...
127 !> \param matrix_c ...
128 !> \param a_first_col ...
129 !> \param a_first_row ...
130 !> \param b_first_col ...
131 !> \param b_first_row ...
132 !> \param c_first_col ...
133 !> \param c_first_row ...
134 ! **************************************************************************************************
135  SUBROUTINE parallel_gemm_cfm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
136  matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
137  c_first_col, c_first_row)
138  CHARACTER(LEN=1), INTENT(IN) :: transa, transb
139  INTEGER, INTENT(IN) :: m, n, k
140  COMPLEX(KIND=dp), INTENT(IN) :: alpha
141  TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
142  COMPLEX(KIND=dp), INTENT(IN) :: beta
143  TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
144  INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
145  b_first_row, c_first_col, c_first_row
146 
147  CHARACTER(len=*), PARAMETER :: routineN = 'parallel_gemm_cfm'
148 
149  INTEGER :: handle, handle1, my_multi
150 
151  CALL timeset(routinen, handle)
152 
153  my_multi = cp_fm_get_mm_type()
154 
155  SELECT CASE (my_multi)
156  CASE (do_scalapack)
157  CALL timeset(routinen//"_gemm", handle1)
158  CALL cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
159  a_first_col=a_first_col, &
160  a_first_row=a_first_row, &
161  b_first_col=b_first_col, &
162  b_first_row=b_first_row, &
163  c_first_col=c_first_col, &
164  c_first_row=c_first_row)
165  CALL timestop(handle1)
166  CASE (do_cosma)
167 #if defined(__COSMA)
168  CALL timeset(routinen//"_cosma", handle1)
170  CALL cosma_pzgemm(transa=transa, transb=transb, m=m, n=n, k=k, alpha=alpha, &
171  matrix_a=matrix_a, matrix_b=matrix_b, beta=beta, matrix_c=matrix_c, &
172  a_first_col=a_first_col, &
173  a_first_row=a_first_row, &
174  b_first_col=b_first_col, &
175  b_first_row=b_first_row, &
176  c_first_col=c_first_col, &
177  c_first_row=c_first_row)
178  CALL timestop(handle1)
179 #else
180  cpabort("CP2K compiled without the COSMA library.")
181 #endif
182  END SELECT
183  CALL timestop(handle)
184 
185  END SUBROUTINE parallel_gemm_cfm
186 
187 #if defined(__COSMA)
188 ! **************************************************************************************************
189 !> \brief Fortran wrapper for cosma_pdgemm.
190 !> \param transa ...
191 !> \param transb ...
192 !> \param m ...
193 !> \param n ...
194 !> \param k ...
195 !> \param alpha ...
196 !> \param matrix_a ...
197 !> \param matrix_b ...
198 !> \param beta ...
199 !> \param matrix_c ...
200 !> \param a_first_col ...
201 !> \param a_first_row ...
202 !> \param b_first_col ...
203 !> \param b_first_row ...
204 !> \param c_first_col ...
205 !> \param c_first_row ...
206 !> \author Ole Schuett
207 ! **************************************************************************************************
208  SUBROUTINE cosma_pdgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
209  a_first_col, a_first_row, b_first_col, b_first_row, &
210  c_first_col, c_first_row)
211  CHARACTER(LEN=1), INTENT(IN) :: transa, transb
212  INTEGER, INTENT(IN) :: m, n, k
213  REAL(KIND=dp), INTENT(IN) :: alpha
214  TYPE(cp_fm_type), INTENT(IN) :: matrix_a, matrix_b
215  REAL(KIND=dp), INTENT(IN) :: beta
216  TYPE(cp_fm_type), INTENT(IN) :: matrix_c
217  INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
218  b_first_row, c_first_col, c_first_row
219 
220  INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
221  INTERFACE
222  SUBROUTINE cosma_pdgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
223  b, ib, jb, descb, beta, c, ic, jc, descc) &
224  BIND(C, name="cosma_pdgemm")
225  IMPORT :: c_ptr, c_int, c_double, c_char
226  CHARACTER(KIND=C_CHAR) :: transa
227  CHARACTER(KIND=C_CHAR) :: transb
228  INTEGER(KIND=C_INT) :: m
229  INTEGER(KIND=C_INT) :: n
230  INTEGER(KIND=C_INT) :: k
231  REAL(KIND=c_double) :: alpha
232  TYPE(C_PTR), VALUE :: a
233  INTEGER(KIND=C_INT) :: ia
234  INTEGER(KIND=C_INT) :: ja
235  TYPE(C_PTR), VALUE :: desca
236  TYPE(C_PTR), VALUE :: b
237  INTEGER(KIND=C_INT) :: ib
238  INTEGER(KIND=C_INT) :: jb
239  TYPE(C_PTR), VALUE :: descb
240  REAL(KIND=c_double) :: beta
241  TYPE(C_PTR), VALUE :: c
242  INTEGER(KIND=C_INT) :: ic
243  INTEGER(KIND=C_INT) :: jc
244  TYPE(C_PTR), VALUE :: descc
245  END SUBROUTINE cosma_pdgemm_c
246  END INTERFACE
247 
248  IF (PRESENT(a_first_row)) THEN
249  i_a = a_first_row
250  ELSE
251  i_a = 1
252  END IF
253  IF (PRESENT(a_first_col)) THEN
254  j_a = a_first_col
255  ELSE
256  j_a = 1
257  END IF
258  IF (PRESENT(b_first_row)) THEN
259  i_b = b_first_row
260  ELSE
261  i_b = 1
262  END IF
263  IF (PRESENT(b_first_col)) THEN
264  j_b = b_first_col
265  ELSE
266  j_b = 1
267  END IF
268  IF (PRESENT(c_first_row)) THEN
269  i_c = c_first_row
270  ELSE
271  i_c = 1
272  END IF
273  IF (PRESENT(c_first_col)) THEN
274  j_c = c_first_col
275  ELSE
276  j_c = 1
277  END IF
278 
279  CALL cosma_pdgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
280  alpha=alpha, &
281  a=c_loc(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
282  desca=c_loc(matrix_a%matrix_struct%descriptor(1)), &
283  b=c_loc(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
284  descb=c_loc(matrix_b%matrix_struct%descriptor(1)), &
285  beta=beta, &
286  c=c_loc(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
287  descc=c_loc(matrix_c%matrix_struct%descriptor(1)))
288 
289  END SUBROUTINE cosma_pdgemm
290 
291 ! **************************************************************************************************
292 !> \brief Fortran wrapper for cosma_pdgemm.
293 !> \param transa ...
294 !> \param transb ...
295 !> \param m ...
296 !> \param n ...
297 !> \param k ...
298 !> \param alpha ...
299 !> \param matrix_a ...
300 !> \param matrix_b ...
301 !> \param beta ...
302 !> \param matrix_c ...
303 !> \param a_first_col ...
304 !> \param a_first_row ...
305 !> \param b_first_col ...
306 !> \param b_first_row ...
307 !> \param c_first_col ...
308 !> \param c_first_row ...
309 !> \author Ole Schuett
310 ! **************************************************************************************************
311  SUBROUTINE cosma_pzgemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
312  a_first_col, a_first_row, b_first_col, b_first_row, &
313  c_first_col, c_first_row)
314  CHARACTER(LEN=1), INTENT(IN) :: transa, transb
315  INTEGER, INTENT(IN) :: m, n, k
316  COMPLEX(KIND=dp), INTENT(IN) :: alpha
317  TYPE(cp_cfm_type), INTENT(IN) :: matrix_a, matrix_b
318  COMPLEX(KIND=dp), INTENT(IN) :: beta
319  TYPE(cp_cfm_type), INTENT(IN) :: matrix_c
320  INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, &
321  b_first_row, c_first_col, c_first_row
322 
323  INTEGER :: i_a, i_b, i_c, j_a, j_b, j_c
324  REAL(KIND=dp), DIMENSION(2), TARGET :: alpha_t, beta_t
325  INTERFACE
326  SUBROUTINE cosma_pzgemm_c(transa, transb, m, n, k, alpha, a, ia, ja, desca, &
327  b, ib, jb, descb, beta, c, ic, jc, descc) &
328  BIND(C, name="cosma_pzgemm")
329  IMPORT :: c_ptr, c_int, c_char
330  CHARACTER(KIND=C_CHAR) :: transa
331  CHARACTER(KIND=C_CHAR) :: transb
332  INTEGER(KIND=C_INT) :: m
333  INTEGER(KIND=C_INT) :: n
334  INTEGER(KIND=C_INT) :: k
335  TYPE(C_PTR), VALUE :: alpha
336  TYPE(C_PTR), VALUE :: a
337  INTEGER(KIND=C_INT) :: ia
338  INTEGER(KIND=C_INT) :: ja
339  TYPE(C_PTR), VALUE :: desca
340  TYPE(C_PTR), VALUE :: b
341  INTEGER(KIND=C_INT) :: ib
342  INTEGER(KIND=C_INT) :: jb
343  TYPE(C_PTR), VALUE :: descb
344  TYPE(C_PTR), VALUE :: beta
345  TYPE(C_PTR), VALUE :: c
346  INTEGER(KIND=C_INT) :: ic
347  INTEGER(KIND=C_INT) :: jc
348  TYPE(C_PTR), VALUE :: descc
349  END SUBROUTINE cosma_pzgemm_c
350  END INTERFACE
351 
352  IF (PRESENT(a_first_row)) THEN
353  i_a = a_first_row
354  ELSE
355  i_a = 1
356  END IF
357  IF (PRESENT(a_first_col)) THEN
358  j_a = a_first_col
359  ELSE
360  j_a = 1
361  END IF
362  IF (PRESENT(b_first_row)) THEN
363  i_b = b_first_row
364  ELSE
365  i_b = 1
366  END IF
367  IF (PRESENT(b_first_col)) THEN
368  j_b = b_first_col
369  ELSE
370  j_b = 1
371  END IF
372  IF (PRESENT(c_first_row)) THEN
373  i_c = c_first_row
374  ELSE
375  i_c = 1
376  END IF
377  IF (PRESENT(c_first_col)) THEN
378  j_c = c_first_col
379  ELSE
380  j_c = 1
381  END IF
382 
383  alpha_t(1) = real(alpha, kind=dp)
384  alpha_t(2) = real(aimag(alpha), kind=dp)
385  beta_t(1) = real(beta, kind=dp)
386  beta_t(2) = real(aimag(beta), kind=dp)
387 
388  CALL cosma_pzgemm_c(transa=transa, transb=transb, m=m, n=n, k=k, &
389  alpha=c_loc(alpha_t), &
390  a=c_loc(matrix_a%local_data(1, 1)), ia=i_a, ja=j_a, &
391  desca=c_loc(matrix_a%matrix_struct%descriptor(1)), &
392  b=c_loc(matrix_b%local_data(1, 1)), ib=i_b, jb=j_b, &
393  descb=c_loc(matrix_b%matrix_struct%descriptor(1)), &
394  beta=c_loc(beta_t), &
395  c=c_loc(matrix_c%local_data(1, 1)), ic=i_c, jc=j_c, &
396  descc=c_loc(matrix_c%matrix_struct%descriptor(1)))
397 
398  END SUBROUTINE cosma_pzgemm
399 #endif
400 
401 END MODULE parallel_gemm_api
Basic linear algebra operations for complex full matrices.
subroutine, public cp_cfm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, c_first_col, c_first_row)
Performs one of the matrix-matrix operations: matrix_c = alpha * op1( matrix_a ) * op2( matrix_b ) + ...
Represents a complex full matrix distributed on many processors.
Definition: cp_cfm_types.F:12
basic linear algebra operations for full matrices
subroutine, public cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, c_first_col, c_first_row)
computes matrix_c = beta * matrix_c + alpha * ( matrix_a ** transa ) * ( matrix_b ** transb )
represent a full matrix distributed on many processors
Definition: cp_fm_types.F:15
integer function, public cp_fm_get_mm_type()
...
Definition: cp_fm_types.F:2623
collects all constants needed in input so that they can be used without circular dependencies
integer, parameter, public do_cosma
integer, parameter, public do_scalapack
Defines the basic variable types.
Definition: kinds.F:23
integer, parameter, public dp
Definition: kinds.F:34
Fortran API for the offload package, which is written in C.
Definition: offload_api.F:12
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()
Definition: offload_api.F:174
basic linear algebra operations for full matrixes