(git:e0f0c17)
Loading...
Searching...
No Matches
skala_gpw_functional.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
10! **************************************************************************************************
12 USE cell_types, ONLY: cell_type
19 USE kinds, ONLY: default_path_length,&
21 dp
26 USE pw_methods, ONLY: pw_scale,&
29 USE pw_types, ONLY: pw_c1d_gs_type,&
51 USE xc_util, ONLY: xc_pw_divergence,&
53#include "./base/base_uses.f90"
54
55 IMPLICIT NONE
56
57 PRIVATE
58
59 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
60 INTEGER, PARAMETER, PRIVATE :: ngrad_per_point = 10
61
63
64 TYPE(skala_torch_model_type), SAVE :: cached_model
65 CHARACTER(len=default_path_length), SAVE :: cached_model_path = ""
66 LOGICAL, SAVE :: cached_model_loaded = .false.
67 INTEGER, SAVE :: cached_model_cuda_device = -3
68 INTEGER, SAVE :: logged_cuda_device = -3, &
69 logged_cuda_device_count = -1, &
70 logged_cuda_nproc = -1, &
71 logged_cuda_request = -3
72
73CONTAINS
74
75! **************************************************************************************************
76!> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
77!> \param xc_section ...
78!> \return ...
79! **************************************************************************************************
80 FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
81 TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
82 LOGICAL :: uses_native_grid
83
84 TYPE(section_vals_type), POINTER :: gauxc_section
85
86 uses_native_grid = .false.
87 gauxc_section => get_gauxc_section(xc_section)
88 IF (ASSOCIATED(gauxc_section)) THEN
89 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
90 END IF
91
93
94! **************************************************************************************************
95!> \brief Enforce the currently implemented native SKALA GPW input scope.
96!> \param xc_section ...
97! **************************************************************************************************
98 SUBROUTINE ensure_native_skala_grid_scope(xc_section)
99 TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
100
101 CHARACTER(len=default_path_length) :: model_key, model_name
102 CHARACTER(len=default_string_length) :: functional_key, functional_name
103 INTEGER :: ifun, nfun
104 LOGICAL :: native_grid
105 TYPE(section_vals_type), POINTER :: functionals, gauxc_section, xc_fun
106
107 NULLIFY (gauxc_section)
108 IF (.NOT. ASSOCIATED(xc_section)) THEN
109 cpabort("Native SKALA GPW requires an XC section")
110 END IF
111
112 functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
113 IF (.NOT. ASSOCIATED(functionals)) THEN
114 cpabort("Native SKALA GPW requires an XC_FUNCTIONAL section")
115 END IF
116
117 nfun = 0
118 ifun = 0
119 DO
120 ifun = ifun + 1
121 xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
122 IF (.NOT. ASSOCIATED(xc_fun)) EXIT
123 nfun = nfun + 1
124 IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
125 END DO
126
127 IF (.NOT. ASSOCIATED(gauxc_section)) THEN
128 cpabort("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
129 END IF
130 IF (nfun /= 1) THEN
131 cpabort("Native SKALA GPW requires GAUXC to be the only XC functional")
132 END IF
133
134 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
135 IF (.NOT. native_grid) RETURN
136
137 CALL section_vals_val_get(gauxc_section, "FUNCTIONAL", c_val=functional_name)
138 functional_key = adjustl(functional_name)
139 CALL uppercase(functional_key)
140 IF (trim(functional_key) /= "PBE") THEN
141 cpabort("Native SKALA GPW currently requires GAUXC%FUNCTIONAL PBE")
142 END IF
143
144 CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
145 model_key = adjustl(model_name)
146 CALL uppercase(model_key)
147 IF (trim(model_key) == "NONE" .OR. trim(model_key) == "") THEN
148 cpabort("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
149 END IF
150
151 END SUBROUTINE ensure_native_skala_grid_scope
152
153! **************************************************************************************************
154!> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
155!> \param vxc_rho ...
156!> \param vxc_tau ...
157!> \param exc ...
158!> \param rho_r ...
159!> \param rho_g ...
160!> \param tau ...
161!> \param xc_section ...
162!> \param weights ...
163!> \param pw_pool ...
164!> \param particle_set ...
165!> \param cell ...
166!> \param compute_virial ...
167!> \param virial_xc ...
168!> \param just_energy ...
169!> \param atom_force ...
170! **************************************************************************************************
171 SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
172 weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
173 just_energy, atom_force)
174 TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau
175 REAL(kind=dp), INTENT(OUT) :: exc
176 TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
177 TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER :: rho_g
178 TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: tau
179 TYPE(section_vals_type), POINTER :: xc_section
180 TYPE(pw_r3d_rs_type), POINTER :: weights
181 TYPE(pw_pool_type), POINTER :: pw_pool
182 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
183 TYPE(cell_type), POINTER :: cell
184 LOGICAL, INTENT(IN) :: compute_virial
185 REAL(kind=dp), DIMENSION(3, 3), INTENT(OUT) :: virial_xc
186 LOGICAL, INTENT(IN), OPTIONAL :: just_energy
187 REAL(kind=dp), DIMENSION(:, :), INTENT(OUT), &
188 OPTIONAL :: atom_force
189
190 CHARACTER(len=default_path_length) :: model_path
191 INTEGER :: native_grid_cuda_device, nspins, &
192 phase_handle, selected_cuda_device, &
193 xc_deriv_method_id, xc_rho_smooth_id
194 LOGICAL :: lsd, my_just_energy, native_grid_atom_chunk_routing, native_grid_atom_chunks, &
195 native_grid_diagnostics, native_grid_use_cuda, needs_atom_force
196 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: density_grad, kin_grad
197 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :, :) :: grad_grad
198 TYPE(section_vals_type), POINTER :: gauxc_section
199 TYPE(skala_gpw_feature_type) :: features
200 TYPE(torch_tensor_type) :: atom_coord_grad_t, exc_tensor
201 TYPE(xc_rho_cflags_type) :: needs
202 TYPE(xc_rho_set_type) :: rho_set
203
204 virial_xc = 0.0_dp
205 exc = 0.0_dp
206 my_just_energy = .false.
207 IF (PRESENT(just_energy)) my_just_energy = just_energy
208 needs_atom_force = PRESENT(atom_force)
209 IF (needs_atom_force) atom_force = 0.0_dp
210
211 IF (compute_virial) THEN
212 CALL cp_abort(__location__, &
213 "Native SKALA GPW stress/virial is not implemented yet.")
214 END IF
215 IF (.NOT. ASSOCIATED(rho_g)) THEN
216 CALL cp_abort(__location__, &
217 "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
218 END IF
219 IF (.NOT. ASSOCIATED(tau)) THEN
220 CALL cp_abort(__location__, &
221 "Native SKALA GPW requires the kinetic-energy density.")
222 END IF
223
224 nspins = SIZE(rho_r)
225 lsd = (nspins /= 1)
226 CALL get_skala_model_path(xc_section, model_path)
227 gauxc_section => get_gauxc_section(xc_section)
228 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
229 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
230 i_val=native_grid_cuda_device)
231 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
232 l_val=native_grid_atom_chunks)
233 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
234 l_val=native_grid_atom_chunk_routing)
235 native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
236 native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
237 IF (native_grid_atom_chunks .AND. needs_atom_force) THEN
238 CALL cp_abort(__location__, &
239 "Native SKALA GPW atom chunks are not implemented for atom forces yet.")
240 END IF
241 ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
242 CALL torch_use_cuda(native_grid_use_cuda)
243 selected_cuda_device = configure_native_grid_cuda( &
244 native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
245 CALL ensure_model_loaded(model_path, selected_cuda_device)
246
247 IF (lsd) THEN
248 needs%rho_spin = .true.
249 needs%drho_spin = .true.
250 needs%tau_spin = .true.
251 ELSE
252 needs%rho = .true.
253 needs%drho = .true.
254 needs%tau = .true.
255 END IF
256
257 CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
258 CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
259
260 CALL xc_rho_set_create(rho_set, &
261 rho_r(1)%pw_grid%bounds_local, &
262 rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
263 drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
264 tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
265 CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
266 xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
267
268 CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
269 requires_grad=(.NOT. my_just_energy), weights=weights, &
270 requires_coordinate_grad=needs_atom_force, &
271 use_atom_chunks=native_grid_atom_chunks, &
272 route_atom_chunks=native_grid_atom_chunk_routing)
273 CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
274 IF (native_grid_diagnostics) THEN
275 CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
276 END IF
277 CALL skala_torch_model_get_exc(cached_model, features%inputs, &
278 features%grid_weights_t, exc_tensor, exc)
279 IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
280
281 IF (.NOT. my_just_energy) THEN
282 CALL timeset("skala_gpw_backward", phase_handle)
283 CALL torch_tensor_backward_scalar(exc_tensor)
284 CALL timestop(phase_handle)
285
286 CALL timeset("skala_gpw_grad_fetch", phase_handle)
287 IF (features%uses_atom_chunks) THEN
288 CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
289 density_grad, grad_grad, kin_grad)
290 ELSE
291 CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
292 END IF
293 IF (needs_atom_force) THEN
294 CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
295 rho_r(1)%pw_grid%para%group%mepos == 0)
296 END IF
297 CALL timestop(phase_handle)
298
299 CALL timeset("skala_gpw_vxc_unpack", phase_handle)
300 CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
301 density_grad, grad_grad, kin_grad, &
302 xc_deriv_method_id)
303 CALL timestop(phase_handle)
304
305 CALL timeset("skala_gpw_grad_release", phase_handle)
306 DEALLOCATE (density_grad, grad_grad, kin_grad)
307 IF (needs_atom_force) CALL torch_tensor_release(atom_coord_grad_t)
308 CALL timestop(phase_handle)
309 END IF
310
311 CALL timeset("skala_gpw_cleanup", phase_handle)
312 CALL torch_tensor_release(exc_tensor)
313 CALL skala_gpw_feature_release(features)
314 CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
315 CALL torch_use_cuda(.true.)
316 CALL timestop(phase_handle)
317
318 END SUBROUTINE skala_gpw_eval
319
320! **************************************************************************************************
321!> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
322!> \param atom_force ...
323!> \param features ...
324!> \param atom_coord_grad_t ...
325!> \param root_rank ...
326! **************************************************************************************************
327 SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
328 REAL(kind=dp), DIMENSION(:, :), INTENT(INOUT) :: atom_force
329 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
330 TYPE(torch_tensor_type), INTENT(INOUT) :: atom_coord_grad_t
331 LOGICAL, INTENT(IN) :: root_rank
332
333 REAL(kind=dp), DIMENSION(:, :), POINTER :: atom_coord_grad
334
335 NULLIFY (atom_coord_grad)
336 CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
337 IF (root_rank) THEN
338 CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
339 cpassert(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
340 cpassert(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
341 atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
342 END IF
343
344 END SUBROUTINE add_explicit_coordinate_force
345
346! **************************************************************************************************
347!> \brief Map full Torch feature gradients back to this rank's local grid order.
348!> \param features ...
349!> \param density_grad ...
350!> \param grad_grad ...
351!> \param kin_grad ...
352! **************************************************************************************************
353 SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
354 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
355 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :), &
356 INTENT(OUT) :: density_grad
357 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :, :), &
358 INTENT(OUT) :: grad_grad
359 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :), &
360 INTENT(OUT) :: kin_grad
361
362 INTEGER :: i, j, k, local_row, row
363 REAL(kind=dp), DIMENSION(:, :), POINTER :: density_grad_all, kin_grad_all
364 REAL(kind=dp), DIMENSION(:, :, :), POINTER :: grad_grad_all
365 TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
366
367 NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
368 CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
369 density_grad_all, grad_grad_all, kin_grad_all)
370 cpassert(SIZE(density_grad_all, 1) == features%nflat)
371 cpassert(SIZE(density_grad_all, 2) == 2)
372 cpassert(SIZE(grad_grad_all, 1) == features%nflat)
373 cpassert(SIZE(grad_grad_all, 2) == 3)
374 cpassert(SIZE(grad_grad_all, 3) == 2)
375 cpassert(SIZE(kin_grad_all, 1) == features%nflat)
376 cpassert(SIZE(kin_grad_all, 2) == 2)
377
378 ALLOCATE (density_grad(features%nflat_local, 2), &
379 grad_grad(features%nflat_local, 3, 2), &
380 kin_grad(features%nflat_local, 2))
381 local_row = 0
382 DO k = lbound(features%feature_index, 3), ubound(features%feature_index, 3)
383 DO j = lbound(features%feature_index, 2), ubound(features%feature_index, 2)
384 DO i = lbound(features%feature_index, 1), ubound(features%feature_index, 1)
385 local_row = local_row + 1
386 row = features%feature_index(i, j, k)
387 cpassert(row >= 1 .AND. row <= features%nflat)
388 density_grad(local_row, :) = density_grad_all(row, :)
389 grad_grad(local_row, :, :) = grad_grad_all(row, :, :)
390 kin_grad(local_row, :) = kin_grad_all(row, :)
391 END DO
392 END DO
393 END DO
394 cpassert(local_row == features%nflat_local)
395
396 CALL torch_tensor_release(density_grad_t)
397 CALL torch_tensor_release(grad_grad_t)
398 CALL torch_tensor_release(kin_grad_t)
399
400 END SUBROUTINE fetch_local_feature_grads
401
402! **************************************************************************************************
403!> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
404!> \param features ...
405!> \param TARGET ...
406!> \param route_to_return_positions ...
407! **************************************************************************************************
408 SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions)
409 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
410 REAL(kind=dp), ALLOCATABLE, DIMENSION(:), &
411 INTENT(INOUT) :: target
412 LOGICAL, INTENT(IN) :: route_to_return_positions
413
414 INTEGER :: base, irow, point_pos
415 REAL(kind=dp), DIMENSION(:, :), POINTER :: chunk_density_grad, chunk_kin_grad
416 REAL(kind=dp), DIMENSION(:, :, :), POINTER :: chunk_grad_grad
417 TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
418
419 NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
420 CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
421 chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
422 cpassert(SIZE(TARGET) == ngrad_per_point*features%chunk_feature_count)
423 cpassert(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
424 cpassert(SIZE(chunk_density_grad, 2) == 2)
425 cpassert(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
426 cpassert(SIZE(chunk_grad_grad, 2) == 3)
427 cpassert(SIZE(chunk_grad_grad, 3) == 2)
428 cpassert(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
429 cpassert(SIZE(chunk_kin_grad, 2) == 2)
430
431 TARGET = 0.0_dp
432 DO irow = 1, features%chunk_feature_count
433 IF (route_to_return_positions) THEN
434 point_pos = features%chunk_return_positions(irow)
435 cpassert(point_pos >= 1 .AND. point_pos <= features%chunk_feature_count)
436 ELSE
437 point_pos = irow
438 END IF
439 base = ngrad_per_point*(point_pos - 1)
440 target(base + 1:base + 2) = chunk_density_grad(irow, :)
441 target(base + 3) = chunk_grad_grad(irow, 1, 1)
442 target(base + 4) = chunk_grad_grad(irow, 2, 1)
443 target(base + 5) = chunk_grad_grad(irow, 3, 1)
444 target(base + 6) = chunk_grad_grad(irow, 1, 2)
445 target(base + 7) = chunk_grad_grad(irow, 2, 2)
446 target(base + 8) = chunk_grad_grad(irow, 3, 2)
447 target(base + 9:base + 10) = chunk_kin_grad(irow, :)
448 END DO
449
450 CALL torch_tensor_release(density_grad_t)
451 CALL torch_tensor_release(grad_grad_t)
452 CALL torch_tensor_release(kin_grad_t)
453
454 END SUBROUTINE pack_atom_chunk_grads
455
456! **************************************************************************************************
457!> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
458!> \param features ...
459!> \param density_grad_t ...
460!> \param grad_grad_t ...
461!> \param kin_grad_t ...
462!> \param density_grad ...
463!> \param grad_grad ...
464!> \param kin_grad ...
465! **************************************************************************************************
466 SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
467 density_grad, grad_grad, kin_grad)
468 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
469 TYPE(torch_tensor_type), INTENT(INOUT) :: density_grad_t, grad_grad_t, kin_grad_t
470 REAL(kind=dp), DIMENSION(:, :), POINTER :: density_grad
471 REAL(kind=dp), DIMENSION(:, :, :), POINTER :: grad_grad
472 REAL(kind=dp), DIMENSION(:, :), POINTER :: kin_grad
473
474 NULLIFY (density_grad, grad_grad, kin_grad)
475 CALL torch_tensor_grad(features%density_t, density_grad_t)
476 CALL torch_tensor_grad(features%grad_t, grad_grad_t)
477 CALL torch_tensor_grad(features%kin_t, kin_grad_t)
478 CALL torch_tensor_data_ptr(density_grad_t, density_grad)
479 CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
480 CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
481
482 END SUBROUTINE get_feature_grad_views
483
484! **************************************************************************************************
485!> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
486!> \param features ...
487!> \param group ...
488!> \param density_grad ...
489!> \param grad_grad ...
490!> \param kin_grad ...
491! **************************************************************************************************
492 SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
493 kin_grad)
494 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
495
496 CLASS(mp_comm_type), INTENT(IN) :: group
497 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :), &
498 INTENT(OUT) :: density_grad, kin_grad
499 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :, :), &
500 INTENT(OUT) :: grad_grad
501
502 INTEGER :: base, i, j, k, local_row, &
503 nflat_local, phase_handle, &
504 point_pos, row
505 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: chunk_grad_buffer, global_grad_buffer, &
506 recv_grad_buffer, send_grad_buffer
507
508 cpassert(features%uses_atom_chunks)
509 cpassert(features%chunk_feature_count > 0)
510
511 nflat_local = features%nflat_local
512 IF (features%uses_atom_chunk_routing) THEN
513 cpassert(sum(features%route_point_recv_counts) == features%chunk_feature_count)
514 cpassert(sum(features%route_point_send_counts) == nflat_local)
515
516 ALLOCATE (send_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
517 recv_grad_buffer(ngrad_per_point*nflat_local))
518
519 CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
520 CALL pack_atom_chunk_grads(features, send_grad_buffer, .true.)
521 CALL timestop(phase_handle)
522
523 CALL timeset("skala_gpw_grad_route_comm", phase_handle)
524 CALL group%alltoall(send_grad_buffer, features%route_grad_return_send_counts, &
525 features%route_grad_return_send_displs, recv_grad_buffer, &
526 features%route_grad_return_recv_counts, &
527 features%route_grad_return_recv_displs)
528 CALL timestop(phase_handle)
529
530 CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
531 ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
532 kin_grad(nflat_local, 2))
533 density_grad = 0.0_dp
534 grad_grad = 0.0_dp
535 kin_grad = 0.0_dp
536 DO point_pos = 1, nflat_local
537 local_row = features%route_send_local_rows(point_pos)
538 cpassert(local_row >= 1 .AND. local_row <= nflat_local)
539 base = ngrad_per_point*(point_pos - 1)
540 density_grad(local_row, :) = recv_grad_buffer(base + 1:base + 2)
541 grad_grad(local_row, 1, 1) = recv_grad_buffer(base + 3)
542 grad_grad(local_row, 2, 1) = recv_grad_buffer(base + 4)
543 grad_grad(local_row, 3, 1) = recv_grad_buffer(base + 5)
544 grad_grad(local_row, 1, 2) = recv_grad_buffer(base + 6)
545 grad_grad(local_row, 2, 2) = recv_grad_buffer(base + 7)
546 grad_grad(local_row, 3, 2) = recv_grad_buffer(base + 8)
547 kin_grad(local_row, :) = recv_grad_buffer(base + 9:base + 10)
548 END DO
549 CALL timestop(phase_handle)
550
551 DEALLOCATE (recv_grad_buffer, send_grad_buffer)
552 ELSE
553 ALLOCATE (chunk_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
554 global_grad_buffer(ngrad_per_point*features%nflat))
555 CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
556 CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .false.)
557 CALL timestop(phase_handle)
558
559 CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
560 CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
561 features%chunk_grad_counts, features%chunk_grad_displs)
562 CALL timestop(phase_handle)
563
564 CALL timeset("skala_gpw_grad_scatter", phase_handle)
565 ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
566 kin_grad(nflat_local, 2))
567 local_row = 0
568 DO k = lbound(features%feature_index, 3), ubound(features%feature_index, 3)
569 DO j = lbound(features%feature_index, 2), ubound(features%feature_index, 2)
570 DO i = lbound(features%feature_index, 1), ubound(features%feature_index, 1)
571 local_row = local_row + 1
572 row = features%feature_index(i, j, k)
573 cpassert(row >= 1 .AND. row <= features%nflat)
574 base = ngrad_per_point*(row - 1)
575 density_grad(local_row, :) = global_grad_buffer(base + 1:base + 2)
576 grad_grad(local_row, 1, 1) = global_grad_buffer(base + 3)
577 grad_grad(local_row, 2, 1) = global_grad_buffer(base + 4)
578 grad_grad(local_row, 3, 1) = global_grad_buffer(base + 5)
579 grad_grad(local_row, 1, 2) = global_grad_buffer(base + 6)
580 grad_grad(local_row, 2, 2) = global_grad_buffer(base + 7)
581 grad_grad(local_row, 3, 2) = global_grad_buffer(base + 8)
582 kin_grad(local_row, :) = global_grad_buffer(base + 9:base + 10)
583 END DO
584 END DO
585 END DO
586 CALL timestop(phase_handle)
587 DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
588
589 END IF
590
591 END SUBROUTINE fetch_and_gather_atom_chunk_grads
592
593! **************************************************************************************************
594!> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
595!> \param vxc_rho ...
596!> \param vxc_tau ...
597!> \param rho_r ...
598!> \param pw_pool ...
599!> \param density_grad ...
600!> \param grad_grad ...
601!> \param kin_grad ...
602!> \param xc_deriv_method_id ...
603! **************************************************************************************************
604 SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
605 density_grad, grad_grad, kin_grad, &
606 xc_deriv_method_id)
607 TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau, rho_r
608 TYPE(pw_pool_type), POINTER :: pw_pool
609 REAL(kind=dp), DIMENSION(:, :), INTENT(IN) :: density_grad
610 REAL(kind=dp), DIMENSION(:, :, :), INTENT(IN) :: grad_grad
611 REAL(kind=dp), DIMENSION(:, :), INTENT(IN) :: kin_grad
612 INTEGER, INTENT(IN) :: xc_deriv_method_id
613
614 INTEGER :: i, ipt, ispin, j, k, nspins
615 INTEGER, DIMENSION(2, 3) :: bo
616 REAL(kind=dp) :: dvol_inv
617 TYPE(pw_c1d_gs_type) :: tmp_g, vxc_g
618 TYPE(pw_r3d_rs_type), DIMENSION(3) :: grad_pw
619
620 nspins = SIZE(rho_r)
621 bo = rho_r(1)%pw_grid%bounds_local
622 dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
623
624 ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
625 DO ispin = 1, nspins
626 CALL pw_pool%create_pw(vxc_rho(ispin))
627 CALL pw_pool%create_pw(vxc_tau(ispin))
628 CALL pw_zero(vxc_rho(ispin))
629 CALL pw_zero(vxc_tau(ispin))
630 END DO
631
632 IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
633 CALL pw_pool%create_pw(vxc_g)
634 IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
635 END IF
636
637 DO ispin = 1, nspins
638 DO i = 1, 3
639 CALL pw_pool%create_pw(grad_pw(i))
640 CALL pw_zero(grad_pw(i))
641 END DO
642
643 ipt = 0
644 DO k = bo(1, 3), bo(2, 3)
645 DO j = bo(1, 2), bo(2, 2)
646 DO i = bo(1, 1), bo(2, 1)
647 ipt = ipt + 1
648 IF (nspins == 1) THEN
649 vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
650 (density_grad(ipt, 1) + density_grad(ipt, 2))
651 vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
652 (kin_grad(ipt, 1) + kin_grad(ipt, 2))
653 grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
654 (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
655 grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
656 (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
657 grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
658 (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
659 ELSE
660 vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
661 vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
662 grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
663 grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
664 grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
665 END IF
666 END DO
667 END DO
668 END DO
669
670 DO i = 1, 3
671 CALL pw_scale(grad_pw(i), -1.0_dp)
672 END DO
673 CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
674
675 DO i = 1, 3
676 CALL pw_pool%give_back_pw(grad_pw(i))
677 END DO
678 END DO
679
680 IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
681 IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
682
683 END SUBROUTINE build_vxc_from_feature_grads
684
685! **************************************************************************************************
686!> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
687!> \param features ...
688!> \param print_active ...
689! **************************************************************************************************
690 SUBROUTINE print_native_grid_diagnostics(features, print_active)
691 TYPE(skala_gpw_feature_type), INTENT(IN) :: features
692 LOGICAL, INTENT(IN) :: print_active
693
694 INTEGER :: iw
695
696 IF (.NOT. print_active) RETURN
697
699 WRITE (unit=iw, fmt="(/,T2,A,1X,ES19.11)") &
700 "SKALA_GPW| Native grid feature electrons", features%electron_count
701 WRITE (unit=iw, fmt="(T2,A,1X,ES19.11)") &
702 "SKALA_GPW| Native grid feature spin moment", features%spin_moment
703 WRITE (unit=iw, fmt="(T2,A,1X,ES19.11)") &
704 "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
705 IF (features%uses_atom_chunks) THEN
706 WRITE (unit=iw, fmt="(T2,A,1X,I0,1X,A,1X,I0)") &
707 "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
708 "of", features%nflat
709 END IF
710
711 END SUBROUTINE print_native_grid_diagnostics
712
713! **************************************************************************************************
714!> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
715!> \param use_cuda ...
716!> \param requested_device ...
717!> \param group ...
718!> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
719! **************************************************************************************************
720 FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
721 LOGICAL, INTENT(IN) :: use_cuda
722 INTEGER, INTENT(IN) :: requested_device
723
724 CLASS(mp_comm_type), INTENT(IN) :: group
725
726 INTEGER :: cuda_device_count, iw, pe, selected_device
727 INTEGER, ALLOCATABLE, DIMENSION(:) :: selected_devices
728
729 selected_device = -1
730
731 IF (.NOT. use_cuda) RETURN
732
733 IF (.NOT. torch_cuda_is_available()) THEN
734 cuda_device_count = 0
735 ELSE
736 cuda_device_count = offload_get_device_count()
737 END IF
738 IF (cuda_device_count > 0) THEN
739 IF (requested_device < 0) THEN
740 selected_device = mod(group%mepos, cuda_device_count)
741 ELSE
742 selected_device = requested_device
743 END IF
744 END IF
745 IF (selected_device >= cuda_device_count) THEN
746 CALL cp_abort(__location__, &
747 "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
748 "CP2K offload device range.")
749 END IF
750 IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
751
752 ALLOCATE (selected_devices(group%num_pe))
753 CALL group%allgather(selected_device, selected_devices)
754
755 IF (group%mepos /= 0) RETURN
756 IF (selected_device == logged_cuda_device .AND. &
757 cuda_device_count == logged_cuda_device_count .AND. &
758 group%num_pe == logged_cuda_nproc .AND. &
759 requested_device == logged_cuda_request) RETURN
760
762 IF (selected_device >= 0) THEN
763 WRITE (unit=iw, fmt="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
764 "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
765 "of", cuda_device_count, "requested", requested_device
766 ELSE
767 WRITE (unit=iw, fmt="(/,T2,A)") &
768 "SKALA_GPW| Native grid Torch CUDA requested, but no CP2K offload device is visible"
769 END IF
770 WRITE (unit=iw, fmt="(T2,A)", advance="NO") &
771 "SKALA_GPW| Native grid Torch CUDA rank devices"
772 DO pe = 1, group%num_pe
773 WRITE (unit=iw, fmt="(1X,I0,A,I0)", advance="NO") pe - 1, ":", selected_devices(pe)
774 END DO
775 WRITE (unit=iw, fmt=*)
776
777 logged_cuda_device = selected_device
778 logged_cuda_device_count = cuda_device_count
779 logged_cuda_nproc = group%num_pe
780 logged_cuda_request = requested_device
781
782 END FUNCTION configure_native_grid_cuda
783
784! **************************************************************************************************
785!> \brief Load and cache the TorchScript SKALA model.
786!> \param model_path ...
787!> \param cuda_device ...
788! **************************************************************************************************
789 SUBROUTINE ensure_model_loaded(model_path, cuda_device)
790 CHARACTER(len=*), INTENT(IN) :: model_path
791 INTEGER, INTENT(IN) :: cuda_device
792
793 IF (cached_model_loaded) THEN
794 IF (trim(cached_model_path) == trim(model_path) .AND. &
795 cached_model_cuda_device == cuda_device) RETURN
796 CALL skala_torch_model_release(cached_model)
797 cached_model_loaded = .false.
798 END IF
799
800 CALL skala_torch_model_load(cached_model, trim(model_path))
801 cached_model_path = model_path
802 cached_model_cuda_device = cuda_device
803 cached_model_loaded = .true.
804
805 END SUBROUTINE ensure_model_loaded
806
807! **************************************************************************************************
808!> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
809!> \param xc_section ...
810!> \param model_path ...
811! **************************************************************************************************
812 SUBROUTINE get_skala_model_path(xc_section, model_path)
813 TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
814 CHARACTER(len=default_path_length), INTENT(OUT) :: model_path
815
816 CHARACTER(len=default_path_length) :: model_key
817 INTEGER :: env_status
818 TYPE(section_vals_type), POINTER :: gauxc_section
819
820 gauxc_section => get_gauxc_section(xc_section)
821 IF (.NOT. ASSOCIATED(gauxc_section)) THEN
822 cpabort("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
823 END IF
824
825 CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
826 model_key = adjustl(model_path)
827 CALL uppercase(model_key)
828 IF (trim(model_key) == "NONE" .OR. trim(model_key) == "") THEN
829 cpabort("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
830 ELSE IF (trim(model_key) == "SKALA") THEN
831 CALL get_environment_variable("GAUXC_SKALA_MODEL", model_path, status=env_status)
832 IF (env_status /= 0 .OR. len_trim(model_path) == 0) THEN
833 cpabort("MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
834 END IF
835 END IF
836
837 END SUBROUTINE get_skala_model_path
838
839! **************************************************************************************************
840!> \brief Return the first GAUXC functional subsection, if present.
841!> \param xc_section ...
842!> \return ...
843! **************************************************************************************************
844 FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
845 TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
846 TYPE(section_vals_type), POINTER :: gauxc_section
847
848 INTEGER :: ifun
849 TYPE(section_vals_type), POINTER :: functionals, xc_fun
850
851 NULLIFY (gauxc_section)
852 IF (.NOT. ASSOCIATED(xc_section)) RETURN
853
854 functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
855 IF (.NOT. ASSOCIATED(functionals)) RETURN
856
857 ifun = 0
858 DO
859 ifun = ifun + 1
860 xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
861 IF (.NOT. ASSOCIATED(xc_fun)) EXIT
862 IF (xc_fun%section%name == "GAUXC") THEN
863 gauxc_section => xc_fun
864 EXIT
865 END IF
866 END DO
867
868 END FUNCTION get_gauxc_section
869
870END MODULE skala_gpw_functional
Handles all functions related to the CELL.
Definition cell_types.F:15
various routines to log and control the output. The idea is that decisions about where to log should ...
integer function, public cp_logger_get_default_io_unit(logger)
returns the unit nr for the ionode (-1 on all other processors) skips as well checks if the procs cal...
objects that represent the structure of input sections and the data contained in an input section
real(kind=dp) function, public section_get_rval(section_vals, keyword_name)
...
type(section_vals_type) function, pointer, public section_vals_get_subs_vals2(section_vals, i_section, i_rep_section)
returns the values of the n-th non default subsection (null if no such section exists (not so many no...
recursive type(section_vals_type) function, pointer, public section_vals_get_subs_vals(section_vals, subsection_name, i_rep_section, can_return_null)
returns the values of the requested subsection
subroutine, public section_vals_val_get(section_vals, keyword_name, i_rep_section, i_rep_val, n_rep_val, val, l_val, i_val, r_val, c_val, l_vals, i_vals, r_vals, c_vals, explicit)
returns the requested value
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
integer, parameter, public default_path_length
Definition kinds.F:58
Interface to the message passing library MPI.
Fortran API for the offload package, which is written in C.
Definition offload_api.F:12
subroutine, public offload_set_chosen_device(device_id)
Selects the chosen device to be used.
integer function, public offload_get_device_count()
Returns the number of available devices.
Define the data structure for the particle information.
Manages a pool of grids (to be used for example as tmp objects), but can also be used to instantiate ...
Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
subroutine, public skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, requires_grad, weights, requires_coordinate_grad, use_atom_chunks, route_atom_chunks)
Build a flat SKALA molecular feature dictionary from a local GPW grid.
subroutine, public skala_gpw_feature_release(features)
Release Torch objects and backing arrays owned by a feature bundle.
Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
subroutine, public ensure_native_skala_grid_scope(xc_section)
Enforce the currently implemented native SKALA GPW input scope.
subroutine, public skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, weights, pw_pool, particle_set, cell, compute_virial, virial_xc, just_energy, atom_force)
Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
logical function, public xc_section_uses_native_skala_grid(xc_section)
Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
Small CP2K wrapper around the SKALA TorchScript functional protocol.
subroutine, public skala_torch_model_release(model)
Release a loaded SKALA TorchScript model.
subroutine, public skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
Evaluate the weighted SKALA exchange-correlation energy.
subroutine, public skala_torch_model_load(model, filename)
Load a SKALA TorchScript model and its feature metadata.
Utilities for string manipulations.
elemental subroutine, public uppercase(string)
Convert all lower case characters in a string to upper case.
subroutine, public torch_use_cuda(use_cuda)
Select whether Torch wrappers should use CUDA when available.
Definition torch_api.F:1473
subroutine, public torch_tensor_backward_scalar(tensor)
Runs autograd on a scalar Torch tensor.
Definition torch_api.F:1422
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
Definition torch_api.F:1494
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition torch_api.F:1966
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:1580
contains the structure
contains the structure
subroutine, public xc_rho_set_create(rho_set, local_bounds, rho_cutoff, drho_cutoff, tau_cutoff)
allocates and does (minimal) initialization of a rho_set
subroutine, public xc_rho_set_release(rho_set, pw_pool)
releases the given rho_set
subroutine, public xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, xc_deriv_method_id, xc_rho_smooth_id, pw_pool, spinflip)
updates the given rho set with the density given by rho_r (and rho_g). The rho set will contain the c...
contains utility functions for the xc package
Definition xc_util.F:14
subroutine, public xc_pw_divergence(xc_deriv_method_id, pw_to_deriv, tmp_g, vxc_g, vxc_r)
Calculates the divergence of pw_to_deriv.
Definition xc_util.F:253
elemental logical function, public xc_requires_tmp_g(xc_deriv_id)
...
Definition xc_util.F:58
Type defining parameters related to the simulation cell.
Definition cell_types.F:60
Manages a pool of grids (to be used for example as tmp objects), but can also be used to instantiate ...
contains a flag for each component of xc_rho_set, so that you can use it to tell which components you...
represent a density, with all the representation and data needed to perform a functional evaluation