18 neighbor_kind_pairs_type
21 fist_nonbond_env_type,&
30 pair_potential_pp_type,&
31 pair_potential_single_type
42 #include "./base/base_uses.f90"
49 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'manybody_nequip'
66 TYPE(fist_neighbor_type),
POINTER :: nonbonded
67 TYPE(pair_potential_pp_type),
POINTER :: potparm
68 INTEGER,
DIMENSION(:, :),
POINTER :: glob_loc_list
69 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: glob_cell_v
70 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
71 TYPE(cell_type),
POINTER :: cell
73 CHARACTER(LEN=*),
PARAMETER :: routinen =
'setup_nequip_arrays'
75 INTEGER :: handle, i, iend, igrp, ikind, ilist, &
76 ipair, istart, jkind, nkinds, npairs, &
78 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: work_list, work_list2
79 INTEGER,
DIMENSION(:, :),
POINTER ::
list
80 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: rwork_list
81 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
82 TYPE(neighbor_kind_pairs_type),
POINTER :: neighbor_kind_pair
83 TYPE(pair_potential_single_type),
POINTER :: pot
85 cpassert(.NOT.
ASSOCIATED(glob_loc_list))
86 cpassert(.NOT.
ASSOCIATED(glob_loc_list_a))
87 cpassert(.NOT.
ASSOCIATED(glob_cell_v))
88 CALL timeset(routinen, handle)
90 nkinds =
SIZE(potparm%pot, 1)
91 DO ilist = 1, nonbonded%nlists
92 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
93 npairs = neighbor_kind_pair%npairs
94 IF (npairs == 0) cycle
95 kind_group_loop1:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
96 istart = neighbor_kind_pair%grp_kind_start(igrp)
97 iend = neighbor_kind_pair%grp_kind_end(igrp)
98 ikind = neighbor_kind_pair%ij_kind(1, igrp)
99 jkind = neighbor_kind_pair%ij_kind(2, igrp)
100 pot => potparm%pot(ikind, jkind)%pot
101 npairs = iend - istart + 1
103 DO i = 1,
SIZE(pot%type)
104 IF (pot%type(i) ==
nequip_type) npairs_tot = npairs_tot + npairs
106 END DO kind_group_loop1
108 ALLOCATE (work_list(npairs_tot))
109 ALLOCATE (work_list2(npairs_tot))
110 ALLOCATE (glob_loc_list(2, npairs_tot))
111 ALLOCATE (glob_cell_v(3, npairs_tot))
114 DO ilist = 1, nonbonded%nlists
115 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
116 npairs = neighbor_kind_pair%npairs
117 IF (npairs == 0) cycle
118 kind_group_loop2:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
119 istart = neighbor_kind_pair%grp_kind_start(igrp)
120 iend = neighbor_kind_pair%grp_kind_end(igrp)
121 ikind = neighbor_kind_pair%ij_kind(1, igrp)
122 jkind = neighbor_kind_pair%ij_kind(2, igrp)
123 list => neighbor_kind_pair%list
124 cvi = neighbor_kind_pair%cell_vector
125 pot => potparm%pot(ikind, jkind)%pot
126 npairs = iend - istart + 1
128 cell_v = matmul(cell%hmat, cvi)
129 DO i = 1,
SIZE(pot%type)
133 glob_loc_list(:, npairs_tot + ipair) =
list(:, istart - 1 + ipair)
134 glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
136 npairs_tot = npairs_tot + npairs
139 END DO kind_group_loop2
142 CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
143 DO ipair = 1, npairs_tot
144 work_list2(ipair) = glob_loc_list(2, work_list(ipair))
146 glob_loc_list(2, :) = work_list2
147 DEALLOCATE (work_list2)
148 ALLOCATE (rwork_list(3, npairs_tot))
149 DO ipair = 1, npairs_tot
150 rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
152 glob_cell_v = rwork_list
153 DEALLOCATE (rwork_list)
154 DEALLOCATE (work_list)
155 ALLOCATE (glob_loc_list_a(npairs_tot))
156 glob_loc_list_a = glob_loc_list(1, :)
157 CALL timestop(handle)
170 INTEGER,
DIMENSION(:, :),
POINTER :: glob_loc_list
171 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: glob_cell_v
172 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
174 IF (
ASSOCIATED(glob_loc_list))
THEN
175 DEALLOCATE (glob_loc_list)
177 IF (
ASSOCIATED(glob_loc_list_a))
THEN
178 DEALLOCATE (glob_loc_list_a)
180 IF (
ASSOCIATED(glob_cell_v))
THEN
181 DEALLOCATE (glob_cell_v)
204 potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
205 pot_nequip, fist_nonbond_env, para_env)
207 TYPE(fist_neighbor_type),
POINTER :: nonbonded
208 TYPE(particle_type),
POINTER :: particle_set(:)
209 TYPE(cell_type),
POINTER :: cell
210 TYPE(atomic_kind_type),
POINTER :: atomic_kind_set(:)
211 TYPE(pair_potential_pp_type),
POINTER :: potparm
212 TYPE(nequip_pot_type),
POINTER :: nequip
213 INTEGER,
DIMENSION(:),
POINTER :: glob_loc_list_a
214 TYPE(pos_type),
DIMENSION(:),
POINTER :: r_last_update_pbc
215 REAL(kind=
dp) :: pot_nequip
216 TYPE(fist_nonbond_env_type),
POINTER :: fist_nonbond_env
217 TYPE(mp_para_env_type),
OPTIONAL,
POINTER :: para_env
219 CHARACTER(LEN=*),
PARAMETER :: routinen =
'nequip_energy_store_force_virial'
221 INTEGER :: atom_a, atom_b, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, ilast, ilist, &
222 ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, nedges, nedges_tot, &
223 nloc_size, npairs, nunique
224 INTEGER(kind=int_8),
ALLOCATABLE ::
atom_types(:)
225 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
226 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: displ, displ_cell, edge_count, &
227 edge_count_cell, work_list
228 INTEGER,
DIMENSION(:, :),
POINTER ::
list, sort_list
229 LOGICAL,
ALLOCATABLE :: use_atom(:)
230 REAL(kind=
dp) :: drij, lattice(3, 3), rab2_max, rij(3)
231 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts, pos, &
232 temp_edge_cell_shifts
233 REAL(kind=
dp),
DIMENSION(3) :: cell_v, cvi
234 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: atomic_energy, forces, total_energy
235 REAL(kind=
sp) :: lattice_sp(3, 3)
236 REAL(kind=
sp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_cell_shifts_sp, pos_sp
237 REAL(kind=
sp),
DIMENSION(:, :),
POINTER :: atomic_energy_sp, forces_sp, &
239 TYPE(neighbor_kind_pairs_type),
POINTER :: neighbor_kind_pair
240 TYPE(nequip_data_type),
POINTER :: nequip_data
241 TYPE(pair_potential_single_type),
POINTER :: pot
242 TYPE(torch_dict_type) :: inputs, outputs
244 CALL timeset(routinen, handle)
246 NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp)
247 n_atoms =
SIZE(particle_set)
248 ALLOCATE (use_atom(n_atoms))
251 DO ikind = 1,
SIZE(atomic_kind_set)
252 DO jkind = 1,
SIZE(atomic_kind_set)
253 pot => potparm%pot(ikind, jkind)%pot
254 DO i = 1,
SIZE(pot%type)
257 IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
258 particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .true.
263 n_atoms_use = count(use_atom)
267 IF (.NOT.
ASSOCIATED(nequip_data))
THEN
268 ALLOCATE (nequip_data)
270 NULLIFY (nequip_data%use_indices, nequip_data%force)
271 CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
274 IF (
ASSOCIATED(nequip_data%force))
THEN
275 IF (
SIZE(nequip_data%force, 2) /= n_atoms_use)
THEN
276 DEALLOCATE (nequip_data%force, nequip_data%use_indices)
279 IF (.NOT.
ASSOCIATED(nequip_data%force))
THEN
280 ALLOCATE (nequip_data%force(3, n_atoms_use))
281 ALLOCATE (nequip_data%use_indices(n_atoms_use))
285 DO iat = 1, n_atoms_use
286 IF (use_atom(iat))
THEN
287 iat_use = iat_use + 1
288 nequip_data%use_indices(iat_use) = iat
293 ALLOCATE (edge_index(2,
SIZE(glob_loc_list_a)))
294 ALLOCATE (edge_cell_shifts(3,
SIZE(glob_loc_list_a)))
295 DO ilist = 1, nonbonded%nlists
296 neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
297 npairs = neighbor_kind_pair%npairs
298 IF (npairs == 0) cycle
299 kind_group_loop_nequip:
DO igrp = 1, neighbor_kind_pair%ngrp_kind
300 istart = neighbor_kind_pair%grp_kind_start(igrp)
301 iend = neighbor_kind_pair%grp_kind_end(igrp)
302 ikind = neighbor_kind_pair%ij_kind(1, igrp)
303 jkind = neighbor_kind_pair%ij_kind(2, igrp)
304 list => neighbor_kind_pair%list
305 cvi = neighbor_kind_pair%cell_vector
306 pot => potparm%pot(ikind, jkind)%pot
307 DO i = 1,
SIZE(pot%type)
309 rab2_max = pot%set(i)%nequip%rcutsq
310 cell_v = matmul(cell%hmat, cvi)
311 pot => potparm%pot(ikind, jkind)%pot
312 nequip => pot%set(i)%nequip
313 npairs = iend - istart + 1
314 IF (npairs /= 0)
THEN
315 ALLOCATE (sort_list(2, npairs), work_list(npairs))
316 sort_list =
list(:, istart:iend)
319 CALL sort(sort_list(1, :), npairs, work_list)
321 work_list(ipair) = sort_list(2, work_list(ipair))
323 sort_list(2, :) = work_list
326 DO ipair = 1, npairs - 1
327 IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
330 junique = sort_list(1, ipair)
332 DO iunique = 1, nunique
334 IF (glob_loc_list_a(ifirst) > atom_a) cycle
335 DO mpair = ifirst,
SIZE(glob_loc_list_a)
336 IF (glob_loc_list_a(mpair) == atom_a)
EXIT
339 DO mpair = ifirst,
SIZE(glob_loc_list_a)
340 IF (glob_loc_list_a(mpair) /= atom_a)
EXIT
344 IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
345 DO WHILE (ipair <= npairs)
346 IF (sort_list(1, ipair) /= junique)
EXIT
347 atom_b = sort_list(2, ipair)
348 rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
349 drij = dot_product(rij, rij)
351 IF (drij <= rab2_max)
THEN
353 edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
354 edge_cell_shifts(:, nedges) = cvi
358 IF (ipair <= npairs) junique = sort_list(1, ipair)
360 DEALLOCATE (sort_list, work_list)
363 END DO kind_group_loop_nequip
366 nequip => pot%set(1)%nequip
368 ALLOCATE (edge_count(para_env%num_pe))
369 ALLOCATE (edge_count_cell(para_env%num_pe))
370 ALLOCATE (displ_cell(para_env%num_pe))
371 ALLOCATE (displ(para_env%num_pe))
373 CALL para_env%allgather(nedges, edge_count)
374 nedges_tot = sum(edge_count)
376 ALLOCATE (temp_edge_index(2, nedges))
377 temp_edge_index(:, :) = edge_index(:, :nedges)
378 DEALLOCATE (edge_index)
379 ALLOCATE (temp_edge_cell_shifts(3, nedges))
380 temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
381 DEALLOCATE (edge_cell_shifts)
383 ALLOCATE (edge_index(2, nedges_tot))
384 ALLOCATE (edge_cell_shifts(3, nedges_tot))
385 ALLOCATE (t_edge_index(nedges_tot, 2))
387 edge_count_cell(:) = edge_count*3
388 edge_count = edge_count*2
391 DO ipair = 2, para_env%num_pe
392 displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
393 displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
396 CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
397 CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
399 t_edge_index(:, :) = transpose(edge_index)
400 DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
402 lattice = cell%hmat/nequip%unit_cell_val
403 lattice_sp = real(lattice, kind=
sp)
406 ALLOCATE (pos(3, n_atoms_use),
atom_types(n_atoms_use))
408 DO iat = 1, n_atoms_use
409 IF (.NOT. use_atom(iat)) cycle
410 iat_use = iat_use + 1
411 atom_types(iat_use) = particle_set(iat)%atomic_kind%kind_number - 1
412 pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
416 IF (nequip%do_nequip_sp)
THEN
417 ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
418 pos_sp(:, :) = real(pos(:, :), kind=
sp)
419 edge_cell_shifts_sp(:, :) = real(edge_cell_shifts(:, :), kind=
sp)
420 CALL torch_dict_insert(inputs,
"pos", pos_sp)
421 CALL torch_dict_insert(inputs,
"edge_cell_shift", edge_cell_shifts_sp)
422 CALL torch_dict_insert(inputs,
"cell", lattice_sp)
424 CALL torch_dict_insert(inputs,
"pos", pos)
425 CALL torch_dict_insert(inputs,
"edge_cell_shift", edge_cell_shifts)
426 CALL torch_dict_insert(inputs,
"cell", lattice)
429 CALL torch_dict_insert(inputs,
"edge_index", t_edge_index)
430 CALL torch_dict_insert(inputs,
"atom_types",
atom_types)
436 IF (nequip%do_nequip_sp)
THEN
437 CALL torch_dict_get(outputs,
"total_energy", total_energy_sp)
438 CALL torch_dict_get(outputs,
"atomic_energy", atomic_energy_sp)
439 CALL torch_dict_get(outputs,
"forces", forces_sp)
440 pot_nequip = real(total_energy_sp(1, 1), kind=
dp)*nequip%unit_energy_val
441 nequip_data%force(:, :) = real(forces_sp(:, :), kind=
dp)*nequip%unit_forces_val
442 DEALLOCATE (pos_sp, edge_cell_shifts_sp, total_energy_sp, atomic_energy_sp, forces_sp)
444 CALL torch_dict_get(outputs,
"total_energy", total_energy)
445 CALL torch_dict_get(outputs,
"atomic_energy", atomic_energy)
446 CALL torch_dict_get(outputs,
"forces", forces)
447 pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
448 nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
449 DEALLOCATE (pos, edge_cell_shifts, total_energy, atomic_energy, forces)
458 IF (
PRESENT(para_env))
THEN
459 pot_nequip = pot_nequip/real(para_env%num_pe,
dp)
460 nequip_data%force = nequip_data%force/real(para_env%num_pe,
dp)
463 CALL timestop(handle)
475 TYPE(fist_nonbond_env_type),
POINTER :: fist_nonbond_env
476 REAL(kind=
dp),
DIMENSION(:, :),
INTENT(INOUT) :: f_nonbond, pv_nonbond
477 LOGICAL,
INTENT(IN) :: use_virial
479 INTEGER :: iat, iat_use
480 REAL(kind=
dp),
DIMENSION(3, 3) :: virial
481 TYPE(nequip_data_type),
POINTER :: nequip_data
487 pv_nonbond = pv_nonbond + virial
488 cpabort(
"Stress tensor for NequIP not yet implemented")
491 DO iat_use = 1,
SIZE(nequip_data%use_indices)
492 iat = nequip_data%use_indices(iat_use)
493 cpassert(iat >= 1 .AND. iat <=
SIZE(f_nonbond, 2))
494 f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + nequip_data%force(1:3, iat_use)
Define the atom type and its sub types.
Define the atomic kind types and their sub types.
Handles all functions related to the CELL.
Define the neighbor list data types and the corresponding functionality.
subroutine, public fist_nonbond_env_get(fist_nonbond_env, potparm14, potparm, nonbonded, rlist_cut, rlist_lowsq, aup, lup, ei_scale14, vdw_scale14, shift_cutoff, do_electrostatics, r_last_update, r_last_update_pbc, rshell_last_update_pbc, rcore_last_update_pbc, cell_last_update, num_update, last_update, counter, natom_types, long_range_correction, ij_kind_full_fac, eam_data, quip_data, nequip_data, allegro_data, deepmd_data, charges)
sets a fist_nonbond_env
subroutine, public fist_nonbond_env_set(fist_nonbond_env, potparm14, potparm, rlist_cut, rlist_lowsq, nonbonded, aup, lup, ei_scale14, vdw_scale14, shift_cutoff, do_electrostatics, r_last_update, r_last_update_pbc, rshell_last_update_pbc, rcore_last_update_pbc, cell_last_update, num_update, last_update, counter, natom_types, long_range_correction, eam_data, quip_data, nequip_data, allegro_data, deepmd_data, charges)
sets a fist_nonbond_env
Defines the basic variable types.
integer, parameter, public int_8
integer, parameter, public dp
integer, parameter, public sp
An array-based list which grows on demand. When the internal array is full, a new array of twice the ...
subroutine, public nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
...
subroutine, public nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, potparm, nequip, glob_loc_list_a, r_last_update_pbc, pot_nequip, fist_nonbond_env, para_env)
...
subroutine, public destroy_nequip_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a)
...
subroutine, public setup_nequip_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, cell)
...
Interface to the message passing library MPI.
integer, parameter, public nequip_type
Define the data structure for the particle information.
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
subroutine, public torch_model_eval(model, inputs, outputs)
Evaluates the given Torch model. (In Torch lingo this operation is called forward())
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
All kind of helpful little routines.