64 INTEGER,
DIMENSION(:, :),
POINTER :: glob_loc_list
65 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: glob_cell_v
66 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
69 CHARACTER(LEN=*),
PARAMETER :: routinen =
'setup_nequip_arrays'
71 INTEGER :: handle, i, iend, igrp, ikind, ilist, &
72 ipair, istart, jkind, nkinds, npairs, &
74 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: work_list, work_list2
75 INTEGER,
DIMENSION(:, :),
POINTER ::
list
76 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: rwork_list
77 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
81 cpassert(.NOT.
ASSOCIATED(glob_loc_list))
82 cpassert(.NOT.
ASSOCIATED(glob_loc_list_a))
83 cpassert(.NOT.
ASSOCIATED(glob_cell_v))
84 CALL timeset(routinen, handle)
86 nkinds =
SIZE(potparm%pot, 1)
87 DO ilist = 1, nonbonded%nlists
88 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
89 npairs = neighbor_kind_pair%npairs
90 IF (npairs == 0) cycle
91 kind_group_loop1:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
92 istart = neighbor_kind_pair%grp_kind_start(igrp)
93 iend = neighbor_kind_pair%grp_kind_end(igrp)
94 ikind = neighbor_kind_pair%ij_kind(1, igrp)
95 jkind = neighbor_kind_pair%ij_kind(2, igrp)
96 pot => potparm%pot(ikind, jkind)%pot
97 npairs = iend - istart + 1
99 DO i = 1,
SIZE(pot%type)
100 IF (pot%type(i) ==
nequip_type) npairs_tot = npairs_tot + npairs
102 END DO kind_group_loop1
104 ALLOCATE (work_list(npairs_tot))
105 ALLOCATE (work_list2(npairs_tot))
106 ALLOCATE (glob_loc_list(2, npairs_tot))
107 ALLOCATE (glob_cell_v(3, npairs_tot))
110 DO ilist = 1, nonbonded%nlists
111 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
112 npairs = neighbor_kind_pair%npairs
113 IF (npairs == 0) cycle
114 kind_group_loop2:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
115 istart = neighbor_kind_pair%grp_kind_start(igrp)
116 iend = neighbor_kind_pair%grp_kind_end(igrp)
117 ikind = neighbor_kind_pair%ij_kind(1, igrp)
118 jkind = neighbor_kind_pair%ij_kind(2, igrp)
119 list => neighbor_kind_pair%list
120 cvi = neighbor_kind_pair%cell_vector
121 pot => potparm%pot(ikind, jkind)%pot
122 npairs = iend - istart + 1
124 cell_v = matmul(cell%hmat, cvi)
125 DO i = 1,
SIZE(pot%type)
129 glob_loc_list(:, npairs_tot + ipair) =
list(:, istart - 1 + ipair)
130 glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
132 npairs_tot = npairs_tot + npairs
135 END DO kind_group_loop2
138 CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
139 DO ipair = 1, npairs_tot
140 work_list2(ipair) = glob_loc_list(2, work_list(ipair))
142 glob_loc_list(2, :) = work_list2
143 DEALLOCATE (work_list2)
144 ALLOCATE (rwork_list(3, npairs_tot))
145 DO ipair = 1, npairs_tot
146 rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
148 glob_cell_v = rwork_list
149 DEALLOCATE (rwork_list)
150 DEALLOCATE (work_list)
151 ALLOCATE (glob_loc_list_a(npairs_tot))
152 glob_loc_list_a = glob_loc_list(1, :)
153 CALL timestop(handle)
201 potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
202 pot_nequip, fist_nonbond_env, para_env, use_virial)
210 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
211 TYPE(
pos_type),
DIMENSION(:),
POINTER :: r_last_update_pbc
212 REAL(kind=
dp) :: pot_nequip
215 LOGICAL,
INTENT(IN) :: use_virial
217 CHARACTER(LEN=*),
PARAMETER :: routinen =
'nequip_energy_store_force_virial'
219 INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
220 ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
221 nedges, nedges_tot, nloc_size, npairs, nunique
222 INTEGER(kind=int_8),
ALLOCATABLE ::
atom_types(:)
223 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
224 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: displ, displ_cell, edge_count, &
225 edge_count_cell, work_list
226 INTEGER,
DIMENSION(:, :),
POINTER ::
list, sort_list
227 LOGICAL,
ALLOCATABLE :: use_atom(:)
228 REAL(kind=
dp) :: drij, rab2_max, rij(3)
229 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts, lattice, pos, &
230 temp_edge_cell_shifts
231 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
232 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: atomic_energy, forces, total_energy
233 REAL(kind=
dp),
DIMENSION(:, :, :),
POINTER :: virial3d
234 REAL(kind=
sp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts_sp, lattice_sp, pos_sp
235 REAL(kind=
sp),
DIMENSION(:, :),
POINTER :: atomic_energy_sp, forces_sp, &
241 TYPE(
torch_tensor_type) :: atom_types_tensor, atomic_energy_tensor, edge_cell_shifts_tensor, &
242 forces_tensor, lattice_tensor, pos_tensor, t_edge_index_tensor, total_energy_tensor, &
245 CALL timeset(routinen, handle)
247 NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp, virial3d)
248 n_atoms =
SIZE(particle_set)
249 ALLOCATE (use_atom(n_atoms))
252 DO ikind = 1,
SIZE(atomic_kind_set)
253 DO jkind = 1,
SIZE(atomic_kind_set)
254 pot => potparm%pot(ikind, jkind)%pot
255 DO i = 1,
SIZE(pot%type)
258 IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
259 particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .true.
264 n_atoms_use = count(use_atom)
268 IF (.NOT.
ASSOCIATED(nequip_data))
THEN
269 ALLOCATE (nequip_data)
271 NULLIFY (nequip_data%use_indices, nequip_data%force)
272 CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
275 IF (
ASSOCIATED(nequip_data%force))
THEN
276 IF (
SIZE(nequip_data%force, 2) /= n_atoms_use)
THEN
277 DEALLOCATE (nequip_data%force, nequip_data%use_indices)
280 IF (.NOT.
ASSOCIATED(nequip_data%force))
THEN
281 ALLOCATE (nequip_data%force(3, n_atoms_use))
282 ALLOCATE (nequip_data%use_indices(n_atoms_use))
286 DO iat = 1, n_atoms_use
287 IF (use_atom(iat))
THEN
288 iat_use = iat_use + 1
289 nequip_data%use_indices(iat_use) = iat
294 ALLOCATE (edge_index(2,
SIZE(glob_loc_list_a)))
295 ALLOCATE (edge_cell_shifts(3,
SIZE(glob_loc_list_a)))
296 DO ilist = 1, nonbonded%nlists
297 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
298 npairs = neighbor_kind_pair%npairs
299 IF (npairs == 0) cycle
300 kind_group_loop_nequip:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
301 istart = neighbor_kind_pair%grp_kind_start(igrp)
302 iend = neighbor_kind_pair%grp_kind_end(igrp)
303 ikind = neighbor_kind_pair%ij_kind(1, igrp)
304 jkind = neighbor_kind_pair%ij_kind(2, igrp)
305 list => neighbor_kind_pair%list
306 cvi = neighbor_kind_pair%cell_vector
307 pot => potparm%pot(ikind, jkind)%pot
308 DO i = 1,
SIZE(pot%type)
310 rab2_max = pot%set(i)%nequip%rcutsq
311 cell_v = matmul(cell%hmat, cvi)
312 pot => potparm%pot(ikind, jkind)%pot
313 nequip => pot%set(i)%nequip
314 npairs = iend - istart + 1
315 IF (npairs /= 0)
THEN
316 ALLOCATE (sort_list(2, npairs), work_list(npairs))
317 sort_list =
list(:, istart:iend)
320 CALL sort(sort_list(1, :), npairs, work_list)
322 work_list(ipair) = sort_list(2, work_list(ipair))
324 sort_list(2, :) = work_list
327 DO ipair = 1, npairs - 1
328 IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
331 junique = sort_list(1, ipair)
333 DO iunique = 1, nunique
335 IF (glob_loc_list_a(ifirst) > atom_a) cycle
336 DO mpair = ifirst,
SIZE(glob_loc_list_a)
337 IF (glob_loc_list_a(mpair) == atom_a)
EXIT
340 DO mpair = ifirst,
SIZE(glob_loc_list_a)
341 IF (glob_loc_list_a(mpair) /= atom_a)
EXIT
345 IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
346 DO WHILE (ipair <= npairs)
347 IF (sort_list(1, ipair) /= junique)
EXIT
348 atom_b = sort_list(2, ipair)
349 rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
350 drij = dot_product(rij, rij)
352 IF (drij <= rab2_max)
THEN
354 edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
355 edge_cell_shifts(:, nedges) = cvi
359 IF (ipair <= npairs) junique = sort_list(1, ipair)
361 DEALLOCATE (sort_list, work_list)
364 END DO kind_group_loop_nequip
367 nequip => pot%set(1)%nequip
369 ALLOCATE (edge_count(para_env%num_pe))
370 ALLOCATE (edge_count_cell(para_env%num_pe))
371 ALLOCATE (displ_cell(para_env%num_pe))
372 ALLOCATE (displ(para_env%num_pe))
374 CALL para_env%allgather(nedges, edge_count)
375 nedges_tot = sum(edge_count)
377 ALLOCATE (temp_edge_index(2, nedges))
378 temp_edge_index(:, :) = edge_index(:, :nedges)
379 DEALLOCATE (edge_index)
380 ALLOCATE (temp_edge_cell_shifts(3, nedges))
381 temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
382 DEALLOCATE (edge_cell_shifts)
384 ALLOCATE (edge_index(2, nedges_tot))
385 ALLOCATE (edge_cell_shifts(3, nedges_tot))
386 ALLOCATE (t_edge_index(nedges_tot, 2))
388 edge_count_cell(:) = edge_count*3
389 edge_count = edge_count*2
392 DO ipair = 2, para_env%num_pe
393 displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
394 displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
397 CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
398 CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
400 t_edge_index(:, :) = transpose(edge_index)
401 DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
403 ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
404 lattice(:, :) = cell%hmat/nequip%unit_cell_val
405 lattice_sp(:, :) = real(lattice, kind=
sp)
408 ALLOCATE (pos(3, n_atoms_use),
atom_types(n_atoms_use))
410 DO iat = 1, n_atoms_use
411 IF (.NOT. use_atom(iat)) cycle
412 iat_use = iat_use + 1
414 DO i = 1,
SIZE(nequip%type_names_torch)
415 IF (particle_set(iat)%atomic_kind%element_symbol == nequip%type_names_torch(i))
THEN
420 pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
424 IF (nequip%do_nequip_sp)
THEN
425 ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
426 pos_sp(:, :) = real(pos(:, :), kind=
sp)
427 edge_cell_shifts_sp(:, :) = real(edge_cell_shifts(:, :), kind=
sp)
456 CALL torch_dict_get(outputs,
"atomic_energy", atomic_energy_tensor)
458 IF (nequip%do_nequip_sp)
THEN
462 pot_nequip = real(total_energy_sp(1, 1), kind=
dp)*nequip%unit_energy_val
463 nequip_data%force(:, :) = real(forces_sp(:, :), kind=
dp)*nequip%unit_forces_val
464 DEALLOCATE (pos_sp, edge_cell_shifts_sp)
469 pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
470 nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
471 DEALLOCATE (pos, edge_cell_shifts)
480 nequip_data%virial(:, :) = reshape(virial3d, (/3, 3/))*nequip%unit_energy_val
490 pot_nequip = pot_nequip/real(para_env%num_pe,
dp)
491 nequip_data%force = nequip_data%force/real(para_env%num_pe,
dp)
492 IF (use_virial) nequip_data%virial(:, :) = nequip_data%virial/real(para_env%num_pe,
dp)
494 CALL timestop(handle)