(git:06f838d)
Loading...
Searching...
No Matches
skala_gpw_features.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
10! **************************************************************************************************
12 USE cell_types, ONLY: cell_type,&
13 pbc
15 USE kinds, ONLY: dp,&
16 int_8
20 USE pw_types, ONLY: pw_r3d_rs_type
21 USE torch_api, ONLY: &
27#include "./base/base_uses.f90"
28
29 IMPLICIT NONE
30
31 PRIVATE
32
33 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
34 REAL(KIND=dp), PARAMETER, PRIVATE :: layout_tol = 1.0e-12_dp
35 INTEGER, PARAMETER, PRIVATE :: ndynamic_per_point = 10, nstatic_per_point = 4, &
36 ngrad_per_point = 10
37
39
40 TYPE skala_gpw_layout_cache_type
41 INTEGER :: chunk_atom_begin = 1, chunk_atom_end = 0, &
42 chunk_feature_begin = 1, &
43 chunk_feature_count = 0, chunk_natom = 0, &
44 natom = 0, nflat = 0, nflat_local = 0, &
45 nproc = 0
46 INTEGER, DIMENSION(2, 3) :: bo = 0, bounds = 0
47 INTEGER, DIMENSION(3) :: npts = 0
48 INTEGER, ALLOCATABLE, DIMENSION(:) :: dynamic_counts, dynamic_displs, &
49 chunk_feature_counts, chunk_feature_displs, &
50 chunk_grad_counts, chunk_grad_displs, &
51 feature_counts, feature_displs, &
52 global_to_feature, route_dynamic_recv_counts, &
53 route_dynamic_recv_displs, &
54 route_dynamic_send_counts, &
55 route_dynamic_send_displs, &
56 route_grad_return_recv_counts, &
57 route_grad_return_recv_displs, &
58 route_grad_return_send_counts, &
59 route_grad_return_send_displs, &
60 route_local_dest, route_meta_recv_counts, &
61 route_meta_recv_displs, &
62 route_meta_send_counts, &
63 route_meta_send_displs, &
64 route_point_recv_counts, &
65 route_point_recv_displs, &
66 route_point_send_counts, &
67 route_point_send_displs, &
68 route_send_local_rows
69 INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
70 INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
71 chunk_feature_indices
72 INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: local_feature_indices
73 INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape, &
74 chunk_atomic_grid_size_bound_shape
75 TYPE(torch_dict_type) :: chunk_static_inputs
76 TYPE(torch_dict_type) :: static_inputs
77 TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
78 TYPE(torch_tensor_type) :: atomic_grid_sizes_t
79 TYPE(torch_tensor_type) :: atomic_grid_weights_t
80 TYPE(torch_tensor_type) :: chunk_atomic_grid_size_bound_shape_t
81 TYPE(torch_tensor_type) :: chunk_atomic_grid_sizes_t
82 TYPE(torch_tensor_type) :: chunk_atomic_grid_weights_t
83 TYPE(torch_tensor_type) :: chunk_coarse_0_atomic_coords_t
84 TYPE(torch_tensor_type) :: chunk_density_t
85 TYPE(torch_tensor_type) :: chunk_feature_indices_t
86 TYPE(torch_tensor_type) :: chunk_grad_t
87 TYPE(torch_tensor_type) :: chunk_grid_coords_t
88 TYPE(torch_tensor_type) :: chunk_grid_weights_t
89 TYPE(torch_tensor_type) :: chunk_kin_t
90 TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
91 TYPE(torch_tensor_type) :: density_t
92 TYPE(torch_tensor_type) :: grid_coords_t
93 TYPE(torch_tensor_type) :: grid_weights_t
94 TYPE(torch_tensor_type) :: grad_t
95 TYPE(torch_tensor_type) :: kin_t
96 TYPE(torch_tensor_type) :: local_feature_indices_t
97 REAL(KIND=dp) :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
98 weight_sumsq = 0.0_dp
99 REAL(KIND=dp), DIMENSION(3, 3) :: cell_hmat = 0.0_dp, dh = 0.0_dp
100 REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, chunk_atomic_grid_weights, &
101 chunk_grid_weights, grid_weights
102 REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords, chunk_coarse_0_atomic_coords, &
103 chunk_grid_coords, coarse_0_atomic_coords, &
104 grid_coords
105 LOGICAL :: active = .false., has_weights = .false., &
106 chunk_dynamic_tensors_active = .false., &
107 chunk_static_tensors_active = .false., &
108 dynamic_tensors_active = .false., &
109 static_tensors_active = .false.
110 END TYPE skala_gpw_layout_cache_type
111
113 INTEGER :: chunk_feature_count = 0, nflat = 0, &
114 nflat_local = 0
115 TYPE(torch_dict_type) :: inputs
116 TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
117 TYPE(torch_tensor_type) :: atomic_grid_sizes_t
118 TYPE(torch_tensor_type) :: atomic_grid_weights_t
119 TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
120 TYPE(torch_tensor_type) :: density_t
121 TYPE(torch_tensor_type) :: grad_t
122 TYPE(torch_tensor_type) :: grid_coords_t
123 TYPE(torch_tensor_type) :: grid_weights_t
124 TYPE(torch_tensor_type) :: kin_t
125 TYPE(torch_tensor_type) :: local_feature_indices_t
126 INTEGER, ALLOCATABLE, DIMENSION(:) :: chunk_grad_counts, chunk_grad_displs, &
127 chunk_return_positions, &
128 chunk_return_ranks, chunk_return_rows, &
129 route_grad_return_recv_counts, &
130 route_grad_return_recv_displs, &
131 route_grad_return_send_counts, &
132 route_grad_return_send_displs, &
133 route_point_recv_counts, &
134 route_point_recv_displs, &
135 route_point_send_counts, &
136 route_point_send_displs, &
137 route_send_local_rows
138 INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
139 INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
140 INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
141 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
142 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: chunk_density, chunk_kin, &
143 coarse_0_atomic_coords, density, &
144 grid_coords, kin
145 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :, :) :: chunk_grad, grad
146 REAL(kind=dp) :: electron_count = 0.0_dp, &
147 grid_weight_sum = 0.0_dp, &
148 spin_moment = 0.0_dp
149 LOGICAL :: active = .false., owns_coordinate_tensor = .false., &
150 owns_dynamic_tensors = .true., &
151 owns_static_tensors = .true., &
152 uses_atom_chunk_routing = .false., &
153 uses_atom_chunks = .false.
155
156 TYPE(skala_gpw_layout_cache_type), SAVE :: cached_layout
157
158CONTAINS
159
160! **************************************************************************************************
161!> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
162!> \param features ...
163!> \param rho_set ...
164!> \param rho_r ...
165!> \param particle_set ...
166!> \param cell ...
167!> \param requires_grad ...
168!> \param weights ...
169!> \param requires_coordinate_grad ...
170!> \param use_atom_chunks ...
171!> \param route_atom_chunks ...
172! **************************************************************************************************
173 SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
174 requires_grad, weights, requires_coordinate_grad, &
175 use_atom_chunks, route_atom_chunks)
176 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
177 TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
178 TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN) :: rho_r
179 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
180 TYPE(cell_type), POINTER :: cell
181 LOGICAL, INTENT(IN), OPTIONAL :: requires_grad
182 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
183 LOGICAL, INTENT(IN), OPTIONAL :: requires_coordinate_grad, &
184 use_atom_chunks, route_atom_chunks
185
186 INTEGER :: handle, i, ipt, ispin, j, k, local_row, &
187 nflat, nflat_local, nspins, &
188 phase_handle, real_base, row
189 INTEGER, DIMENSION(2, 3) :: bo
190 LOGICAL :: can_use_atom_chunks, my_requires_coordinate_grad, my_requires_grad, &
191 my_route_atom_chunks, my_use_atom_chunks
192 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: global_dynamic, local_dynamic
193 REAL(kind=dp), DIMENSION(:, :, :), POINTER :: rho, rhoa, rhob, tau_a, tau_b, tau_total
194 TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
195 TYPE(pw_grid_type), POINTER :: pw_grid
196
197 CALL timeset("skala_gpw_feature_build", handle)
198
199 my_requires_grad = .false.
200 IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
201 my_requires_coordinate_grad = .false.
202 IF (PRESENT(requires_coordinate_grad)) &
203 my_requires_coordinate_grad = requires_coordinate_grad
204 my_use_atom_chunks = .false.
205 IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
206 my_route_atom_chunks = .false.
207 IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
208
209 cpassert(ASSOCIATED(cell))
210 cpassert(ASSOCIATED(particle_set))
211 cpassert(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
212 cpassert(ASSOCIATED(rho_r(1)%pw_grid))
213 pw_grid => rho_r(1)%pw_grid
214
215 nspins = SIZE(rho_r)
216 bo = pw_grid%bounds_local
217 nflat_local = pw_grid%ngpts_local
218
219 CALL timeset("skala_gpw_pre_release", phase_handle)
220 CALL skala_gpw_feature_release(features)
221 CALL timestop(phase_handle)
222
223 CALL timeset("skala_gpw_layout_cache", phase_handle)
224 CALL ensure_layout_cache(pw_grid, particle_set, cell, weights)
225 CALL timestop(phase_handle)
226 nflat = cached_layout%nflat
227 can_use_atom_chunks = my_use_atom_chunks .AND. cached_layout%nproc > 1 .AND. &
228 cached_layout%chunk_feature_count > 0
229 ALLOCATE (local_dynamic(ndynamic_per_point*nflat_local))
230 local_dynamic = 0.0_dp
231
232 CALL timeset("skala_gpw_pack_local", phase_handle)
233 IF (nspins == 1) THEN
234 CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
235 ELSE
236 CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
237 tau_a=tau_a, tau_b=tau_b)
238 END IF
239
240 local_row = 0
241 DO k = bo(1, 3), bo(2, 3)
242 DO j = bo(1, 2), bo(2, 2)
243 DO i = bo(1, 1), bo(2, 1)
244 local_row = local_row + 1
245 real_base = ndynamic_per_point*(local_row - 1)
246
247 IF (nspins == 1) THEN
248 local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
249 local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
250 DO ispin = 1, 2
251 local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = 0.5_dp*drho(1)%array(i, j, k)
252 local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = 0.5_dp*drho(2)%array(i, j, k)
253 local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = 0.5_dp*drho(3)%array(i, j, k)
254 local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
255 END DO
256 ELSE
257 local_dynamic(real_base + 1) = rhoa(i, j, k)
258 local_dynamic(real_base + 2) = rhob(i, j, k)
259 local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
260 local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
261 local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
262 local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
263 local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
264 local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
265 local_dynamic(real_base + 9) = tau_a(i, j, k)
266 local_dynamic(real_base + 10) = tau_b(i, j, k)
267 END IF
268 END DO
269 END DO
270 END DO
271 CALL timestop(phase_handle)
272
273 CALL timeset("skala_gpw_copy_layout", phase_handle)
274 CALL copy_cached_layout(features, my_requires_coordinate_grad)
275 CALL timestop(phase_handle)
276
277 IF (can_use_atom_chunks .AND. my_route_atom_chunks) THEN
278 CALL timeset("skala_gpw_route_dyn", phase_handle)
279 CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group)
280 features%uses_atom_chunk_routing = .true.
281 features%uses_atom_chunks = .true.
282 CALL timestop(phase_handle)
283 ELSE
284 ALLOCATE (global_dynamic(ndynamic_per_point*nflat))
285 CALL timeset("skala_gpw_allgatherv", phase_handle)
286 CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
287 cached_layout%dynamic_counts, &
288 cached_layout%dynamic_displs)
289 CALL timestop(phase_handle)
290
291 CALL timeset("skala_gpw_reorder_dyn", phase_handle)
292 ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
293 features%kin(nflat, 2))
294 features%density = 0.0_dp
295 features%grad = 0.0_dp
296 features%kin = 0.0_dp
297
298 DO ipt = 1, nflat
299 row = cached_layout%global_to_feature(ipt)
300 real_base = ndynamic_per_point*(ipt - 1)
301 features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
302 features%grad(row, 1, 1) = global_dynamic(real_base + 3)
303 features%grad(row, 2, 1) = global_dynamic(real_base + 4)
304 features%grad(row, 3, 1) = global_dynamic(real_base + 5)
305 features%grad(row, 1, 2) = global_dynamic(real_base + 6)
306 features%grad(row, 2, 2) = global_dynamic(real_base + 7)
307 features%grad(row, 3, 2) = global_dynamic(real_base + 8)
308 features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
309 END DO
310 CALL timestop(phase_handle)
311 END IF
312
313 CALL timeset("skala_gpw_feature_sums", phase_handle)
314 IF (features%uses_atom_chunks) THEN
315 features%electron_count = sum((features%chunk_density(:, 1) + &
316 features%chunk_density(:, 2))* &
317 cached_layout%chunk_grid_weights)
318 features%spin_moment = sum((features%chunk_density(:, 1) - &
319 features%chunk_density(:, 2))* &
320 cached_layout%chunk_grid_weights)
321 CALL pw_grid%para%group%sum(features%electron_count)
322 CALL pw_grid%para%group%sum(features%spin_moment)
323 ELSE
324 features%electron_count = sum((features%density(:, 1) + features%density(:, 2))* &
325 features%grid_weights)
326 features%spin_moment = sum((features%density(:, 1) - features%density(:, 2))* &
327 features%grid_weights)
328 END IF
329 features%grid_weight_sum = sum(features%grid_weights)
330 CALL timestop(phase_handle)
331
332 CALL timeset("skala_gpw_tensor_update", phase_handle)
333 IF (can_use_atom_chunks .AND. .NOT. features%uses_atom_chunks) THEN
334 CALL extract_atom_chunk_dynamics(features)
335 features%uses_atom_chunks = .true.
336 END IF
337 CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
338 features%uses_atom_chunks)
339 CALL timestop(phase_handle)
340 features%active = .true.
341
342 IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
343 DEALLOCATE (local_dynamic)
344 CALL timestop(handle)
345
346 END SUBROUTINE skala_gpw_feature_build
347
348! **************************************************************************************************
349!> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
350!> \param pw_grid ...
351!> \param particle_set ...
352!> \param cell ...
353!> \param weights ...
354! **************************************************************************************************
355 SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights)
356 TYPE(pw_grid_type), POINTER :: pw_grid
357 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
358 TYPE(cell_type), POINTER :: cell
359 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
360
361 INTEGER :: phase_handle
362 LOGICAL :: cache_matches
363
364 IF (PRESENT(weights)) THEN
365 CALL timeset("skala_gpw_layout_match", phase_handle)
366 cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights)
367 CALL timestop(phase_handle)
368 IF (cache_matches) RETURN
369 CALL timeset("skala_gpw_layout_rebuild", phase_handle)
370 CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights)
371 CALL timestop(phase_handle)
372 ELSE
373 CALL timeset("skala_gpw_layout_match", phase_handle)
374 cache_matches = layout_cache_matches(pw_grid, particle_set, cell)
375 CALL timestop(phase_handle)
376 IF (cache_matches) RETURN
377 CALL timeset("skala_gpw_layout_rebuild", phase_handle)
378 CALL rebuild_layout_cache(pw_grid, particle_set, cell)
379 CALL timestop(phase_handle)
380 END IF
381
382 END SUBROUTINE ensure_layout_cache
383
384! **************************************************************************************************
385!> \brief Check whether the current static layout cache can be reused.
386!> \param pw_grid ...
387!> \param particle_set ...
388!> \param cell ...
389!> \param weights ...
390!> \return ...
391! **************************************************************************************************
392 FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights) RESULT(matches)
393 TYPE(pw_grid_type), POINTER :: pw_grid
394 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
395 TYPE(cell_type), POINTER :: cell
396 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
397 LOGICAL :: matches
398
399 INTEGER :: iatom
400 LOGICAL :: weights_match
401
402 matches = .false.
403 IF (.NOT. cached_layout%active) RETURN
404 IF (cached_layout%natom /= SIZE(particle_set)) RETURN
405 IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
406 IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
407 IF (any(cached_layout%bo /= pw_grid%bounds_local)) RETURN
408 IF (any(cached_layout%bounds /= pw_grid%bounds)) RETURN
409 IF (any(cached_layout%npts /= pw_grid%npts)) RETURN
410 IF (abs(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
411 IF (any(abs(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
412 IF (any(abs(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
413 IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
414
415 DO iatom = 1, SIZE(particle_set)
416 IF (any(abs(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
417 END DO
418
419 IF (PRESENT(weights)) THEN
420 weights_match = layout_weights_match(pw_grid, weights)
421 ELSE
422 weights_match = layout_weights_match(pw_grid)
423 END IF
424 IF (.NOT. weights_match) RETURN
425
426 matches = .true.
427
428 END FUNCTION layout_cache_matches
429
430! **************************************************************************************************
431!> \brief Check whether current optional integration weights match the cached static tensors.
432!> \param pw_grid ...
433!> \param weights ...
434!> \return ...
435! **************************************************************************************************
436 FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
437 TYPE(pw_grid_type), POINTER :: pw_grid
438 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
439 LOGICAL :: matches
440
441 LOGICAL :: has_weights
442 REAL(kind=dp) :: weight_sum, weight_sumsq
443
444 matches = .false.
445 mark_used(pw_grid)
446 IF (PRESENT(weights)) THEN
447 CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
448 ELSE
449 CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
450 weight_sumsq=weight_sumsq)
451 END IF
452
453 IF (cached_layout%has_weights .NEQV. has_weights) RETURN
454 IF (abs(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
455 IF (abs(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
456
457 matches = .true.
458
459 END FUNCTION layout_weights_match
460
461! **************************************************************************************************
462!> \brief Build the static SKALA layout cache.
463!> \param pw_grid ...
464!> \param particle_set ...
465!> \param cell ...
466!> \param weights ...
467! **************************************************************************************************
468 SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights)
469 TYPE(pw_grid_type), POINTER :: pw_grid
470 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
471 TYPE(cell_type), POINTER :: cell
472 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
473
474 INTEGER :: i, iatom, ipt, j, k, local_row, max_grid_size, natom, nflat, nflat_local, nproc, &
475 owner, pe, pe_index, phase_handle, row, static_base
476 INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
477 chunk_atom_end, feature_counts, feature_displs, global_owner, local_owner, &
478 local_to_global, static_counts, static_displs
479 INTEGER, DIMENSION(2, 3) :: bo
480 LOGICAL :: has_weights
481 REAL(kind=dp) :: weight_sum, weight_sumsq
482 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: global_static, local_static
483 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
484 REAL(kind=dp), DIMENSION(3) :: grid_point, owner_coord
485
486 CALL release_layout_cache(cached_layout)
487
488 natom = SIZE(particle_set)
489 bo = pw_grid%bounds_local
490 nflat_local = pw_grid%ngpts_local
491 nproc = pw_grid%para%group%num_pe
492 pe_index = pw_grid%para%group%mepos + 1
493
494 IF (PRESENT(weights)) THEN
495 CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
496 ELSE
497 CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
498 weight_sumsq=weight_sumsq)
499 END IF
500
501 ALLOCATE (local_owner(nflat_local), local_static(nstatic_per_point*nflat_local), &
502 feature_counts(nproc), feature_displs(nproc), static_counts(nproc), &
503 static_displs(nproc), atom_coords_pbc(3, natom))
504 ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
505 bo(1, 2):bo(2, 2), &
506 bo(1, 3):bo(2, 3)))
507 cached_layout%feature_index = 0
508 local_static = 0.0_dp
509 DO iatom = 1, natom
510 atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.true.)
511 END DO
512
513 CALL timeset("skala_gpw_layout_local", phase_handle)
514 local_row = 0
515 DO k = bo(1, 3), bo(2, 3)
516 DO j = bo(1, 2), bo(2, 2)
517 DO i = bo(1, 1), bo(2, 1)
518 local_row = local_row + 1
519 static_base = nstatic_per_point*(local_row - 1)
520 grid_point = grid_coordinate(pw_grid, [i, j, k])
521 owner = nearest_atom(grid_point, atom_coords_pbc, cell)
522 local_owner(local_row) = owner
523 cached_layout%feature_index(i, j, k) = local_row
524
525 owner_coord = atom_coords_pbc(:, owner)
526 local_static(static_base + 1:static_base + 3) = &
527 nearest_image_coordinate(owner_coord, grid_point, cell)
528 local_static(static_base + 4) = pw_grid%dvol
529 IF (PRESENT(weights)) THEN
530 IF (ASSOCIATED(weights)) local_static(static_base + 4) = &
531 pw_grid%dvol*weights%array(i, j, k)
532 END IF
533 END DO
534 END DO
535 END DO
536 CALL timestop(phase_handle)
537
538 ! SKALA groups all grid points by atom. This ordering is static while the
539 ! grid, cell, atom positions, and optional integration weights are unchanged.
540 CALL timeset("skala_gpw_layout_gather", phase_handle)
541 CALL pw_grid%para%group%allgather(nflat_local, feature_counts)
542 feature_displs(1) = 0
543 DO pe = 2, nproc
544 feature_displs(pe) = feature_displs(pe - 1) + feature_counts(pe - 1)
545 END DO
546 DO pe = 1, nproc
547 static_counts(pe) = nstatic_per_point*feature_counts(pe)
548 static_displs(pe) = nstatic_per_point*feature_displs(pe)
549 END DO
550 nflat = sum(feature_counts)
551 ALLOCATE (global_owner(nflat), global_static(nstatic_per_point*nflat))
552 CALL pw_grid%para%group%allgatherv(local_owner, global_owner, feature_counts, &
553 feature_displs)
554 CALL pw_grid%para%group%allgatherv(local_static, global_static, static_counts, &
555 static_displs)
556 CALL timestop(phase_handle)
557
558 ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
559 cached_layout%chunk_feature_displs(nproc), &
560 cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
561 cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
562 cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
563 cached_layout%route_dynamic_recv_counts(nproc), &
564 cached_layout%route_dynamic_recv_displs(nproc), &
565 cached_layout%route_dynamic_send_counts(nproc), &
566 cached_layout%route_dynamic_send_displs(nproc), &
567 cached_layout%route_grad_return_recv_counts(nproc), &
568 cached_layout%route_grad_return_recv_displs(nproc), &
569 cached_layout%route_grad_return_send_counts(nproc), &
570 cached_layout%route_grad_return_send_displs(nproc), &
571 cached_layout%route_meta_recv_counts(nproc), &
572 cached_layout%route_meta_recv_displs(nproc), &
573 cached_layout%route_meta_send_counts(nproc), &
574 cached_layout%route_meta_send_displs(nproc), &
575 cached_layout%route_point_recv_counts(nproc), &
576 cached_layout%route_point_recv_displs(nproc), &
577 cached_layout%route_point_send_counts(nproc), &
578 cached_layout%route_point_send_displs(nproc), &
579 cached_layout%global_to_feature(nflat), cached_layout%atomic_grid_sizes(natom), &
580 cached_layout%local_feature_indices(nflat_local), atom_offset(natom + 1), &
581 atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
582 local_to_global(nflat_local))
583 cached_layout%feature_counts(:) = feature_counts
584 cached_layout%feature_displs(:) = feature_displs
585 cached_layout%dynamic_counts(:) = ndynamic_per_point*feature_counts
586 cached_layout%dynamic_displs(:) = ndynamic_per_point*feature_displs
587 cached_layout%atomic_grid_sizes = 0_int_8
588
589 CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
590 DO ipt = 1, nflat
591 cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
592 cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
593 END DO
594 atom_offset(1) = 1
595 DO iatom = 1, natom
596 atom_offset(iatom + 1) = atom_offset(iatom) + int(cached_layout%atomic_grid_sizes(iatom))
597 END DO
598 DO iatom = 1, natom
599 atom_position(iatom) = atom_offset(iatom)
600 END DO
601 max_grid_size = maxval(int(cached_layout%atomic_grid_sizes))
602 CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
603 chunk_atom_begin, chunk_atom_end, &
604 cached_layout%chunk_feature_counts, &
605 cached_layout%chunk_feature_displs)
606 cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
607 cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
608 cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
609 cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
610 cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
611 cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
612 cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
613 cached_layout%chunk_atom_begin + 1
614
615 ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
616 cached_layout%atomic_grid_weights(nflat), &
617 cached_layout%coarse_0_atomic_coords(3, natom), &
618 cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
619 cached_layout%atom_coords(3, natom))
620 cached_layout%grid_coords = 0.0_dp
621 cached_layout%grid_weights = 0.0_dp
622 cached_layout%atomic_grid_weights = 0.0_dp
623 cached_layout%atomic_grid_size_bound_shape = 0_int_8
624
625 DO iatom = 1, natom
626 cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
627 cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
628 END DO
629
630 DO ipt = 1, nflat
631 owner = global_owner(ipt)
632 row = atom_position(owner)
633 atom_position(owner) = atom_position(owner) + 1
634 cached_layout%global_to_feature(ipt) = row
635 static_base = nstatic_per_point*(ipt - 1)
636 cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
637 cached_layout%grid_weights(row) = global_static(static_base + 4)
638 cached_layout%atomic_grid_weights(row) = cached_layout%grid_weights(row)
639 IF (ipt > feature_displs(pe_index) .AND. &
640 ipt <= feature_displs(pe_index) + nflat_local) THEN
641 local_to_global(ipt - feature_displs(pe_index)) = row
642 END IF
643 END DO
644
645 DO k = bo(1, 3), bo(2, 3)
646 DO j = bo(1, 2), bo(2, 2)
647 DO i = bo(1, 1), bo(2, 1)
648 cached_layout%feature_index(i, j, k) = &
649 local_to_global(cached_layout%feature_index(i, j, k))
650 END DO
651 END DO
652 END DO
653 DO local_row = 1, nflat_local
654 cached_layout%local_feature_indices(local_row) = &
655 int(local_to_global(local_row) - 1, kind=int_8)
656 END DO
657 CALL timestop(phase_handle)
658 CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
659 CALL build_atom_chunk_routes(cached_layout, local_to_global, pw_grid%para%group)
660 CALL build_atom_chunk_layout(cached_layout)
661 CALL timestop(phase_handle)
662
663 cached_layout%natom = natom
664 cached_layout%nflat = nflat
665 cached_layout%nflat_local = nflat_local
666 cached_layout%nproc = nproc
667 cached_layout%bo = bo
668 cached_layout%bounds = pw_grid%bounds
669 cached_layout%npts = pw_grid%npts
670 cached_layout%dvol = pw_grid%dvol
671 cached_layout%dh = pw_grid%dh
672 cached_layout%cell_hmat = cell%hmat
673 cached_layout%weight_sum = weight_sum
674 cached_layout%weight_sumsq = weight_sumsq
675 cached_layout%has_weights = has_weights
676 CALL timeset("skala_gpw_layout_tensors", phase_handle)
677 CALL build_static_layout_tensors(cached_layout)
678 CALL timestop(phase_handle)
679 cached_layout%active = .true.
680
681 DEALLOCATE (atom_coords_pbc, atom_offset, atom_position, chunk_atom_begin, chunk_atom_end, &
682 feature_counts, feature_displs, global_owner, global_static, local_owner, &
683 local_static, local_to_global, static_counts, static_displs)
684
685 END SUBROUTINE rebuild_layout_cache
686
687! **************************************************************************************************
688!> \brief Build cached Torch tensors for static SKALA inputs.
689!> \param cache ...
690! **************************************************************************************************
691 SUBROUTINE build_static_layout_tensors(cache)
692 TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
693
694 cpassert(.NOT. cache%static_tensors_active)
695
696 CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
697 CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .false.)
698 CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
699 CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .false.)
700 CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
701 CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .false.)
702 CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
703 CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .false.)
704 CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
705 CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .false.)
706 CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
707 cache%atomic_grid_size_bound_shape)
708 CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .false.)
709 CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
710 CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .false.)
711
712 CALL torch_dict_create(cache%static_inputs)
713 CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
714 CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
715 CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
716 cache%atomic_grid_weights_t)
717 CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
718 cache%atomic_grid_sizes_t)
719 CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
720 cache%atomic_grid_size_bound_shape_t)
721 cache%static_tensors_active = .true.
722
723 IF (cache%chunk_feature_count > 0) THEN
724 cpassert(.NOT. cache%chunk_static_tensors_active)
725 CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
726 CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .false.)
727 CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
728 CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .false.)
729 CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
730 cache%chunk_atomic_grid_weights)
731 CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .false.)
732 CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
733 cache%chunk_atomic_grid_sizes)
734 CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .false.)
735 CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
736 cache%chunk_coarse_0_atomic_coords)
737 CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .false.)
738 CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
739 cache%chunk_atomic_grid_size_bound_shape)
740 CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .false.)
741 CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
742 CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .false.)
743
744 CALL torch_dict_create(cache%chunk_static_inputs)
745 CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
746 cache%chunk_grid_coords_t)
747 CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
748 cache%chunk_grid_weights_t)
749 CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
750 cache%chunk_atomic_grid_weights_t)
751 CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
752 cache%chunk_atomic_grid_sizes_t)
753 CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
754 cache%chunk_atomic_grid_size_bound_shape_t)
755 cache%chunk_static_tensors_active = .true.
756 END IF
757
758 END SUBROUTINE build_static_layout_tensors
759
760! **************************************************************************************************
761!> \brief Copy static cached layout arrays into a feature bundle.
762!> \param features ...
763!> \param needs_coordinate_array ...
764! **************************************************************************************************
765 SUBROUTINE copy_cached_layout(features, needs_coordinate_array)
766 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
767 LOGICAL, INTENT(IN) :: needs_coordinate_array
768
769 cpassert(cached_layout%active)
770
771 ALLOCATE (features%feature_index(lbound(cached_layout%feature_index, 1): &
772 ubound(cached_layout%feature_index, 1), &
773 lbound(cached_layout%feature_index, 2): &
774 ubound(cached_layout%feature_index, 2), &
775 lbound(cached_layout%feature_index, 3): &
776 ubound(cached_layout%feature_index, 3)))
777 ALLOCATE (features%grid_weights(cached_layout%nflat))
778
779 features%feature_index(:, :, :) = cached_layout%feature_index
780 features%grid_weights(:) = cached_layout%grid_weights
781 features%nflat = cached_layout%nflat
782 features%nflat_local = cached_layout%nflat_local
783 features%chunk_feature_count = cached_layout%chunk_feature_count
784 ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
785 features%chunk_grad_displs(cached_layout%nproc), &
786 features%route_grad_return_recv_counts(cached_layout%nproc), &
787 features%route_grad_return_recv_displs(cached_layout%nproc), &
788 features%route_grad_return_send_counts(cached_layout%nproc), &
789 features%route_grad_return_send_displs(cached_layout%nproc), &
790 features%route_point_recv_counts(cached_layout%nproc), &
791 features%route_point_recv_displs(cached_layout%nproc), &
792 features%route_point_send_counts(cached_layout%nproc), &
793 features%route_point_send_displs(cached_layout%nproc), &
794 features%route_send_local_rows(cached_layout%nflat_local))
795 features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
796 features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
797 features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
798 features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
799 features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
800 features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
801 features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
802 features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
803 features%route_point_send_counts(:) = cached_layout%route_point_send_counts
804 features%route_point_send_displs(:) = cached_layout%route_point_send_displs
805 features%route_send_local_rows(:) = cached_layout%route_send_local_rows
806 IF (needs_coordinate_array) THEN
807 ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
808 features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
809 END IF
810
811 END SUBROUTINE copy_cached_layout
812
813! **************************************************************************************************
814!> \brief Split the atom-ordered feature rows into contiguous atom chunks.
815!> \param atomic_grid_sizes ...
816!> \param atom_offset ...
817!> \param nproc ...
818!> \param chunk_atom_begin ...
819!> \param chunk_atom_end ...
820!> \param chunk_feature_counts ...
821!> \param chunk_feature_displs ...
822! **************************************************************************************************
823 SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
824 chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
825 INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
826 INTEGER, DIMENSION(:), INTENT(IN) :: atom_offset
827 INTEGER, INTENT(IN) :: nproc
828 INTEGER, DIMENSION(:), INTENT(OUT) :: chunk_atom_begin, chunk_atom_end, &
829 chunk_feature_counts, &
830 chunk_feature_displs
831
832 INTEGER :: atoms_left, count, displ, end_atom, max_end_atom, natom, next_atom, next_count, &
833 pe, ranks_left, target_count, total_left
834
835 natom = SIZE(atomic_grid_sizes)
836 chunk_atom_begin = natom + 1
837 chunk_atom_end = natom
838 chunk_feature_counts = 0
839 chunk_feature_displs = 0
840
841 displ = 0
842 next_atom = 1
843 DO pe = 1, nproc
844 chunk_feature_displs(pe) = displ
845 IF (next_atom > natom) cycle
846
847 ranks_left = nproc - pe + 1
848 atoms_left = natom - next_atom + 1
849 chunk_atom_begin(pe) = next_atom
850 IF (ranks_left >= atoms_left) THEN
851 end_atom = next_atom
852 ELSE
853 max_end_atom = natom - ranks_left + 1
854 total_left = atom_offset(natom + 1) - atom_offset(next_atom)
855 target_count = max(1, nint(real(total_left, kind=dp)/real(ranks_left, kind=dp)))
856 end_atom = next_atom
857 count = int(atomic_grid_sizes(end_atom))
858 DO WHILE (end_atom < max_end_atom)
859 next_count = count + int(atomic_grid_sizes(end_atom + 1))
860 IF (count >= target_count .AND. &
861 abs(count - target_count) <= abs(next_count - target_count)) EXIT
862 IF (count < target_count .OR. &
863 abs(next_count - target_count) < abs(count - target_count)) THEN
864 end_atom = end_atom + 1
865 count = next_count
866 ELSE
867 EXIT
868 END IF
869 END DO
870 END IF
871
872 chunk_atom_end(pe) = end_atom
873 chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
874 displ = displ + chunk_feature_counts(pe)
875 next_atom = end_atom + 1
876 END DO
877
878 cpassert(displ == atom_offset(natom + 1) - 1)
879
880 END SUBROUTINE build_atom_chunks
881
882! **************************************************************************************************
883!> \brief Return the MPI rank owning an atom-ordered feature row.
884!> \param row ...
885!> \param counts ...
886!> \param displs ...
887!> \return ...
888! **************************************************************************************************
889 FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
890 INTEGER, INTENT(IN) :: row
891 INTEGER, DIMENSION(:), INTENT(IN) :: counts, displs
892 INTEGER :: owner
893
894 INTEGER :: pe
895
896 owner = 0
897 DO pe = 1, SIZE(counts)
898 IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
899 owner = pe
900 RETURN
901 END IF
902 END DO
903
904 END FUNCTION feature_row_chunk_owner
905
906! **************************************************************************************************
907!> \brief Build zero-based displacement arrays from per-rank counts.
908!> \param counts ...
909!> \param displs ...
910! **************************************************************************************************
911 SUBROUTINE counts_to_displs(counts, displs)
912 INTEGER, DIMENSION(:), INTENT(IN) :: counts
913 INTEGER, DIMENSION(:), INTENT(OUT) :: displs
914
915 INTEGER :: pe
916
917 displs(1) = 0
918 DO pe = 2, SIZE(counts)
919 displs(pe) = displs(pe - 1) + counts(pe - 1)
920 END DO
921
922 END SUBROUTINE counts_to_displs
923
924! **************************************************************************************************
925!> \brief Precompute all-to-all routing between local grid rows and atom chunks.
926!> \param cache ...
927!> \param local_to_global ...
928!> \param group ...
929! **************************************************************************************************
930 SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
931 TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
932 INTEGER, DIMENSION(:), INTENT(IN) :: local_to_global
933
934 CLASS(mp_comm_type), INTENT(IN) :: group
935
936 INTEGER :: dest, local_row, point_pos
937 INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor
938
939 ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
940 cache%route_send_local_rows(SIZE(local_to_global)), &
941 cursor(SIZE(cache%route_point_send_counts)))
942 cache%route_point_send_counts = 0
943 cache%route_send_local_rows = 0
944 DO local_row = 1, SIZE(local_to_global)
945 dest = feature_row_chunk_owner(local_to_global(local_row), &
946 cache%chunk_feature_counts, &
947 cache%chunk_feature_displs)
948 cpassert(dest > 0)
949 cache%route_local_dest(local_row) = dest
950 cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
951 END DO
952 CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
953 cursor(:) = cache%route_point_send_displs + 1
954 DO local_row = 1, SIZE(local_to_global)
955 dest = cache%route_local_dest(local_row)
956 point_pos = cursor(dest)
957 cursor(dest) = cursor(dest) + 1
958 cache%route_send_local_rows(point_pos) = local_row
959 END DO
960 CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
961 CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
962
963 cache%route_meta_send_counts(:) = 2*cache%route_point_send_counts
964 cache%route_meta_send_displs(:) = 2*cache%route_point_send_displs
965 cache%route_meta_recv_counts(:) = 2*cache%route_point_recv_counts
966 cache%route_meta_recv_displs(:) = 2*cache%route_point_recv_displs
967 cache%route_dynamic_send_counts(:) = ndynamic_per_point*cache%route_point_send_counts
968 cache%route_dynamic_send_displs(:) = ndynamic_per_point*cache%route_point_send_displs
969 cache%route_dynamic_recv_counts(:) = ndynamic_per_point*cache%route_point_recv_counts
970 cache%route_dynamic_recv_displs(:) = ndynamic_per_point*cache%route_point_recv_displs
971 cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
972 cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
973 cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
974 cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
975
976 cpassert(sum(cache%route_point_send_counts) == SIZE(local_to_global))
977 cpassert(sum(cache%route_point_recv_counts) == cache%chunk_feature_count)
978 cpassert(all(cache%route_send_local_rows > 0))
979
980 DEALLOCATE (cursor)
981
982 END SUBROUTINE build_atom_chunk_routes
983
984! **************************************************************************************************
985!> \brief Materialize the current rank's atom chunk static layout.
986!> \param cache ...
987! **************************************************************************************************
988 SUBROUTINE build_atom_chunk_layout(cache)
989 TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
990
991 INTEGER :: irow, max_grid_size, row_begin, row_end
992
993 IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
994
995 row_begin = cache%chunk_feature_begin
996 row_end = row_begin + cache%chunk_feature_count - 1
997 ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
998 cache%chunk_grid_weights(cache%chunk_feature_count), &
999 cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
1000 cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
1001 cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
1002 cache%chunk_feature_indices(cache%chunk_feature_count))
1003 cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
1004 cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
1005 cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
1006 cache%chunk_atomic_grid_sizes(:) = &
1007 cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
1008 cache%chunk_coarse_0_atomic_coords(:, :) = &
1009 cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
1010
1011 max_grid_size = maxval(int(cache%chunk_atomic_grid_sizes))
1012 ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
1013 cache%chunk_atomic_grid_size_bound_shape = 0_int_8
1014 DO irow = 1, cache%chunk_feature_count
1015 cache%chunk_feature_indices(irow) = int(irow - 1, kind=int_8)
1016 END DO
1017
1018 END SUBROUTINE build_atom_chunk_layout
1019
1020! **************************************************************************************************
1021!> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
1022!> \param features ...
1023!> \param local_dynamic ...
1024!> \param group ...
1025! **************************************************************************************************
1026 SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group)
1027 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1028 REAL(kind=dp), DIMENSION(:), INTENT(IN) :: local_dynamic
1029
1030 CLASS(mp_comm_type), INTENT(IN) :: group
1031
1032 INTEGER :: chunk_row, dest, dyn_base, irow, local_row, &
1033 meta_base, nrecv, nsend, pe, point_pos, &
1034 row, src_base
1035 INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_meta, send_meta
1036 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: recv_dynamic, send_dynamic
1037
1038 cpassert(cached_layout%chunk_feature_count > 0)
1039 nsend = SIZE(cached_layout%route_local_dest)
1040 nrecv = sum(cached_layout%route_point_recv_counts)
1041 cpassert(nsend == cached_layout%nflat_local)
1042 cpassert(nrecv == cached_layout%chunk_feature_count)
1043
1044 ALLOCATE (send_meta(2*nsend), send_dynamic(ndynamic_per_point*nsend), &
1045 recv_meta(2*nrecv), recv_dynamic(ndynamic_per_point*nrecv), &
1046 cursor(cached_layout%nproc))
1047 send_meta = 0
1048 send_dynamic = 0.0_dp
1049 cursor(:) = cached_layout%route_point_send_displs + 1
1050 DO local_row = 1, nsend
1051 dest = cached_layout%route_local_dest(local_row)
1052 point_pos = cursor(dest)
1053 cursor(dest) = cursor(dest) + 1
1054 meta_base = 2*(point_pos - 1)
1055 dyn_base = ndynamic_per_point*(point_pos - 1)
1056 src_base = ndynamic_per_point*(local_row - 1)
1057 send_meta(meta_base + 1) = int(cached_layout%local_feature_indices(local_row) + 1_int_8)
1058 send_meta(meta_base + 2) = local_row
1059 send_dynamic(dyn_base + 1:dyn_base + ndynamic_per_point) = &
1060 local_dynamic(src_base + 1:src_base + ndynamic_per_point)
1061 END DO
1062
1063 CALL group%alltoall(send_meta, cached_layout%route_meta_send_counts, &
1064 cached_layout%route_meta_send_displs, recv_meta, &
1065 cached_layout%route_meta_recv_counts, &
1066 cached_layout%route_meta_recv_displs)
1067 CALL group%alltoall(send_dynamic, cached_layout%route_dynamic_send_counts, &
1068 cached_layout%route_dynamic_send_displs, recv_dynamic, &
1069 cached_layout%route_dynamic_recv_counts, &
1070 cached_layout%route_dynamic_recv_displs)
1071
1072 ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1073 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1074 features%chunk_kin(cached_layout%chunk_feature_count, 2), &
1075 features%chunk_return_positions(cached_layout%chunk_feature_count), &
1076 features%chunk_return_ranks(cached_layout%chunk_feature_count), &
1077 features%chunk_return_rows(cached_layout%chunk_feature_count))
1078 features%chunk_density = 0.0_dp
1079 features%chunk_grad = 0.0_dp
1080 features%chunk_kin = 0.0_dp
1081 features%chunk_return_positions = 0
1082 features%chunk_return_ranks = 0
1083 features%chunk_return_rows = 0
1084
1085 DO pe = 1, cached_layout%nproc
1086 DO irow = 1, cached_layout%route_point_recv_counts(pe)
1087 point_pos = cached_layout%route_point_recv_displs(pe) + irow
1088 meta_base = 2*(point_pos - 1)
1089 dyn_base = ndynamic_per_point*(point_pos - 1)
1090 row = recv_meta(meta_base + 1)
1091 local_row = recv_meta(meta_base + 2)
1092 chunk_row = row - cached_layout%chunk_feature_begin + 1
1093 cpassert(chunk_row >= 1 .AND. chunk_row <= cached_layout%chunk_feature_count)
1094 features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
1095 features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
1096 features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
1097 features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
1098 features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
1099 features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
1100 features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
1101 features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
1102 features%chunk_return_positions(chunk_row) = point_pos
1103 features%chunk_return_ranks(chunk_row) = pe
1104 features%chunk_return_rows(chunk_row) = local_row
1105 END DO
1106 END DO
1107 cpassert(all(features%chunk_return_positions > 0))
1108 cpassert(all(features%chunk_return_ranks > 0))
1109 cpassert(all(features%chunk_return_rows > 0))
1110
1111 DEALLOCATE (cursor, recv_dynamic, recv_meta, send_dynamic, send_meta)
1112
1113 END SUBROUTINE route_atom_chunk_dynamics
1114
1115! **************************************************************************************************
1116!> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
1117!> \param features ...
1118! **************************************************************************************************
1119 SUBROUTINE extract_atom_chunk_dynamics(features)
1120 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1121
1122 INTEGER :: row_begin, row_end
1123
1124 cpassert(cached_layout%chunk_feature_count > 0)
1125 row_begin = cached_layout%chunk_feature_begin
1126 row_end = row_begin + cached_layout%chunk_feature_count - 1
1127 ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1128 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1129 features%chunk_kin(cached_layout%chunk_feature_count, 2))
1130 features%chunk_density(:, :) = features%density(row_begin:row_end, :)
1131 features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
1132 features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
1133
1134 END SUBROUTINE extract_atom_chunk_dynamics
1135
1136! **************************************************************************************************
1137!> \brief Compute a local signature for optional integration weights.
1138!> \param weights ...
1139!> \param has_weights ...
1140!> \param weight_sum ...
1141!> \param weight_sumsq ...
1142! **************************************************************************************************
1143 SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
1144 TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
1145 LOGICAL, INTENT(OUT) :: has_weights
1146 REAL(kind=dp), INTENT(OUT) :: weight_sum, weight_sumsq
1147
1148 has_weights = .false.
1149 weight_sum = 0.0_dp
1150 weight_sumsq = 0.0_dp
1151 IF (PRESENT(weights)) THEN
1152 IF (ASSOCIATED(weights)) THEN
1153 has_weights = .true.
1154 weight_sum = sum(weights%array)
1155 weight_sumsq = sum(weights%array*weights%array)
1156 END IF
1157 END IF
1158
1159 END SUBROUTINE weights_signature
1160
1161! **************************************************************************************************
1162!> \brief Release cached layout arrays.
1163!> \param cache ...
1164! **************************************************************************************************
1165 SUBROUTINE release_layout_cache(cache)
1166 TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1167
1168 IF (cache%dynamic_tensors_active) THEN
1169 CALL torch_tensor_release(cache%density_t)
1170 CALL torch_tensor_release(cache%grad_t)
1171 CALL torch_tensor_release(cache%kin_t)
1172 cache%dynamic_tensors_active = .false.
1173 END IF
1174
1175 IF (cache%chunk_dynamic_tensors_active) THEN
1176 CALL torch_tensor_release(cache%chunk_density_t)
1177 CALL torch_tensor_release(cache%chunk_grad_t)
1178 CALL torch_tensor_release(cache%chunk_kin_t)
1179 cache%chunk_dynamic_tensors_active = .false.
1180 END IF
1181
1182 IF (cache%static_tensors_active) THEN
1183 CALL torch_tensor_release(cache%grid_coords_t)
1184 CALL torch_tensor_release(cache%grid_weights_t)
1185 CALL torch_tensor_release(cache%atomic_grid_weights_t)
1186 CALL torch_tensor_release(cache%atomic_grid_sizes_t)
1187 CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
1188 CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
1189 CALL torch_tensor_release(cache%local_feature_indices_t)
1190 CALL torch_dict_release(cache%static_inputs)
1191 cache%static_tensors_active = .false.
1192 END IF
1193
1194 IF (cache%chunk_static_tensors_active) THEN
1195 CALL torch_tensor_release(cache%chunk_grid_coords_t)
1196 CALL torch_tensor_release(cache%chunk_grid_weights_t)
1197 CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
1198 CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
1199 CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
1200 CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
1201 CALL torch_tensor_release(cache%chunk_feature_indices_t)
1202 CALL torch_dict_release(cache%chunk_static_inputs)
1203 cache%chunk_static_tensors_active = .false.
1204 END IF
1205
1206 IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
1207 IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
1208 IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
1209 IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
1210 IF (ALLOCATED(cache%route_dynamic_recv_counts)) DEALLOCATE (cache%route_dynamic_recv_counts)
1211 IF (ALLOCATED(cache%route_dynamic_recv_displs)) DEALLOCATE (cache%route_dynamic_recv_displs)
1212 IF (ALLOCATED(cache%route_dynamic_send_counts)) DEALLOCATE (cache%route_dynamic_send_counts)
1213 IF (ALLOCATED(cache%route_dynamic_send_displs)) DEALLOCATE (cache%route_dynamic_send_displs)
1214 IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
1215 DEALLOCATE (cache%route_grad_return_recv_counts)
1216 IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
1217 DEALLOCATE (cache%route_grad_return_recv_displs)
1218 IF (ALLOCATED(cache%route_grad_return_send_counts)) &
1219 DEALLOCATE (cache%route_grad_return_send_counts)
1220 IF (ALLOCATED(cache%route_grad_return_send_displs)) &
1221 DEALLOCATE (cache%route_grad_return_send_displs)
1222 IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
1223 IF (ALLOCATED(cache%route_meta_recv_counts)) DEALLOCATE (cache%route_meta_recv_counts)
1224 IF (ALLOCATED(cache%route_meta_recv_displs)) DEALLOCATE (cache%route_meta_recv_displs)
1225 IF (ALLOCATED(cache%route_meta_send_counts)) DEALLOCATE (cache%route_meta_send_counts)
1226 IF (ALLOCATED(cache%route_meta_send_displs)) DEALLOCATE (cache%route_meta_send_displs)
1227 IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
1228 IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
1229 IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
1230 IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
1231 IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
1232 IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
1233 IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
1234 IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
1235 IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
1236 IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
1237 IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
1238 IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
1239 IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
1240 IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
1241 IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
1242 IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
1243 DEALLOCATE (cache%atomic_grid_size_bound_shape)
1244 IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
1245 DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
1246 IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
1247 IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
1248 IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
1249 IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
1250 IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
1251 IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
1252 DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
1253 IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
1254 IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
1255 IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
1256
1257 cache%chunk_atom_begin = 1
1258 cache%chunk_atom_end = 0
1259 cache%chunk_feature_begin = 1
1260 cache%chunk_feature_count = 0
1261 cache%chunk_natom = 0
1262 cache%natom = 0
1263 cache%nflat = 0
1264 cache%nflat_local = 0
1265 cache%nproc = 0
1266 cache%bo = 0
1267 cache%bounds = 0
1268 cache%npts = 0
1269 cache%dvol = 0.0_dp
1270 cache%weight_sum = 0.0_dp
1271 cache%weight_sumsq = 0.0_dp
1272 cache%cell_hmat = 0.0_dp
1273 cache%dh = 0.0_dp
1274 cache%active = .false.
1275 cache%has_weights = .false.
1276 cache%chunk_dynamic_tensors_active = .false.
1277 cache%chunk_static_tensors_active = .false.
1278 cache%dynamic_tensors_active = .false.
1279 cache%static_tensors_active = .false.
1280
1281 END SUBROUTINE release_layout_cache
1282
1283! **************************************************************************************************
1284!> \brief Release Torch objects and backing arrays owned by a feature bundle.
1285!> \param features ...
1286! **************************************************************************************************
1287 SUBROUTINE skala_gpw_feature_release(features)
1288 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1289
1290 IF (features%active) THEN
1291 IF (features%owns_dynamic_tensors) THEN
1292 CALL torch_tensor_release(features%density_t)
1293 CALL torch_tensor_release(features%grad_t)
1294 CALL torch_tensor_release(features%kin_t)
1295 END IF
1296 IF (features%owns_static_tensors) THEN
1297 CALL torch_tensor_release(features%grid_coords_t)
1298 CALL torch_tensor_release(features%grid_weights_t)
1299 CALL torch_tensor_release(features%atomic_grid_weights_t)
1300 CALL torch_tensor_release(features%atomic_grid_sizes_t)
1301 CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
1302 END IF
1303 IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
1304 CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
1305 END IF
1306 CALL torch_dict_release(features%inputs)
1307 features%active = .false.
1308 features%owns_coordinate_tensor = .false.
1309 features%owns_dynamic_tensors = .true.
1310 features%owns_static_tensors = .true.
1311 features%uses_atom_chunk_routing = .false.
1312 features%uses_atom_chunks = .false.
1313 END IF
1314
1315 IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
1316 IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
1317 IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
1318 IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
1319 IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
1320 IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
1321 IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
1322 IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
1323 IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
1324 IF (ALLOCATED(features%chunk_return_ranks)) DEALLOCATE (features%chunk_return_ranks)
1325 IF (ALLOCATED(features%chunk_return_rows)) DEALLOCATE (features%chunk_return_rows)
1326 IF (ALLOCATED(features%route_grad_return_recv_counts)) &
1327 DEALLOCATE (features%route_grad_return_recv_counts)
1328 IF (ALLOCATED(features%route_grad_return_recv_displs)) &
1329 DEALLOCATE (features%route_grad_return_recv_displs)
1330 IF (ALLOCATED(features%route_grad_return_send_counts)) &
1331 DEALLOCATE (features%route_grad_return_send_counts)
1332 IF (ALLOCATED(features%route_grad_return_send_displs)) &
1333 DEALLOCATE (features%route_grad_return_send_displs)
1334 IF (ALLOCATED(features%route_point_recv_counts)) &
1335 DEALLOCATE (features%route_point_recv_counts)
1336 IF (ALLOCATED(features%route_point_recv_displs)) &
1337 DEALLOCATE (features%route_point_recv_displs)
1338 IF (ALLOCATED(features%route_point_send_counts)) &
1339 DEALLOCATE (features%route_point_send_counts)
1340 IF (ALLOCATED(features%route_point_send_displs)) &
1341 DEALLOCATE (features%route_point_send_displs)
1342 IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
1343 IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
1344 IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
1345 IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
1346 IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
1347 IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
1348 IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
1349 IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
1350 DEALLOCATE (features%atomic_grid_size_bound_shape)
1351 features%chunk_feature_count = 0
1352 features%nflat = 0
1353 features%nflat_local = 0
1354 features%uses_atom_chunk_routing = .false.
1355
1356 END SUBROUTINE skala_gpw_feature_release
1357
1358! **************************************************************************************************
1359!> \brief Insert all SKALA feature tensors into the Torch dictionary.
1360!> \param features ...
1361!> \param requires_grad ...
1362!> \param requires_coordinate_grad ...
1363!> \param use_atom_chunks ...
1364! **************************************************************************************************
1365 SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
1366 use_atom_chunks)
1367 TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1368 LOGICAL, INTENT(IN) :: requires_grad, requires_coordinate_grad, &
1369 use_atom_chunks
1370
1371 cpassert(cached_layout%static_tensors_active)
1372 features%owns_static_tensors = .false.
1373 features%owns_coordinate_tensor = .false.
1374 features%owns_dynamic_tensors = .false.
1375 IF (use_atom_chunks) THEN
1376 cpassert(cached_layout%chunk_static_tensors_active)
1377 CALL torch_dict_clone(cached_layout%chunk_static_inputs, features%inputs)
1378 features%grid_coords_t = cached_layout%chunk_grid_coords_t
1379 features%grid_weights_t = cached_layout%chunk_grid_weights_t
1380 features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
1381 features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
1382 features%atomic_grid_size_bound_shape_t = &
1383 cached_layout%chunk_atomic_grid_size_bound_shape_t
1384 features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
1385
1386 CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
1387 features%chunk_density, requires_grad=requires_grad)
1388 features%density_t = cached_layout%chunk_density_t
1389 CALL torch_dict_insert(features%inputs, "density", features%density_t)
1390 CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
1391 requires_grad=requires_grad)
1392 features%grad_t = cached_layout%chunk_grad_t
1393 CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1394 CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
1395 requires_grad=requires_grad)
1396 features%kin_t = cached_layout%chunk_kin_t
1397 CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1398 cached_layout%chunk_dynamic_tensors_active = .true.
1399 ELSE
1400 CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
1401 features%grid_coords_t = cached_layout%grid_coords_t
1402 features%grid_weights_t = cached_layout%grid_weights_t
1403 features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
1404 features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
1405 features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
1406 features%local_feature_indices_t = cached_layout%local_feature_indices_t
1407
1408 CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
1409 requires_grad=requires_grad)
1410 features%density_t = cached_layout%density_t
1411 CALL torch_dict_insert(features%inputs, "density", features%density_t)
1412 CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
1413 requires_grad=requires_grad)
1414 features%grad_t = cached_layout%grad_t
1415 CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1416 CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
1417 requires_grad=requires_grad)
1418 features%kin_t = cached_layout%kin_t
1419 CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1420 cached_layout%dynamic_tensors_active = .true.
1421 END IF
1422
1423 IF (requires_coordinate_grad) THEN
1424 cpassert(.NOT. use_atom_chunks)
1425 CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
1426 features%coarse_0_atomic_coords)
1427 CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .true.)
1428 CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1429 features%coarse_0_atomic_coords_t)
1430 features%owns_coordinate_tensor = .true.
1431 ELSE
1432 IF (use_atom_chunks) THEN
1433 features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
1434 CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1435 cached_layout%chunk_coarse_0_atomic_coords_t)
1436 ELSE
1437 features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
1438 CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1439 cached_layout%coarse_0_atomic_coords_t)
1440 END IF
1441 END IF
1442
1443 END SUBROUTINE add_feature_tensors
1444
1445! **************************************************************************************************
1446!> \brief Return the Cartesian coordinate of a regular GPW grid point.
1447!> \param pw_grid ...
1448!> \param index ...
1449!> \return ...
1450! **************************************************************************************************
1451 FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
1452 TYPE(pw_grid_type), POINTER :: pw_grid
1453 INTEGER, DIMENSION(3), INTENT(IN) :: index
1454 REAL(kind=dp), DIMENSION(3) :: coord
1455
1456 INTEGER, DIMENSION(3) :: relative_index
1457
1458 relative_index = index - pw_grid%bounds(1, :)
1459 coord = real(relative_index(1), kind=dp)*pw_grid%dh(:, 1) + &
1460 REAL(relative_index(2), kind=dp)*pw_grid%dh(:, 2) + &
1461 REAL(relative_index(3), kind=dp)*pw_grid%dh(:, 3)
1462
1463 END FUNCTION grid_coordinate
1464
1465! **************************************************************************************************
1466!> \brief Return the grid-point image nearest to the owning atom coordinate.
1467!> \param owner_coord ...
1468!> \param grid_point ...
1469!> \param cell ...
1470!> \return ...
1471! **************************************************************************************************
1472 FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
1473 REAL(kind=dp), DIMENSION(3), INTENT(IN) :: owner_coord, grid_point
1474 TYPE(cell_type), POINTER :: cell
1475 REAL(kind=dp), DIMENSION(3) :: coord
1476
1477 REAL(kind=dp) :: dx, dy, dz
1478
1479 IF (cell%orthorhombic) THEN
1480 dx = grid_point(1) - owner_coord(1)
1481 dy = grid_point(2) - owner_coord(2)
1482 dz = grid_point(3) - owner_coord(3)
1483 dx = dx - cell%hmat(1, 1)*cell%perd(1)*anint(cell%h_inv(1, 1)*dx)
1484 dy = dy - cell%hmat(2, 2)*cell%perd(2)*anint(cell%h_inv(2, 2)*dy)
1485 dz = dz - cell%hmat(3, 3)*cell%perd(3)*anint(cell%h_inv(3, 3)*dz)
1486 coord = owner_coord + [dx, dy, dz]
1487 ELSE
1488 coord = owner_coord + pbc(owner_coord, grid_point, cell)
1489 END IF
1490
1491 END FUNCTION nearest_image_coordinate
1492
1493! **************************************************************************************************
1494!> \brief Assign a grid point to the nearest periodic atom.
1495!> \param grid_point ...
1496!> \param atom_coords ...
1497!> \param cell ...
1498!> \return ...
1499! **************************************************************************************************
1500 FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
1501 REAL(kind=dp), DIMENSION(3), INTENT(IN) :: grid_point
1502 REAL(kind=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
1503 TYPE(cell_type), POINTER :: cell
1504 INTEGER :: owner
1505
1506 INTEGER :: iatom
1507 REAL(kind=dp) :: best_r2, dx, dy, dz, r2
1508 REAL(kind=dp), DIMENSION(3) :: rij
1509
1510 owner = 1
1511 best_r2 = huge(1.0_dp)
1512 IF (cell%orthorhombic) THEN
1513 DO iatom = 1, SIZE(atom_coords, 2)
1514 dx = grid_point(1) - atom_coords(1, iatom)
1515 dy = grid_point(2) - atom_coords(2, iatom)
1516 dz = grid_point(3) - atom_coords(3, iatom)
1517 dx = dx - cell%hmat(1, 1)*cell%perd(1)*anint(cell%h_inv(1, 1)*dx)
1518 dy = dy - cell%hmat(2, 2)*cell%perd(2)*anint(cell%h_inv(2, 2)*dy)
1519 dz = dz - cell%hmat(3, 3)*cell%perd(3)*anint(cell%h_inv(3, 3)*dz)
1520 r2 = dx*dx + dy*dy + dz*dz
1521 IF (r2 < best_r2) THEN
1522 best_r2 = r2
1523 owner = iatom
1524 END IF
1525 END DO
1526 ELSE
1527 DO iatom = 1, SIZE(atom_coords, 2)
1528 rij = pbc(grid_point, atom_coords(:, iatom), cell)
1529 r2 = sum(rij**2)
1530 IF (r2 < best_r2) THEN
1531 best_r2 = r2
1532 owner = iatom
1533 END IF
1534 END DO
1535 END IF
1536
1537 END FUNCTION nearest_atom
1538
1539END MODULE skala_gpw_features
Handles all functions related to the CELL.
Definition cell_types.F:15
various utilities that regard array of different kinds: output, allocation,... maybe it is not a good...
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
Interface to the message passing library MPI.
Define the data structure for the particle information.
Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
subroutine, public skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, requires_grad, weights, requires_coordinate_grad, use_atom_chunks, route_atom_chunks)
Build a flat SKALA molecular feature dictionary from a local GPW grid.
subroutine, public skala_gpw_feature_release(features)
Release Torch objects and backing arrays owned by a feature bundle.
type(skala_gpw_layout_cache_type), save cached_layout
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1721
subroutine, public torch_tensor_to_device_leaf(tensor, requires_grad)
Moves a tensor to the active Torch device and makes it an autograd leaf.
Definition torch_api.F:1445
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1604
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1655
subroutine, public torch_dict_clone(source, target)
Clones a Torch dictionary.
Definition torch_api.F:1627
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:1580
contains the structure
subroutine, public xc_rho_set_get(rho_set, can_return_null, rho, drho, norm_drho, rhoa, rhob, norm_drhoa, norm_drhob, rho_1_3, rhoa_1_3, rhob_1_3, laplace_rho, laplace_rhoa, laplace_rhob, drhoa, drhob, rho_cutoff, drho_cutoff, tau_cutoff, tau, tau_a, tau_b, local_bounds)
returns the various attributes of rho_set
Type defining parameters related to the simulation cell.
Definition cell_types.F:60
represent a pointer to a contiguous 3d array
represent a density, with all the representation and data needed to perform a functional evaluation