46 SUBROUTINE ab_contract(abint, sab, sphi_a, sphi_b, ncoa, ncob, nsgfa, nsgfb)
48 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(INOUT) :: abint
49 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: sab, sphi_a, sphi_b
50 INTEGER,
INTENT(IN) :: ncoa, ncob, nsgfa, nsgfb
52 CHARACTER(LEN=*),
PARAMETER :: routinen =
'ab_contract'
55 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: cpp
57 CALL timeset(routinen, handle)
59 cpassert(ncob <=
SIZE(sab, 2))
61 IF ((nsgfa*ncob*(ncoa + nsgfb)) <= (nsgfb*ncoa*(ncob + nsgfa)))
THEN
62 ALLOCATE (cpp(nsgfa, ncob))
64 CALL dgemm(
"T",
"N", nsgfa, ncob, ncoa, 1._dp, sphi_a,
SIZE(sphi_a, 1), sab,
SIZE(sab, 1), 0.0_dp, cpp, nsgfa)
66 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, cpp, nsgfa, sphi_b,
SIZE(sphi_b, 1), 0.0_dp, &
67 abint,
SIZE(abint, 1))
69 ALLOCATE (cpp(ncoa, nsgfb))
71 CALL dgemm(
"N",
"N", ncoa, nsgfb, ncob, 1._dp, sab,
SIZE(sab, 1), sphi_b,
SIZE(sphi_b, 1), 0.0_dp, cpp, ncoa)
73 CALL dgemm(
"T",
"N", nsgfa, nsgfb, ncoa, 1._dp, sphi_a,
SIZE(sphi_a, 1), cpp, ncoa, 0.0_dp, &
74 abint,
SIZE(abint, 1))
98 SUBROUTINE abc_contract(abcint, sabc, sphi_a, sphi_b, sphi_c, ncoa, ncob, ncoc, &
101 REAL(kind=
dp),
DIMENSION(:, :, :) :: abcint, sabc
102 REAL(kind=
dp),
DIMENSION(:, :) :: sphi_a, sphi_b, sphi_c
103 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, nsgfa, nsgfb, nsgfc
105 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abc_contract'
107 INTEGER :: handle, i, m1, m2, m3, msphia, msphib, &
109 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :) :: tmp
111 CALL timeset(routinen, handle)
113 cpassert(
SIZE(abcint, 1) == nsgfa)
114 cpassert(
SIZE(abcint, 2) == nsgfb)
116 msphia =
SIZE(sphi_a, 1)
117 msphib =
SIZE(sphi_b, 1)
118 msphic =
SIZE(sphi_c, 1)
126 ALLOCATE (tmp(nsgfa, mx, m3 + 1))
128 CALL dgemm(
"T",
"N", nsgfa, m2*m3, ncoa, 1._dp, sphi_a, msphia, sabc, m1, 0.0_dp, tmp(:, :, 2), nsgfa)
130 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, tmp(:, :, i + 1), nsgfa, sphi_b, msphib, &
131 0.0_dp, tmp(:, :, i), nsgfa)
133 CALL dgemm(
"N",
"N", nsgfa*nsgfb, nsgfc, ncoc, 1._dp, tmp, nsgfa*mx, sphi_c, msphic, 0.0_dp, &
138 CALL timestop(handle)
160 SUBROUTINE abcd_contract(abcdint, sabcd, sphi_a, sphi_b, sphi_c, sphi_d, ncoa, ncob, &
161 ncoc, ncod, nsgfa, nsgfb, nsgfc, nsgfd)
163 REAL(kind=
dp),
DIMENSION(:, :, :, :), &
164 INTENT(INOUT) :: abcdint
165 REAL(kind=
dp),
DIMENSION(:, :, :, :),
INTENT(IN) :: sabcd
166 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: sphi_a, sphi_b, sphi_c, sphi_d
167 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, ncod, nsgfa, nsgfb, &
170 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abcd_contract'
172 INTEGER :: handle, isgfc, isgfd, m1, m2, m3, m4, &
173 msphia, msphib, msphic, msphid
174 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: temp_cccc, work_cpcc
175 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :) :: temp_cpcc, work_cppc
176 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :, :) :: cpcc, cppc, cppp
178 CALL timeset(routinen, handle)
180 msphia =
SIZE(sphi_a, 1)
181 msphib =
SIZE(sphi_b, 1)
182 msphic =
SIZE(sphi_c, 1)
183 msphid =
SIZE(sphi_d, 1)
190 ALLOCATE (cppp(nsgfa, m2, m3, m4), cppc(nsgfa, m2, m3, nsgfd), &
191 cpcc(nsgfa, m2, nsgfc, nsgfd))
193 ALLOCATE (work_cppc(nsgfa, m2, m3), temp_cpcc(nsgfa, m2, nsgfc))
197 ALLOCATE (work_cpcc(nsgfa, m2), temp_cccc(nsgfa, nsgfb))
201 CALL dgemm(
"T",
"N", nsgfa, m2*m3*m4, ncoa, 1._dp, sphi_a, msphia, sabcd, m1, &
203 CALL dgemm(
"N",
"N", nsgfa*m2*m3, nsgfd, ncod, 1._dp, cppp, nsgfa*m2*m3, &
204 sphi_d, msphid, 0.0_dp, cppc, nsgfa*m2*m3)
207 work_cppc(:, :, :) = cppc(:, :, :, isgfd)
208 CALL dgemm(
"N",
"N", nsgfa*m2, nsgfc, ncoc, 1._dp, work_cppc, nsgfa*m2, &
209 sphi_c, msphic, 0.0_dp, temp_cpcc, nsgfa*m2)
210 cpcc(:, :, :, isgfd) = temp_cpcc(:, :, :)
212 work_cpcc(:, :) = cpcc(:, :, isgfc, isgfd)
213 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1._dp, work_cpcc, nsgfa, sphi_b, &
214 msphib, 0.0_dp, temp_cccc, nsgfa)
215 abcdint(:, :, isgfc, isgfd) = temp_cccc(:, :)
219 DEALLOCATE (cpcc, cppc, cppp)
220 DEALLOCATE (work_cpcc, work_cppc, temp_cpcc, temp_cccc)
222 CALL timestop(handle)
249 nsgfa, nsgfb, nsgfc, cpp_buffer, ccp_buffer, prefac, pstfac)
251 REAL(kind=
dp),
DIMENSION(:, :, :) :: abcint
252 REAL(kind=
dp),
DIMENSION(*),
INTENT(IN),
TARGET :: sabc
253 REAL(kind=
dp),
DIMENSION(:, :),
CONTIGUOUS,
INTENT(IN),
TARGET :: sphi_a, sphi_b, sphi_c
254 INTEGER,
INTENT(IN) :: ncoa, ncob, ncoc, nsgfa, nsgfb, nsgfc
255 REAL(kind=
dp),
DIMENSION(:),
ALLOCATABLE,
TARGET :: cpp_buffer, ccp_buffer
256 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: prefac, pstfac
258 CHARACTER(LEN=*),
PARAMETER :: routinen =
'abc_contract_xsmm'
260 REAL(kind=
dp) :: alpha, beta
261 INTEGER(KIND=int_8) :: cpp_size, ccp_size
265 TYPE(libxs_gemm_config_t) :: cfg1, cfg2
269 CALL timeset(routinen, handle)
272 IF (
PRESENT(prefac)) alpha = prefac
275 IF (
PRESENT(pstfac)) beta = pstfac
278 IF ((nsgfa*ncob*(ncoa + nsgfb)) <= (ncoa*nsgfb*(ncob + nsgfa)))
THEN
279 cpp_size = nsgfa*ncob
282 cpp_size = ncoa*nsgfb
286 ccp_size = nsgfa*nsgfb*ncoc
287 IF (.NOT.
ALLOCATED(ccp_buffer))
THEN
288 ALLOCATE (ccp_buffer(ccp_size))
289 ELSE IF (
SIZE(ccp_buffer) < ccp_size)
THEN
290 DEALLOCATE (ccp_buffer)
291 ALLOCATE (ccp_buffer(ccp_size))
294 IF (.NOT.
ALLOCATED(cpp_buffer))
THEN
295 ALLOCATE (cpp_buffer(cpp_size))
296 ELSE IF (
SIZE(cpp_buffer) < cpp_size)
THEN
297 DEALLOCATE (cpp_buffer)
298 ALLOCATE (cpp_buffer(cpp_size))
303 rc1 = libxs_gemm_dispatch(cfg1, libxs_datatype_f64,
'N',
'N', &
304 nsgfa, ncob, ncoa, nsgfa, ncoa, nsgfa, 1.0d0, 0.0d0)
305 rc2 = libxs_gemm_dispatch(cfg2, libxs_datatype_f64,
'N',
'N', &
306 nsgfa, nsgfb, ncob, nsgfa, ncob, nsgfa, 1.0d0, 0.0d0)
308 rc1 = libxs_gemm_dispatch(cfg1, libxs_datatype_f64,
'N',
'N', &
309 ncoa, nsgfb, ncob, ncoa, ncob, ncoa, 1.0d0, 0.0d0)
310 rc2 = libxs_gemm_dispatch(cfg2, libxs_datatype_f64,
'N',
'N', &
311 nsgfa, nsgfb, ncoa, nsgfa, ncoa, nsgfa, 1.0d0, 0.0d0)
313 IF (0 /= rc1 .AND. 0 /= rc2)
THEN
316 CALL libxs_gemm_call(cfg1, c_loc(sphi_a(1, 1)), &
317 c_loc(sabc(i*ncoa*ncob + 1)), c_loc(cpp_buffer(1)))
318 CALL libxs_gemm_call(cfg2, c_loc(cpp_buffer(1)), &
319 c_loc(sphi_b(1, 1)), c_loc(ccp_buffer(i*nsgfa*nsgfb + 1)))
323 CALL libxs_gemm_call(cfg1, c_loc(sabc(i*ncoa*ncob + 1)), &
324 c_loc(sphi_b(1, 1)), c_loc(cpp_buffer(1)))
325 CALL libxs_gemm_call(cfg2, c_loc(sphi_a(1, 1)), &
326 c_loc(cpp_buffer(1)), c_loc(ccp_buffer(i*nsgfa*nsgfb + 1)))
333 CALL dgemm(
"N",
"N", nsgfa, ncob, ncoa, 1.0_dp, sphi_a, nsgfa, sabc(i*ncoa*ncob + 1), &
334 ncoa, 0.0_dp, cpp_buffer, nsgfa)
335 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncob, 1.0_dp, cpp_buffer, nsgfa, sphi_b, ncob, 0.0_dp, &
336 ccp_buffer(i*nsgfa*nsgfb + 1), nsgfa)
340 CALL dgemm(
"N",
"N", ncoa, nsgfb, ncob, 1.0_dp, sabc(i*ncoa*ncob + 1), ncoa, sphi_b, &
341 ncob, 0.0_dp, cpp_buffer, ncoa)
342 CALL dgemm(
"N",
"N", nsgfa, nsgfb, ncoa, 1.0_dp, sphi_a, nsgfa, cpp_buffer, ncoa, 0.0_dp, &
343 ccp_buffer(i*nsgfa*nsgfb + 1), nsgfa)
350 CALL dgemm(
"N",
"N", nsgfa*nsgfb, nsgfc, ncoc, alpha, ccp_buffer, nsgfa*nsgfb, &
351 sphi_c, ncoc, beta, abcint, nsgfa*nsgfb)
353 CALL timestop(handle)
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.