47 SUBROUTINE ab_contract(abint, sab, sphi_a, sphi_b, ncoa, ncob, nsgfa, nsgfb)
49 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(INOUT) :: abint
50 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: sab, sphi_a, sphi_b
51 INTEGER,
INTENT(IN) :: ncoa, ncob, nsgfa, nsgfb
53 CHARACTER(LEN=*),
PARAMETER :: routinen =
'ab_contract'
56 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: cpp
58 CALL timeset(routinen, handle)
60 cpassert(ncob <=
SIZE(sab, 2))
62 IF ((nsgfa*ncob*(ncoa + nsgfb)) <= (nsgfb*ncoa*(ncob + nsgfa)))
THEN
63 ALLOCATE (cpp(nsgfa, ncob))
65 CALL dgemm(
"T",
"N", nsgfa, ncob, ncoa, 1._dp, sphi_a,
SIZE(sphi_a, 1), sab,
SIZE(sab, 1), 0.0_dp, cpp, nsgfa)
67 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, cpp, nsgfa, sphi_b,
SIZE(sphi_b, 1), 0.0_dp, &
68 abint,
SIZE(abint, 1))
70 ALLOCATE (cpp(ncoa, nsgfb))
72 CALL dgemm(
"N",
"N", ncoa, nsgfb, ncob, 1._dp, sab,
SIZE(sab, 1), sphi_b,
SIZE(sphi_b, 1), 0.0_dp, cpp, ncoa)
74 CALL dgemm(
"T",
"N", nsgfa, nsgfb, ncoa, 1._dp, sphi_a,
SIZE(sphi_a, 1), cpp, ncoa, 0.0_dp, &
75 abint,
SIZE(abint, 1))
99 SUBROUTINE abc_contract(abcint, sabc, sphi_a, sphi_b, sphi_c, ncoa, ncob, ncoc, &
102 REAL(kind=
dp),
DIMENSION(:, :, :) :: abcint, sabc
103 REAL(kind=
dp),
DIMENSION(:, :) :: sphi_a, sphi_b, sphi_c
104 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, nsgfa, nsgfb, nsgfc
106 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abc_contract'
108 INTEGER :: handle, i, m1, m2, m3, msphia, msphib, &
110 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :) :: tmp
112 CALL timeset(routinen, handle)
114 cpassert(
SIZE(abcint, 1) == nsgfa)
115 cpassert(
SIZE(abcint, 2) == nsgfb)
117 msphia =
SIZE(sphi_a, 1)
118 msphib =
SIZE(sphi_b, 1)
119 msphic =
SIZE(sphi_c, 1)
127 ALLOCATE (tmp(nsgfa, mx, m3 + 1))
129 CALL dgemm(
"T",
"N", nsgfa, m2*m3, ncoa, 1._dp, sphi_a, msphia, sabc, m1, 0.0_dp, tmp(:, :, 2), nsgfa)
131 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, tmp(:, :, i + 1), nsgfa, sphi_b, msphib, &
132 0.0_dp, tmp(:, :, i), nsgfa)
134 CALL dgemm(
"N",
"N", nsgfa*nsgfb, nsgfc, ncoc, 1._dp, tmp, nsgfa*mx, sphi_c, msphic, 0.0_dp, &
139 CALL timestop(handle)
161 SUBROUTINE abcd_contract(abcdint, sabcd, sphi_a, sphi_b, sphi_c, sphi_d, ncoa, ncob, &
162 ncoc, ncod, nsgfa, nsgfb, nsgfc, nsgfd)
164 REAL(kind=
dp),
DIMENSION(:, :, :, :), &
165 INTENT(INOUT) :: abcdint
166 REAL(kind=
dp),
DIMENSION(:, :, :, :),
INTENT(IN) :: sabcd
167 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: sphi_a, sphi_b, sphi_c, sphi_d
168 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, ncod, nsgfa, nsgfb, &
171 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abcd_contract'
173 INTEGER :: handle, isgfc, isgfd, m1, m2, m3, m4, &
174 msphia, msphib, msphic, msphid
175 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: temp_cccc, work_cpcc
176 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :) :: temp_cpcc, work_cppc
177 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :, :) :: cpcc, cppc, cppp
179 CALL timeset(routinen, handle)
181 msphia =
SIZE(sphi_a, 1)
182 msphib =
SIZE(sphi_b, 1)
183 msphic =
SIZE(sphi_c, 1)
184 msphid =
SIZE(sphi_d, 1)
191 ALLOCATE (cppp(nsgfa, m2, m3, m4), cppc(nsgfa, m2, m3, nsgfd), &
192 cpcc(nsgfa, m2, nsgfc, nsgfd))
194 ALLOCATE (work_cppc(nsgfa, m2, m3), temp_cpcc(nsgfa, m2, nsgfc))
198 ALLOCATE (work_cpcc(nsgfa, m2), temp_cccc(nsgfa, nsgfb))
202 CALL dgemm(
"T",
"N", nsgfa, m2*m3*m4, ncoa, 1._dp, sphi_a, msphia, sabcd, m1, &
204 CALL dgemm(
"N",
"N", nsgfa*m2*m3, nsgfd, ncod, 1._dp, cppp, nsgfa*m2*m3, &
205 sphi_d, msphid, 0.0_dp, cppc, nsgfa*m2*m3)
208 work_cppc(:, :, :) = cppc(:, :, :, isgfd)
209 CALL dgemm(
"N",
"N", nsgfa*m2, nsgfc, ncoc, 1._dp, work_cppc, nsgfa*m2, &
210 sphi_c, msphic, 0.0_dp, temp_cpcc, nsgfa*m2)
211 cpcc(:, :, :, isgfd) = temp_cpcc(:, :, :)
213 work_cpcc(:, :) = cpcc(:, :, isgfc, isgfd)
214 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, work_cpcc, nsgfa, sphi_b, &
215 msphib, 0.0_dp, temp_cccc, nsgfa)
216 abcdint(:, :, isgfc, isgfd) = temp_cccc(:, :)
220 DEALLOCATE (cpcc, cppc, cppp)
221 DEALLOCATE (work_cpcc, work_cppc, temp_cpcc, temp_cccc)
223 CALL timestop(handle)
250 nsgfa, nsgfb, nsgfc, cpp_buffer, ccp_buffer, prefac, pstfac)
252 REAL(kind=
dp),
DIMENSION(:, :, :) :: abcint
253 REAL(kind=
dp),
DIMENSION(*),
INTENT(IN),
TARGET :: sabc
254 REAL(kind=
dp),
DIMENSION(:, :),
CONTIGUOUS,
INTENT(IN),
TARGET :: sphi_a, sphi_b, sphi_c
255 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, nsgfa, nsgfb, nsgfc
256 REAL(kind=
dp),
DIMENSION(:),
ALLOCATABLE,
TARGET :: cpp_buffer, ccp_buffer
257 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: prefac, pstfac
259 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abc_contract_xsmm'
261 REAL(kind=
dp) :: alpha, beta
262 INTEGER(KIND=int_8) :: cpp_size, ccp_size
266 TYPE(libxs_gemm_config_t) :: cfg1, cfg2
270 FUNCTION mkl_cblas_jit_create_dgemm(jitter, &
271 layout, transa, transb, m, n, k, &
272 alpha, lda, ldb, beta, ldc) &
273 result(status)
BIND(C)
274 use,
INTRINSIC :: iso_c_binding, only: c_ptr, c_int, c_double
275 INTEGER(C_INT) :: status
276 TYPE(c_ptr) :: jitter
277 INTEGER(C_INT),
VALUE :: layout, transa, transb
278 INTEGER(C_INT),
VALUE :: m, n, k, lda, ldb, ldc
279 REAL(c_double),
VALUE :: alpha, beta
281 FUNCTION mkl_jit_get_dgemm_ptr(jitter)
RESULT(ptr)
BIND(C)
282 use,
INTRINSIC :: iso_c_binding, only: c_funptr, c_ptr
283 TYPE(c_funptr) :: ptr
284 TYPE(c_ptr),
INTENT(IN),
VALUE :: jitter
288#if defined(__LIBXSMM)
290 FUNCTION libxsmm_dispatch_gemm(shape, flags, prefetch) &
292 use,
INTRINSIC :: iso_c_binding, only: c_funptr, c_int
293 INTEGER(C_INT),
INTENT(IN) :: shape(10)
294 INTEGER(C_INT),
INTENT(IN),
VALUE :: flags, prefetch
300 SUBROUTINE dgemm_blas(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) &
301 BIND(C, NAME="dgemm_")
302 use,
INTRINSIC :: iso_c_binding, only: c_int, c_char, c_double
303 CHARACTER(1, C_CHAR),
INTENT(IN) :: transa, transb
304 INTEGER(C_INT),
INTENT(IN) :: m, n, k, lda, ldb, ldc
305 REAL(c_double),
INTENT(IN) :: alpha, beta, a(*), b(*)
306 REAL(c_double),
INTENT(INOUT) :: c(*)
311 CALL timeset(routinen, handle)
314 IF (
PRESENT(prefac)) alpha = prefac
317 IF (
PRESENT(pstfac)) beta = pstfac
320 IF ((nsgfa*ncob*(ncoa + nsgfb)) <= (ncoa*nsgfb*(ncob + nsgfa)))
THEN
321 cpp_size = nsgfa*ncob
324 cpp_size = ncoa*nsgfb
328 ccp_size = nsgfa*nsgfb*ncoc
329 IF (.NOT.
ALLOCATED(ccp_buffer))
THEN
330 ALLOCATE (ccp_buffer(ccp_size))
331 ELSE IF (
SIZE(ccp_buffer) < ccp_size)
THEN
332 DEALLOCATE (ccp_buffer)
333 ALLOCATE (ccp_buffer(ccp_size))
336 IF (.NOT.
ALLOCATED(cpp_buffer))
THEN
337 ALLOCATE (cpp_buffer(cpp_size))
338 ELSE IF (
SIZE(cpp_buffer) < cpp_size)
THEN
339 DEALLOCATE (cpp_buffer)
340 ALLOCATE (cpp_buffer(cpp_size))
345 rc1 = libxs_gemm_dispatch(cfg1, libxs_datatype_f64,
'N',
'N', &
346 nsgfa, ncob, ncoa, nsgfa, ncoa, nsgfa, 1.0d0, 0.0d0 &
348 , jit_create_dgemm=c_funloc(mkl_cblas_jit_create_dgemm) &
349 , jit_get_dgemm=c_funloc(mkl_jit_get_dgemm_ptr) &
351#if defined(__LIBXSMM)
352 , xgemm_dispatch=c_funloc(libxsmm_dispatch_gemm) &
354 , dgemm_blas=c_funloc(dgemm_blas))
355 rc2 = libxs_gemm_dispatch(cfg2, libxs_datatype_f64,
'N',
'N', &
356 nsgfa, nsgfb, ncob, nsgfa, ncob, nsgfa, 1.0d0, 0.0d0 &
358 , jit_create_dgemm=c_funloc(mkl_cblas_jit_create_dgemm) &
359 , jit_get_dgemm=c_funloc(mkl_jit_get_dgemm_ptr) &
361#if defined(__LIBXSMM)
362 , xgemm_dispatch=c_funloc(libxsmm_dispatch_gemm) &
364 , dgemm_blas=c_funloc(dgemm_blas))
366 rc1 = libxs_gemm_dispatch(cfg1, libxs_datatype_f64,
'N',
'N', &
367 ncoa, nsgfb, ncob, ncoa, ncob, ncoa, 1.0d0, 0.0d0 &
369 , jit_create_dgemm=c_funloc(mkl_cblas_jit_create_dgemm) &
370 , jit_get_dgemm=c_funloc(mkl_jit_get_dgemm_ptr) &
372#if defined(__LIBXSMM)
373 , xgemm_dispatch=c_funloc(libxsmm_dispatch_gemm) &
375 , dgemm_blas=c_funloc(dgemm_blas))
376 rc2 = libxs_gemm_dispatch(cfg2, libxs_datatype_f64,
'N',
'N', &
377 nsgfa, nsgfb, ncoa, nsgfa, ncoa, nsgfa, 1.0d0, 0.0d0 &
379 , jit_create_dgemm=c_funloc(mkl_cblas_jit_create_dgemm) &
380 , jit_get_dgemm=c_funloc(mkl_jit_get_dgemm_ptr) &
382#if defined(__LIBXSMM)
383 , xgemm_dispatch=c_funloc(libxsmm_dispatch_gemm) &
385 , dgemm_blas=c_funloc(dgemm_blas))
387 IF (0 /= rc1 .AND. 0 /= rc2)
THEN
390 CALL libxs_gemm_call(cfg1, c_loc(sphi_a(1, 1)), &
391 c_loc(sabc(i*ncoa*ncob + 1)), c_loc(cpp_buffer(1)))
392 CALL libxs_gemm_call(cfg2, c_loc(cpp_buffer(1)), &
393 c_loc(sphi_b(1, 1)), c_loc(ccp_buffer(i*nsgfa*nsgfb + 1)))
397 CALL libxs_gemm_call(cfg1, c_loc(sabc(i*ncoa*ncob + 1)), &
398 c_loc(sphi_b(1, 1)), c_loc(cpp_buffer(1)))
399 CALL libxs_gemm_call(cfg2, c_loc(sphi_a(1, 1)), &
400 c_loc(cpp_buffer(1)), c_loc(ccp_buffer(i*nsgfa*nsgfb + 1)))
407 CALL dgemm(
"N",
"N", nsgfa, ncob, ncoa, 1.0_dp, sphi_a, nsgfa, sabc(i*ncoa*ncob + 1), &
408 ncoa, 0.0_dp, cpp_buffer, nsgfa)
409 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1.0_dp, cpp_buffer, nsgfa, sphi_b, ncob, 0.0_dp, &
410 ccp_buffer(i*nsgfa*nsgfb + 1), nsgfa)
414 CALL dgemm(
"N",
"N", ncoa, nsgfb, ncob, 1.0_dp, sabc(i*ncoa*ncob + 1), ncoa, sphi_b, &
415 ncob, 0.0_dp, cpp_buffer, ncoa)
416 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncoa, 1.0_dp, sphi_a, nsgfa, cpp_buffer, ncoa, 0.0_dp, &
417 ccp_buffer(i*nsgfa*nsgfb + 1), nsgfa)
424 CALL dgemm(
"N",
"N", nsgfa*nsgfb, nsgfc, ncoc, alpha, ccp_buffer, nsgfa*nsgfb, &
425 sphi_c, ncoc, beta, abcint, nsgfa*nsgfb)
427 CALL timestop(handle)