53#include "./base/base_uses.f90"
59 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'skala_gpw_functional'
60 INTEGER,
PARAMETER,
PRIVATE :: ngrad_per_point = 10
64 TYPE(skala_torch_model_type),
SAVE :: cached_model
65 CHARACTER(len=default_path_length),
SAVE :: cached_model_path =
""
66 LOGICAL,
SAVE :: cached_model_loaded = .false.
67 INTEGER,
SAVE :: cached_model_cuda_device = -3
68 INTEGER,
SAVE :: logged_cuda_device = -3, &
69 logged_cuda_device_count = -1, &
70 logged_cuda_nproc = -1, &
71 logged_cuda_request = -3
82 LOGICAL :: uses_native_grid
86 uses_native_grid = .false.
87 gauxc_section => get_gauxc_section(xc_section)
88 IF (
ASSOCIATED(gauxc_section))
THEN
101 CHARACTER(len=default_path_length) :: model_key, model_name
102 CHARACTER(len=default_string_length) :: functional_key, functional_name
103 INTEGER :: ifun, nfun
104 LOGICAL :: native_grid
107 NULLIFY (gauxc_section)
108 IF (.NOT.
ASSOCIATED(xc_section))
THEN
109 cpabort(
"Native SKALA GPW requires an XC section")
113 IF (.NOT.
ASSOCIATED(functionals))
THEN
114 cpabort(
"Native SKALA GPW requires an XC_FUNCTIONAL section")
122 IF (.NOT.
ASSOCIATED(xc_fun))
EXIT
124 IF (xc_fun%section%name ==
"GAUXC") gauxc_section => xc_fun
127 IF (.NOT.
ASSOCIATED(gauxc_section))
THEN
128 cpabort(
"Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
131 cpabort(
"Native SKALA GPW requires GAUXC to be the only XC functional")
135 IF (.NOT. native_grid)
RETURN
138 functional_key = adjustl(functional_name)
140 IF (trim(functional_key) /=
"PBE")
THEN
141 cpabort(
"Native SKALA GPW currently requires GAUXC%FUNCTIONAL PBE")
145 model_key = adjustl(model_name)
147 IF (trim(model_key) ==
"NONE" .OR. trim(model_key) ==
"")
THEN
148 cpabort(
"Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
171 SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
172 weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
173 just_energy, atom_force)
175 REAL(kind=
dp),
INTENT(OUT) :: exc
184 LOGICAL,
INTENT(IN) :: compute_virial
185 REAL(kind=
dp),
DIMENSION(3, 3),
INTENT(OUT) :: virial_xc
186 LOGICAL,
INTENT(IN),
OPTIONAL :: just_energy
187 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(OUT), &
188 OPTIONAL :: atom_force
190 CHARACTER(len=default_path_length) :: model_path
191 INTEGER :: native_grid_cuda_device, nspins, &
192 phase_handle, selected_cuda_device, &
193 xc_deriv_method_id, xc_rho_smooth_id
194 LOGICAL :: lsd, my_just_energy, native_grid_atom_chunk_routing, native_grid_atom_chunks, &
195 native_grid_diagnostics, native_grid_use_cuda, needs_atom_force
196 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: density_grad, kin_grad
197 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :) :: grad_grad
206 my_just_energy = .false.
207 IF (
PRESENT(just_energy)) my_just_energy = just_energy
208 needs_atom_force =
PRESENT(atom_force)
209 IF (needs_atom_force) atom_force = 0.0_dp
211 IF (compute_virial)
THEN
212 CALL cp_abort(__location__, &
213 "Native SKALA GPW stress/virial is not implemented yet.")
215 IF (.NOT.
ASSOCIATED(rho_g))
THEN
216 CALL cp_abort(__location__, &
217 "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
219 IF (.NOT.
ASSOCIATED(tau))
THEN
220 CALL cp_abort(__location__, &
221 "Native SKALA GPW requires the kinetic-energy density.")
226 CALL get_skala_model_path(xc_section, model_path)
227 gauxc_section => get_gauxc_section(xc_section)
230 i_val=native_grid_cuda_device)
232 l_val=native_grid_atom_chunks)
234 l_val=native_grid_atom_chunk_routing)
235 native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
236 native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
237 IF (native_grid_atom_chunks .AND. needs_atom_force)
THEN
238 CALL cp_abort(__location__, &
239 "Native SKALA GPW atom chunks are not implemented for atom forces yet.")
243 selected_cuda_device = configure_native_grid_cuda( &
244 native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
245 CALL ensure_model_loaded(model_path, selected_cuda_device)
248 needs%rho_spin = .true.
249 needs%drho_spin = .true.
250 needs%tau_spin = .true.
261 rho_r(1)%pw_grid%bounds_local, &
266 xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
269 requires_grad=(.NOT. my_just_energy), weights=weights, &
270 requires_coordinate_grad=needs_atom_force, &
271 use_atom_chunks=native_grid_atom_chunks, &
272 route_atom_chunks=native_grid_atom_chunk_routing)
273 CALL section_vals_val_get(gauxc_section,
"NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
274 IF (native_grid_diagnostics)
THEN
275 CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
278 features%grid_weights_t, exc_tensor, exc)
279 IF (features%uses_atom_chunks)
CALL rho_r(1)%pw_grid%para%group%sum(exc)
281 IF (.NOT. my_just_energy)
THEN
282 CALL timeset(
"skala_gpw_backward", phase_handle)
284 CALL timestop(phase_handle)
286 CALL timeset(
"skala_gpw_grad_fetch", phase_handle)
287 IF (features%uses_atom_chunks)
THEN
288 CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
289 density_grad, grad_grad, kin_grad)
291 CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
293 IF (needs_atom_force)
THEN
294 CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
295 rho_r(1)%pw_grid%para%group%mepos == 0)
297 CALL timestop(phase_handle)
299 CALL timeset(
"skala_gpw_vxc_unpack", phase_handle)
300 CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
301 density_grad, grad_grad, kin_grad, &
303 CALL timestop(phase_handle)
305 CALL timeset(
"skala_gpw_grad_release", phase_handle)
306 DEALLOCATE (density_grad, grad_grad, kin_grad)
308 CALL timestop(phase_handle)
311 CALL timeset(
"skala_gpw_cleanup", phase_handle)
316 CALL timestop(phase_handle)
327 SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
328 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(INOUT) :: atom_force
331 LOGICAL,
INTENT(IN) :: root_rank
333 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: atom_coord_grad
335 NULLIFY (atom_coord_grad)
339 cpassert(
SIZE(atom_force, 1) ==
SIZE(atom_coord_grad, 1))
340 cpassert(
SIZE(atom_force, 2) ==
SIZE(atom_coord_grad, 2))
341 atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
344 END SUBROUTINE add_explicit_coordinate_force
353 SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
355 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :), &
356 INTENT(OUT) :: density_grad
357 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :), &
358 INTENT(OUT) :: grad_grad
359 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :), &
360 INTENT(OUT) :: kin_grad
362 INTEGER :: i, j, k, local_row, row
363 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: density_grad_all, kin_grad_all
364 REAL(kind=
dp),
DIMENSION(:, :, :),
POINTER :: grad_grad_all
367 NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
368 CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
369 density_grad_all, grad_grad_all, kin_grad_all)
370 cpassert(
SIZE(density_grad_all, 1) == features%nflat)
371 cpassert(
SIZE(density_grad_all, 2) == 2)
372 cpassert(
SIZE(grad_grad_all, 1) == features%nflat)
373 cpassert(
SIZE(grad_grad_all, 2) == 3)
374 cpassert(
SIZE(grad_grad_all, 3) == 2)
375 cpassert(
SIZE(kin_grad_all, 1) == features%nflat)
376 cpassert(
SIZE(kin_grad_all, 2) == 2)
378 ALLOCATE (density_grad(features%nflat_local, 2), &
379 grad_grad(features%nflat_local, 3, 2), &
380 kin_grad(features%nflat_local, 2))
382 DO k = lbound(features%feature_index, 3), ubound(features%feature_index, 3)
383 DO j = lbound(features%feature_index, 2), ubound(features%feature_index, 2)
384 DO i = lbound(features%feature_index, 1), ubound(features%feature_index, 1)
385 local_row = local_row + 1
386 row = features%feature_index(i, j, k)
387 cpassert(row >= 1 .AND. row <= features%nflat)
388 density_grad(local_row, :) = density_grad_all(row, :)
389 grad_grad(local_row, :, :) = grad_grad_all(row, :, :)
390 kin_grad(local_row, :) = kin_grad_all(row, :)
394 cpassert(local_row == features%nflat_local)
400 END SUBROUTINE fetch_local_feature_grads
408 SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions)
410 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:), &
411 INTENT(INOUT) ::
target
412 LOGICAL,
INTENT(IN) :: route_to_return_positions
414 INTEGER :: base, irow, point_pos
415 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: chunk_density_grad, chunk_kin_grad
416 REAL(kind=
dp),
DIMENSION(:, :, :),
POINTER :: chunk_grad_grad
419 NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
420 CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
421 chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
422 cpassert(
SIZE(
TARGET) == ngrad_per_point*features%chunk_feature_count)
423 cpassert(
SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
424 cpassert(
SIZE(chunk_density_grad, 2) == 2)
425 cpassert(
SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
426 cpassert(
SIZE(chunk_grad_grad, 2) == 3)
427 cpassert(
SIZE(chunk_grad_grad, 3) == 2)
428 cpassert(
SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
429 cpassert(
SIZE(chunk_kin_grad, 2) == 2)
432 DO irow = 1, features%chunk_feature_count
433 IF (route_to_return_positions)
THEN
434 point_pos = features%chunk_return_positions(irow)
435 cpassert(point_pos >= 1 .AND. point_pos <= features%chunk_feature_count)
439 base = ngrad_per_point*(point_pos - 1)
440 target(base + 1:base + 2) = chunk_density_grad(irow, :)
441 target(base + 3) = chunk_grad_grad(irow, 1, 1)
442 target(base + 4) = chunk_grad_grad(irow, 2, 1)
443 target(base + 5) = chunk_grad_grad(irow, 3, 1)
444 target(base + 6) = chunk_grad_grad(irow, 1, 2)
445 target(base + 7) = chunk_grad_grad(irow, 2, 2)
446 target(base + 8) = chunk_grad_grad(irow, 3, 2)
447 target(base + 9:base + 10) = chunk_kin_grad(irow, :)
454 END SUBROUTINE pack_atom_chunk_grads
466 SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
467 density_grad, grad_grad, kin_grad)
469 TYPE(
torch_tensor_type),
INTENT(INOUT) :: density_grad_t, grad_grad_t, kin_grad_t
470 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: density_grad
471 REAL(kind=
dp),
DIMENSION(:, :, :),
POINTER :: grad_grad
472 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: kin_grad
474 NULLIFY (density_grad, grad_grad, kin_grad)
482 END SUBROUTINE get_feature_grad_views
492 SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
497 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :), &
498 INTENT(OUT) :: density_grad, kin_grad
499 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :, :), &
500 INTENT(OUT) :: grad_grad
502 INTEGER :: base, i, j, k, local_row, &
503 nflat_local, phase_handle, &
505 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:) :: chunk_grad_buffer, global_grad_buffer, &
506 recv_grad_buffer, send_grad_buffer
508 cpassert(features%uses_atom_chunks)
509 cpassert(features%chunk_feature_count > 0)
511 nflat_local = features%nflat_local
512 IF (features%uses_atom_chunk_routing)
THEN
513 cpassert(sum(features%route_point_recv_counts) == features%chunk_feature_count)
514 cpassert(sum(features%route_point_send_counts) == nflat_local)
516 ALLOCATE (send_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
517 recv_grad_buffer(ngrad_per_point*nflat_local))
519 CALL timeset(
"skala_gpw_grad_torch_pack", phase_handle)
520 CALL pack_atom_chunk_grads(features, send_grad_buffer, .true.)
521 CALL timestop(phase_handle)
523 CALL timeset(
"skala_gpw_grad_route_comm", phase_handle)
524 CALL group%alltoall(send_grad_buffer, features%route_grad_return_send_counts, &
525 features%route_grad_return_send_displs, recv_grad_buffer, &
526 features%route_grad_return_recv_counts, &
527 features%route_grad_return_recv_displs)
528 CALL timestop(phase_handle)
530 CALL timeset(
"skala_gpw_grad_route_scatter", phase_handle)
531 ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
532 kin_grad(nflat_local, 2))
533 density_grad = 0.0_dp
536 DO point_pos = 1, nflat_local
537 local_row = features%route_send_local_rows(point_pos)
538 cpassert(local_row >= 1 .AND. local_row <= nflat_local)
539 base = ngrad_per_point*(point_pos - 1)
540 density_grad(local_row, :) = recv_grad_buffer(base + 1:base + 2)
541 grad_grad(local_row, 1, 1) = recv_grad_buffer(base + 3)
542 grad_grad(local_row, 2, 1) = recv_grad_buffer(base + 4)
543 grad_grad(local_row, 3, 1) = recv_grad_buffer(base + 5)
544 grad_grad(local_row, 1, 2) = recv_grad_buffer(base + 6)
545 grad_grad(local_row, 2, 2) = recv_grad_buffer(base + 7)
546 grad_grad(local_row, 3, 2) = recv_grad_buffer(base + 8)
547 kin_grad(local_row, :) = recv_grad_buffer(base + 9:base + 10)
549 CALL timestop(phase_handle)
551 DEALLOCATE (recv_grad_buffer, send_grad_buffer)
553 ALLOCATE (chunk_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
554 global_grad_buffer(ngrad_per_point*features%nflat))
555 CALL timeset(
"skala_gpw_grad_torch_pack", phase_handle)
556 CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .false.)
557 CALL timestop(phase_handle)
559 CALL timeset(
"skala_gpw_grad_allgatherv", phase_handle)
560 CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
561 features%chunk_grad_counts, features%chunk_grad_displs)
562 CALL timestop(phase_handle)
564 CALL timeset(
"skala_gpw_grad_scatter", phase_handle)
565 ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
566 kin_grad(nflat_local, 2))
568 DO k = lbound(features%feature_index, 3), ubound(features%feature_index, 3)
569 DO j = lbound(features%feature_index, 2), ubound(features%feature_index, 2)
570 DO i = lbound(features%feature_index, 1), ubound(features%feature_index, 1)
571 local_row = local_row + 1
572 row = features%feature_index(i, j, k)
573 cpassert(row >= 1 .AND. row <= features%nflat)
574 base = ngrad_per_point*(row - 1)
575 density_grad(local_row, :) = global_grad_buffer(base + 1:base + 2)
576 grad_grad(local_row, 1, 1) = global_grad_buffer(base + 3)
577 grad_grad(local_row, 2, 1) = global_grad_buffer(base + 4)
578 grad_grad(local_row, 3, 1) = global_grad_buffer(base + 5)
579 grad_grad(local_row, 1, 2) = global_grad_buffer(base + 6)
580 grad_grad(local_row, 2, 2) = global_grad_buffer(base + 7)
581 grad_grad(local_row, 3, 2) = global_grad_buffer(base + 8)
582 kin_grad(local_row, :) = global_grad_buffer(base + 9:base + 10)
586 CALL timestop(phase_handle)
587 DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
591 END SUBROUTINE fetch_and_gather_atom_chunk_grads
604 SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
605 density_grad, grad_grad, kin_grad, &
607 TYPE(
pw_r3d_rs_type),
DIMENSION(:),
POINTER :: vxc_rho, vxc_tau, rho_r
609 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: density_grad
610 REAL(kind=
dp),
DIMENSION(:, :, :),
INTENT(IN) :: grad_grad
611 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(IN) :: kin_grad
612 INTEGER,
INTENT(IN) :: xc_deriv_method_id
614 INTEGER :: i, ipt, ispin, j, k, nspins
615 INTEGER,
DIMENSION(2, 3) :: bo
616 REAL(kind=
dp) :: dvol_inv
621 bo = rho_r(1)%pw_grid%bounds_local
622 dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
624 ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
626 CALL pw_pool%create_pw(vxc_rho(ispin))
627 CALL pw_pool%create_pw(vxc_tau(ispin))
633 CALL pw_pool%create_pw(vxc_g)
634 IF (.NOT. rho_r(1)%pw_grid%spherical)
CALL pw_pool%create_pw(tmp_g)
639 CALL pw_pool%create_pw(grad_pw(i))
644 DO k = bo(1, 3), bo(2, 3)
645 DO j = bo(1, 2), bo(2, 2)
646 DO i = bo(1, 1), bo(2, 1)
648 IF (nspins == 1)
THEN
649 vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
650 (density_grad(ipt, 1) + density_grad(ipt, 2))
651 vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
652 (kin_grad(ipt, 1) + kin_grad(ipt, 2))
653 grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
654 (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
655 grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
656 (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
657 grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
658 (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
660 vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
661 vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
662 grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
663 grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
664 grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
673 CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
676 CALL pw_pool%give_back_pw(grad_pw(i))
680 IF (
ASSOCIATED(vxc_g%pw_grid))
CALL pw_pool%give_back_pw(vxc_g)
681 IF (
ASSOCIATED(tmp_g%pw_grid))
CALL pw_pool%give_back_pw(tmp_g)
683 END SUBROUTINE build_vxc_from_feature_grads
690 SUBROUTINE print_native_grid_diagnostics(features, print_active)
692 LOGICAL,
INTENT(IN) :: print_active
696 IF (.NOT. print_active)
RETURN
699 WRITE (unit=iw, fmt=
"(/,T2,A,1X,ES19.11)") &
700 "SKALA_GPW| Native grid feature electrons", features%electron_count
701 WRITE (unit=iw, fmt=
"(T2,A,1X,ES19.11)") &
702 "SKALA_GPW| Native grid feature spin moment", features%spin_moment
703 WRITE (unit=iw, fmt=
"(T2,A,1X,ES19.11)") &
704 "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
705 IF (features%uses_atom_chunks)
THEN
706 WRITE (unit=iw, fmt=
"(T2,A,1X,I0,1X,A,1X,I0)") &
707 "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
711 END SUBROUTINE print_native_grid_diagnostics
720 FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group)
RESULT(selected_device)
721 LOGICAL,
INTENT(IN) :: use_cuda
722 INTEGER,
INTENT(IN) :: requested_device
726 INTEGER :: cuda_device_count, iw, pe, selected_device
727 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: selected_devices
731 IF (.NOT. use_cuda)
RETURN
734 cuda_device_count = 0
738 IF (cuda_device_count > 0)
THEN
739 IF (requested_device < 0)
THEN
740 selected_device = mod(group%mepos, cuda_device_count)
742 selected_device = requested_device
745 IF (selected_device >= cuda_device_count)
THEN
746 CALL cp_abort(__location__, &
747 "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
748 "CP2K offload device range.")
752 ALLOCATE (selected_devices(group%num_pe))
753 CALL group%allgather(selected_device, selected_devices)
755 IF (group%mepos /= 0)
RETURN
756 IF (selected_device == logged_cuda_device .AND. &
757 cuda_device_count == logged_cuda_device_count .AND. &
758 group%num_pe == logged_cuda_nproc .AND. &
759 requested_device == logged_cuda_request)
RETURN
762 IF (selected_device >= 0)
THEN
763 WRITE (unit=iw, fmt=
"(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
764 "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
765 "of", cuda_device_count,
"requested", requested_device
767 WRITE (unit=iw, fmt=
"(/,T2,A)") &
768 "SKALA_GPW| Native grid Torch CUDA requested, but no CP2K offload device is visible"
770 WRITE (unit=iw, fmt=
"(T2,A)", advance=
"NO") &
771 "SKALA_GPW| Native grid Torch CUDA rank devices"
772 DO pe = 1, group%num_pe
773 WRITE (unit=iw, fmt=
"(1X,I0,A,I0)", advance=
"NO") pe - 1,
":", selected_devices(pe)
775 WRITE (unit=iw, fmt=*)
777 logged_cuda_device = selected_device
778 logged_cuda_device_count = cuda_device_count
779 logged_cuda_nproc = group%num_pe
780 logged_cuda_request = requested_device
782 END FUNCTION configure_native_grid_cuda
789 SUBROUTINE ensure_model_loaded(model_path, cuda_device)
790 CHARACTER(len=*),
INTENT(IN) :: model_path
791 INTEGER,
INTENT(IN) :: cuda_device
793 IF (cached_model_loaded)
THEN
794 IF (trim(cached_model_path) == trim(model_path) .AND. &
795 cached_model_cuda_device == cuda_device)
RETURN
797 cached_model_loaded = .false.
801 cached_model_path = model_path
802 cached_model_cuda_device = cuda_device
803 cached_model_loaded = .true.
805 END SUBROUTINE ensure_model_loaded
812 SUBROUTINE get_skala_model_path(xc_section, model_path)
814 CHARACTER(len=default_path_length),
INTENT(OUT) :: model_path
816 CHARACTER(len=default_path_length) :: model_key
817 INTEGER :: env_status
820 gauxc_section => get_gauxc_section(xc_section)
821 IF (.NOT.
ASSOCIATED(gauxc_section))
THEN
822 cpabort(
"Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
826 model_key = adjustl(model_path)
828 IF (trim(model_key) ==
"NONE" .OR. trim(model_key) ==
"")
THEN
829 cpabort(
"Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
830 ELSE IF (trim(model_key) ==
"SKALA")
THEN
831 CALL get_environment_variable(
"GAUXC_SKALA_MODEL", model_path, status=env_status)
832 IF (env_status /= 0 .OR. len_trim(model_path) == 0)
THEN
833 cpabort(
"MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
837 END SUBROUTINE get_skala_model_path
844 FUNCTION get_gauxc_section(xc_section)
RESULT(gauxc_section)
851 NULLIFY (gauxc_section)
852 IF (.NOT.
ASSOCIATED(xc_section))
RETURN
855 IF (.NOT.
ASSOCIATED(functionals))
RETURN
861 IF (.NOT.
ASSOCIATED(xc_fun))
EXIT
862 IF (xc_fun%section%name ==
"GAUXC")
THEN
863 gauxc_section => xc_fun
868 END FUNCTION get_gauxc_section
Handles all functions related to the CELL.
various routines to log and control the output. The idea is that decisions about where to log should ...
integer function, public cp_logger_get_default_io_unit(logger)
returns the unit nr for the ionode (-1 on all other processors) skips as well checks if the procs cal...
Defines the basic variable types.
integer, parameter, public dp
integer, parameter, public default_string_length
integer, parameter, public default_path_length
Interface to the message passing library MPI.
Fortran API for the offload package, which is written in C.
subroutine, public offload_set_chosen_device(device_id)
Selects the chosen device to be used.
integer function, public offload_get_device_count()
Returns the number of available devices.
Define the data structure for the particle information.
Manages a pool of grids (to be used for example as tmp objects), but can also be used to instantiate ...
Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
subroutine, public skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, requires_grad, weights, requires_coordinate_grad, use_atom_chunks, route_atom_chunks)
Build a flat SKALA molecular feature dictionary from a local GPW grid.
subroutine, public skala_gpw_feature_release(features)
Release Torch objects and backing arrays owned by a feature bundle.
Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
subroutine, public ensure_native_skala_grid_scope(xc_section)
Enforce the currently implemented native SKALA GPW input scope.
subroutine, public skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, weights, pw_pool, particle_set, cell, compute_virial, virial_xc, just_energy, atom_force)
Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
logical function, public xc_section_uses_native_skala_grid(xc_section)
Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
Small CP2K wrapper around the SKALA TorchScript functional protocol.
subroutine, public skala_torch_model_release(model)
Release a loaded SKALA TorchScript model.
subroutine, public skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
Evaluate the weighted SKALA exchange-correlation energy.
subroutine, public skala_torch_model_load(model, filename)
Load a SKALA TorchScript model and its feature metadata.
Utilities for string manipulations.
elemental subroutine, public uppercase(string)
Convert all lower case characters in a string to upper case.
subroutine, public torch_use_cuda(use_cuda)
Select whether Torch wrappers should use CUDA when available.
subroutine, public torch_tensor_backward_scalar(tensor)
Runs autograd on a scalar Torch tensor.
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
subroutine, public xc_rho_set_create(rho_set, local_bounds, rho_cutoff, drho_cutoff, tau_cutoff)
allocates and does (minimal) initialization of a rho_set
subroutine, public xc_rho_set_release(rho_set, pw_pool)
releases the given rho_set
subroutine, public xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, xc_deriv_method_id, xc_rho_smooth_id, pw_pool, spinflip)
updates the given rho set with the density given by rho_r (and rho_g). The rho set will contain the c...
contains utility functions for the xc package
subroutine, public xc_pw_divergence(xc_deriv_method_id, pw_to_deriv, tmp_g, vxc_g, vxc_r)
Calculates the divergence of pw_to_deriv.
elemental logical function, public xc_requires_tmp_g(xc_deriv_id)
...
Type defining parameters related to the simulation cell.
Manages a pool of grids (to be used for example as tmp objects), but can also be used to instantiate ...
contains a flag for each component of xc_rho_set, so that you can use it to tell which components you...
represent a density, with all the representation and data needed to perform a functional evaluation