68 INTEGER,
DIMENSION(:, :),
POINTER :: glob_loc_list
69 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: glob_cell_v
70 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
73 CHARACTER(LEN=*),
PARAMETER :: routinen =
'setup_nequip_arrays'
75 INTEGER :: handle, i, iend, igrp, ikind, ilist, &
76 ipair, istart, jkind, nkinds, npairs, &
78 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: work_list, work_list2
79 INTEGER,
DIMENSION(:, :),
POINTER ::
list
80 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: rwork_list
81 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
85 cpassert(.NOT.
ASSOCIATED(glob_loc_list))
86 cpassert(.NOT.
ASSOCIATED(glob_loc_list_a))
87 cpassert(.NOT.
ASSOCIATED(glob_cell_v))
88 CALL timeset(routinen, handle)
90 nkinds =
SIZE(potparm%pot, 1)
91 DO ilist = 1, nonbonded%nlists
92 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
93 npairs = neighbor_kind_pair%npairs
94 IF (npairs == 0) cycle
95 kind_group_loop1:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
96 istart = neighbor_kind_pair%grp_kind_start(igrp)
97 iend = neighbor_kind_pair%grp_kind_end(igrp)
98 ikind = neighbor_kind_pair%ij_kind(1, igrp)
99 jkind = neighbor_kind_pair%ij_kind(2, igrp)
100 pot => potparm%pot(ikind, jkind)%pot
101 npairs = iend - istart + 1
103 DO i = 1,
SIZE(pot%type)
104 IF (pot%type(i) ==
nequip_type) npairs_tot = npairs_tot + npairs
106 END DO kind_group_loop1
108 ALLOCATE (work_list(npairs_tot))
109 ALLOCATE (work_list2(npairs_tot))
110 ALLOCATE (glob_loc_list(2, npairs_tot))
111 ALLOCATE (glob_cell_v(3, npairs_tot))
114 DO ilist = 1, nonbonded%nlists
115 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
116 npairs = neighbor_kind_pair%npairs
117 IF (npairs == 0) cycle
118 kind_group_loop2:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
119 istart = neighbor_kind_pair%grp_kind_start(igrp)
120 iend = neighbor_kind_pair%grp_kind_end(igrp)
121 ikind = neighbor_kind_pair%ij_kind(1, igrp)
122 jkind = neighbor_kind_pair%ij_kind(2, igrp)
123 list => neighbor_kind_pair%list
124 cvi = neighbor_kind_pair%cell_vector
125 pot => potparm%pot(ikind, jkind)%pot
126 npairs = iend - istart + 1
128 cell_v = matmul(cell%hmat, cvi)
129 DO i = 1,
SIZE(pot%type)
133 glob_loc_list(:, npairs_tot + ipair) =
list(:, istart - 1 + ipair)
134 glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
136 npairs_tot = npairs_tot + npairs
139 END DO kind_group_loop2
142 CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
143 DO ipair = 1, npairs_tot
144 work_list2(ipair) = glob_loc_list(2, work_list(ipair))
146 glob_loc_list(2, :) = work_list2
147 DEALLOCATE (work_list2)
148 ALLOCATE (rwork_list(3, npairs_tot))
149 DO ipair = 1, npairs_tot
150 rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
152 glob_cell_v = rwork_list
153 DEALLOCATE (rwork_list)
154 DEALLOCATE (work_list)
155 ALLOCATE (glob_loc_list_a(npairs_tot))
156 glob_loc_list_a = glob_loc_list(1, :)
157 CALL timestop(handle)
204 potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
205 pot_nequip, fist_nonbond_env, para_env)
213 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
214 TYPE(
pos_type),
DIMENSION(:),
POINTER :: r_last_update_pbc
215 REAL(kind=
dp) :: pot_nequip
219 CHARACTER(LEN=*),
PARAMETER :: routinen =
'nequip_energy_store_force_virial'
221 INTEGER :: atom_a, atom_b, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, ilast, ilist, &
222 ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, nedges, nedges_tot, &
223 nloc_size, npairs, nunique
224 INTEGER(kind=int_8),
ALLOCATABLE ::
atom_types(:)
225 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
226 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: displ, displ_cell, edge_count, &
227 edge_count_cell, work_list
228 INTEGER,
DIMENSION(:, :),
POINTER ::
list, sort_list
229 LOGICAL,
ALLOCATABLE :: use_atom(:)
230 REAL(kind=
dp) :: drij, lattice(3, 3), rab2_max, rij(3)
231 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts, pos, &
232 temp_edge_cell_shifts
233 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
234 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: atomic_energy, forces, total_energy
235 REAL(kind=
sp) :: lattice_sp(3, 3)
236 REAL(kind=
sp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts_sp, pos_sp
237 REAL(kind=
sp),
DIMENSION(:, :),
POINTER :: atomic_energy_sp, forces_sp, &
244 CALL timeset(routinen, handle)
246 NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp)
247 n_atoms =
SIZE(particle_set)
248 ALLOCATE (use_atom(n_atoms))
251 DO ikind = 1,
SIZE(atomic_kind_set)
252 DO jkind = 1,
SIZE(atomic_kind_set)
253 pot => potparm%pot(ikind, jkind)%pot
254 DO i = 1,
SIZE(pot%type)
257 IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
258 particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .true.
263 n_atoms_use = count(use_atom)
267 IF (.NOT.
ASSOCIATED(nequip_data))
THEN
268 ALLOCATE (nequip_data)
270 NULLIFY (nequip_data%use_indices, nequip_data%force)
271 CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
274 IF (
ASSOCIATED(nequip_data%force))
THEN
275 IF (
SIZE(nequip_data%force, 2) /= n_atoms_use)
THEN
276 DEALLOCATE (nequip_data%force, nequip_data%use_indices)
279 IF (.NOT.
ASSOCIATED(nequip_data%force))
THEN
280 ALLOCATE (nequip_data%force(3, n_atoms_use))
281 ALLOCATE (nequip_data%use_indices(n_atoms_use))
285 DO iat = 1, n_atoms_use
286 IF (use_atom(iat))
THEN
287 iat_use = iat_use + 1
288 nequip_data%use_indices(iat_use) = iat
293 ALLOCATE (edge_index(2,
SIZE(glob_loc_list_a)))
294 ALLOCATE (edge_cell_shifts(3,
SIZE(glob_loc_list_a)))
295 DO ilist = 1, nonbonded%nlists
296 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
297 npairs = neighbor_kind_pair%npairs
298 IF (npairs == 0) cycle
299 kind_group_loop_nequip:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
300 istart = neighbor_kind_pair%grp_kind_start(igrp)
301 iend = neighbor_kind_pair%grp_kind_end(igrp)
302 ikind = neighbor_kind_pair%ij_kind(1, igrp)
303 jkind = neighbor_kind_pair%ij_kind(2, igrp)
304 list => neighbor_kind_pair%list
305 cvi = neighbor_kind_pair%cell_vector
306 pot => potparm%pot(ikind, jkind)%pot
307 DO i = 1,
SIZE(pot%type)
309 rab2_max = pot%set(i)%nequip%rcutsq
310 cell_v = matmul(cell%hmat, cvi)
311 pot => potparm%pot(ikind, jkind)%pot
312 nequip => pot%set(i)%nequip
313 npairs = iend - istart + 1
314 IF (npairs /= 0)
THEN
315 ALLOCATE (sort_list(2, npairs), work_list(npairs))
316 sort_list =
list(:, istart:iend)
319 CALL sort(sort_list(1, :), npairs, work_list)
321 work_list(ipair) = sort_list(2, work_list(ipair))
323 sort_list(2, :) = work_list
326 DO ipair = 1, npairs - 1
327 IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
330 junique = sort_list(1, ipair)
332 DO iunique = 1, nunique
334 IF (glob_loc_list_a(ifirst) > atom_a) cycle
335 DO mpair = ifirst,
SIZE(glob_loc_list_a)
336 IF (glob_loc_list_a(mpair) == atom_a)
EXIT
339 DO mpair = ifirst,
SIZE(glob_loc_list_a)
340 IF (glob_loc_list_a(mpair) /= atom_a)
EXIT
344 IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
345 DO WHILE (ipair <= npairs)
346 IF (sort_list(1, ipair) /= junique)
EXIT
347 atom_b = sort_list(2, ipair)
348 rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
349 drij = dot_product(rij, rij)
351 IF (drij <= rab2_max)
THEN
353 edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
354 edge_cell_shifts(:, nedges) = cvi
358 IF (ipair <= npairs) junique = sort_list(1, ipair)
360 DEALLOCATE (sort_list, work_list)
363 END DO kind_group_loop_nequip
366 nequip => pot%set(1)%nequip
368 ALLOCATE (edge_count(para_env%num_pe))
369 ALLOCATE (edge_count_cell(para_env%num_pe))
370 ALLOCATE (displ_cell(para_env%num_pe))
371 ALLOCATE (displ(para_env%num_pe))
373 CALL para_env%allgather(nedges, edge_count)
374 nedges_tot = sum(edge_count)
376 ALLOCATE (temp_edge_index(2, nedges))
377 temp_edge_index(:, :) = edge_index(:, :nedges)
378 DEALLOCATE (edge_index)
379 ALLOCATE (temp_edge_cell_shifts(3, nedges))
380 temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
381 DEALLOCATE (edge_cell_shifts)
383 ALLOCATE (edge_index(2, nedges_tot))
384 ALLOCATE (edge_cell_shifts(3, nedges_tot))
385 ALLOCATE (t_edge_index(nedges_tot, 2))
387 edge_count_cell(:) = edge_count*3
388 edge_count = edge_count*2
391 DO ipair = 2, para_env%num_pe
392 displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
393 displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
396 CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
397 CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
399 t_edge_index(:, :) = transpose(edge_index)
400 DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
402 lattice = cell%hmat/nequip%unit_cell_val
403 lattice_sp = real(lattice, kind=
sp)
406 ALLOCATE (pos(3, n_atoms_use),
atom_types(n_atoms_use))
408 DO iat = 1, n_atoms_use
409 IF (.NOT. use_atom(iat)) cycle
410 iat_use = iat_use + 1
411 atom_types(iat_use) = particle_set(iat)%atomic_kind%kind_number - 1
412 pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
416 IF (nequip%do_nequip_sp)
THEN
417 ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
418 pos_sp(:, :) = real(pos(:, :), kind=
sp)
419 edge_cell_shifts_sp(:, :) = real(edge_cell_shifts(:, :), kind=
sp)
436 IF (nequip%do_nequip_sp)
THEN
440 pot_nequip = real(total_energy_sp(1, 1), kind=
dp)*nequip%unit_energy_val
441 nequip_data%force(:, :) = real(forces_sp(:, :), kind=
dp)*nequip%unit_forces_val
442 DEALLOCATE (pos_sp, edge_cell_shifts_sp, total_energy_sp, atomic_energy_sp, forces_sp)
447 pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
448 nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
449 DEALLOCATE (pos, edge_cell_shifts, total_energy, atomic_energy, forces)
458 IF (
PRESENT(para_env))
THEN
459 pot_nequip = pot_nequip/real(para_env%num_pe,
dp)
460 nequip_data%force = nequip_data%force/real(para_env%num_pe,
dp)
463 CALL timestop(handle)