66 INTEGER,
DIMENSION(:, :),
POINTER :: glob_loc_list
67 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: glob_cell_v
68 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a, unique_list_a
71 CHARACTER(LEN=*),
PARAMETER :: routinen =
'setup_allegro_arrays'
73 INTEGER :: handle, i, iend, igrp, ikind, ilist, &
74 ipair, istart, jkind, nkinds, nlocal, &
76 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: temp_unique_list_a, work_list, work_list2
77 INTEGER,
DIMENSION(:, :),
POINTER ::
list
78 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: rwork_list
79 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
83 cpassert(.NOT.
ASSOCIATED(glob_loc_list))
84 cpassert(.NOT.
ASSOCIATED(glob_loc_list_a))
85 cpassert(.NOT.
ASSOCIATED(unique_list_a))
86 cpassert(.NOT.
ASSOCIATED(glob_cell_v))
87 CALL timeset(routinen, handle)
89 nkinds =
SIZE(potparm%pot, 1)
90 DO ilist = 1, nonbonded%nlists
91 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
92 npairs = neighbor_kind_pair%npairs
93 IF (npairs == 0) cycle
94 kind_group_loop1:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
95 istart = neighbor_kind_pair%grp_kind_start(igrp)
96 iend = neighbor_kind_pair%grp_kind_end(igrp)
97 ikind = neighbor_kind_pair%ij_kind(1, igrp)
98 jkind = neighbor_kind_pair%ij_kind(2, igrp)
99 pot => potparm%pot(ikind, jkind)%pot
100 npairs = iend - istart + 1
102 DO i = 1,
SIZE(pot%type)
103 IF (pot%type(i) ==
allegro_type) npairs_tot = npairs_tot + npairs
105 END DO kind_group_loop1
107 ALLOCATE (work_list(npairs_tot))
108 ALLOCATE (work_list2(npairs_tot))
109 ALLOCATE (glob_loc_list(2, npairs_tot))
110 ALLOCATE (glob_cell_v(3, npairs_tot))
113 DO ilist = 1, nonbonded%nlists
114 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
115 npairs = neighbor_kind_pair%npairs
116 IF (npairs == 0) cycle
117 kind_group_loop2:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
118 istart = neighbor_kind_pair%grp_kind_start(igrp)
119 iend = neighbor_kind_pair%grp_kind_end(igrp)
120 ikind = neighbor_kind_pair%ij_kind(1, igrp)
121 jkind = neighbor_kind_pair%ij_kind(2, igrp)
122 list => neighbor_kind_pair%list
123 cvi = neighbor_kind_pair%cell_vector
124 pot => potparm%pot(ikind, jkind)%pot
125 npairs = iend - istart + 1
127 cell_v = matmul(cell%hmat, cvi)
128 DO i = 1,
SIZE(pot%type)
132 glob_loc_list(:, npairs_tot + ipair) =
list(:, istart - 1 + ipair)
133 glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
135 npairs_tot = npairs_tot + npairs
138 END DO kind_group_loop2
141 CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
142 DO ipair = 1, npairs_tot
143 work_list2(ipair) = glob_loc_list(2, work_list(ipair))
145 glob_loc_list(2, :) = work_list2
146 DEALLOCATE (work_list2)
147 ALLOCATE (rwork_list(3, npairs_tot))
148 DO ipair = 1, npairs_tot
149 rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
151 glob_cell_v = rwork_list
152 DEALLOCATE (rwork_list)
153 DEALLOCATE (work_list)
154 ALLOCATE (glob_loc_list_a(npairs_tot))
155 glob_loc_list_a = glob_loc_list(1, :)
156 ALLOCATE (temp_unique_list_a(npairs_tot))
158 temp_unique_list_a(1) = glob_loc_list_a(1)
159 DO ipair = 2, npairs_tot
160 IF (glob_loc_list_a(ipair - 1) /= glob_loc_list_a(ipair))
THEN
162 temp_unique_list_a(nlocal) = glob_loc_list_a(ipair)
165 ALLOCATE (unique_list_a(nlocal))
166 unique_list_a(:) = temp_unique_list_a(:nlocal)
167 DEALLOCATE (temp_unique_list_a)
168 CALL timestop(handle)
222 potparm, allegro, glob_loc_list_a, r_last_update_pbc, &
223 pot_allegro, fist_nonbond_env, unique_list_a, para_env, use_virial)
231 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
232 TYPE(
pos_type),
DIMENSION(:),
POINTER :: r_last_update_pbc
233 REAL(kind=
dp) :: pot_allegro
235 INTEGER,
DIMENSION(:),
POINTER :: unique_list_a
237 LOGICAL,
INTENT(IN) :: use_virial
239 CHARACTER(LEN=*),
PARAMETER :: routinen =
'allegro_energy_store_force_virial'
241 INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
242 ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
243 nedges, nloc_size, npairs, nunique
244 INTEGER(kind=int_8),
ALLOCATABLE ::
atom_types(:), temp_atom_types(:)
245 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
246 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: work_list
247 INTEGER,
DIMENSION(:, :),
POINTER ::
list, sort_list
248 LOGICAL,
ALLOCATABLE :: use_atom(:)
249 REAL(kind=
dp) :: drij, rab2_max, rij(3)
250 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts, lattice, &
251 new_edge_cell_shifts, pos
252 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
253 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: atomic_energy, forces, virial
254 REAL(kind=
dp),
DIMENSION(:, :, :),
POINTER :: virial3d
255 REAL(kind=
sp),
ALLOCATABLE,
DIMENSION(:, :) :: lattice_sp, new_edge_cell_shifts_sp, &
257 REAL(kind=
sp),
DIMENSION(:, :),
POINTER :: atomic_energy_sp, forces_sp
262 TYPE(
torch_tensor_type) :: atom_types_tensor, atomic_energy_tensor, forces_tensor, &
263 lattice_tensor, new_edge_cell_shifts_tensor, pos_tensor, t_edge_index_tensor, &
266 CALL timeset(routinen, handle)
268 NULLIFY (atomic_energy, forces, atomic_energy_sp, forces_sp, virial3d, virial)
269 n_atoms =
SIZE(particle_set)
270 ALLOCATE (use_atom(n_atoms))
273 DO ikind = 1,
SIZE(atomic_kind_set)
274 DO jkind = 1,
SIZE(atomic_kind_set)
275 pot => potparm%pot(ikind, jkind)%pot
276 DO i = 1,
SIZE(pot%type)
279 IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
280 particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .true.
285 n_atoms_use = count(use_atom)
289 IF (.NOT.
ASSOCIATED(allegro_data))
THEN
290 ALLOCATE (allegro_data)
292 NULLIFY (allegro_data%use_indices, allegro_data%force)
293 CALL torch_model_load(allegro_data%model, pot%set(1)%allegro%allegro_file_name)
296 IF (
ASSOCIATED(allegro_data%force))
THEN
297 IF (
SIZE(allegro_data%force, 2) /= n_atoms_use)
THEN
298 DEALLOCATE (allegro_data%force, allegro_data%use_indices)
301 IF (.NOT.
ASSOCIATED(allegro_data%force))
THEN
302 ALLOCATE (allegro_data%force(3, n_atoms_use))
303 ALLOCATE (allegro_data%use_indices(n_atoms_use))
307 DO iat = 1, n_atoms_use
308 IF (use_atom(iat))
THEN
309 iat_use = iat_use + 1
310 allegro_data%use_indices(iat_use) = iat
316 ALLOCATE (edge_index(2,
SIZE(glob_loc_list_a)))
317 ALLOCATE (edge_cell_shifts(3,
SIZE(glob_loc_list_a)))
318 ALLOCATE (temp_atom_types(
SIZE(glob_loc_list_a)))
320 DO ilist = 1, nonbonded%nlists
321 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
322 npairs = neighbor_kind_pair%npairs
323 IF (npairs == 0) cycle
324 kind_group_loop_allegro:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
325 istart = neighbor_kind_pair%grp_kind_start(igrp)
326 iend = neighbor_kind_pair%grp_kind_end(igrp)
327 ikind = neighbor_kind_pair%ij_kind(1, igrp)
328 jkind = neighbor_kind_pair%ij_kind(2, igrp)
329 list => neighbor_kind_pair%list
330 cvi = neighbor_kind_pair%cell_vector
331 pot => potparm%pot(ikind, jkind)%pot
332 DO i = 1,
SIZE(pot%type)
334 rab2_max = pot%set(i)%allegro%rcutsq
335 cell_v = matmul(cell%hmat, cvi)
336 pot => potparm%pot(ikind, jkind)%pot
337 allegro => pot%set(i)%allegro
338 npairs = iend - istart + 1
339 IF (npairs /= 0)
THEN
340 ALLOCATE (sort_list(2, npairs), work_list(npairs))
341 sort_list =
list(:, istart:iend)
344 CALL sort(sort_list(1, :), npairs, work_list)
346 work_list(ipair) = sort_list(2, work_list(ipair))
348 sort_list(2, :) = work_list
351 DO ipair = 1, npairs - 1
352 IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
355 junique = sort_list(1, ipair)
357 DO iunique = 1, nunique
359 IF (glob_loc_list_a(ifirst) > atom_a) cycle
360 DO mpair = ifirst,
SIZE(glob_loc_list_a)
361 IF (glob_loc_list_a(mpair) == atom_a)
EXIT
364 DO mpair = ifirst,
SIZE(glob_loc_list_a)
365 IF (glob_loc_list_a(mpair) /= atom_a)
EXIT
369 IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
370 DO WHILE (ipair <= npairs)
371 IF (sort_list(1, ipair) /= junique)
EXIT
372 atom_b = sort_list(2, ipair)
373 rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
374 drij = dot_product(rij, rij)
376 IF (drij <= rab2_max)
THEN
378 edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
379 edge_cell_shifts(:, nedges) = cvi
383 IF (ipair <= npairs) junique = sort_list(1, ipair)
385 DEALLOCATE (sort_list, work_list)
388 END DO kind_group_loop_allegro
391 allegro => pot%set(1)%allegro
393 ALLOCATE (temp_edge_index(2, nedges))
394 temp_edge_index(:, :) = edge_index(:, :nedges)
395 ALLOCATE (new_edge_cell_shifts(3, nedges))
396 new_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
397 DEALLOCATE (edge_cell_shifts)
399 ALLOCATE (t_edge_index(nedges, 2))
401 t_edge_index(:, :) = transpose(temp_edge_index)
402 DEALLOCATE (temp_edge_index, edge_index)
403 ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
404 lattice(:, :) = cell%hmat/pot%set(1)%allegro%unit_cell_val
405 lattice_sp(:, :) = real(lattice, kind=
sp)
407 ALLOCATE (pos(3, n_atoms_use),
atom_types(n_atoms_use))
408 DO iat = 1, n_atoms_use
409 IF (.NOT. use_atom(iat)) cycle
410 iat_use = iat_use + 1
412 DO i = 1,
SIZE(allegro%type_names_torch)
413 IF (particle_set(iat)%atomic_kind%element_symbol == allegro%type_names_torch(i))
THEN
418 pos(:, iat) = r_last_update_pbc(iat)%r(:)/allegro%unit_coords_val
423 IF (allegro%do_allegro_sp)
THEN
424 ALLOCATE (new_edge_cell_shifts_sp(3, nedges), pos_sp(3, n_atoms_use))
425 new_edge_cell_shifts_sp(:, :) = real(new_edge_cell_shifts(:, :), kind=
sp)
426 pos_sp(:, :) = real(pos(:, :), kind=
sp)
427 DEALLOCATE (pos, new_edge_cell_shifts)
456 CALL torch_dict_get(outputs,
"atomic_energy", atomic_energy_tensor)
458 IF (allegro%do_allegro_sp)
THEN
461 allegro_data%force(:, :) = real(forces_sp(:, :), kind=
dp)*allegro%unit_forces_val
462 DO iat_use = 1,
SIZE(unique_list_a)
463 i = unique_list_a(iat_use)
464 pot_allegro = pot_allegro + real(atomic_energy_sp(1, i), kind=
dp)*allegro%unit_energy_val
466 DEALLOCATE (new_edge_cell_shifts_sp, pos_sp)
471 allegro_data%force(:, :) = forces(:, :)*allegro%unit_forces_val
472 DO iat_use = 1,
SIZE(unique_list_a)
473 i = unique_list_a(iat_use)
474 pot_allegro = pot_allegro + atomic_energy(1, i)*allegro%unit_energy_val
476 DEALLOCATE (pos, new_edge_cell_shifts)
484 allegro_data%virial(:, :) = reshape(virial3d, (/3, 3/))*allegro%unit_energy_val
493 IF (use_virial) allegro_data%virial(:, :) = allegro_data%virial/real(para_env%num_pe,
dp)
494 CALL timestop(handle)