(git:ed6f26b)
Loading...
Searching...
No Matches
pao_model.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief Module for equivariant PAO-ML based on PyTorch.
10!> \author Ole Schuett
11! **************************************************************************************************
13 USE omp_lib, ONLY: omp_init_lock,&
14 omp_set_lock,&
15 omp_unset_lock
19 USE cell_types, ONLY: cell_type,&
20 pbc
21 USE cp_dbcsr_api, ONLY: dbcsr_get_info,&
28 USE kinds, ONLY: default_path_length,&
30 dp,&
31 sp
33 USE pao_types, ONLY: pao_env_type,&
36 USE physcon, ONLY: angstrom
39 USE qs_kind_types, ONLY: get_qs_kind,&
41 USE torch_api, ONLY: &
46 USE util, ONLY: sort
47#include "./base/base_uses.f90"
48
49 IMPLICIT NONE
50
51 PRIVATE
52
54
55CONTAINS
56
57! **************************************************************************************************
58!> \brief Loads a PAO-ML model.
59!> \param pao ...
60!> \param qs_env ...
61!> \param ikind ...
62!> \param pao_model_file ...
63!> \param model ...
64! **************************************************************************************************
65 SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
66 TYPE(pao_env_type), INTENT(IN) :: pao
67 TYPE(qs_environment_type), INTENT(IN) :: qs_env
68 INTEGER, INTENT(IN) :: ikind
69 CHARACTER(LEN=default_path_length), INTENT(IN) :: pao_model_file
70 TYPE(pao_model_type), INTENT(OUT) :: model
71
72 CHARACTER(len=*), PARAMETER :: routinen = 'pao_model_load'
73
74 CHARACTER(LEN=default_string_length) :: kind_name
75 CHARACTER(LEN=default_string_length), &
76 ALLOCATABLE, DIMENSION(:) :: feature_kind_names
77 INTEGER :: handle, jkind, kkind, pao_basis_size, z
78 TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
79 TYPE(gto_basis_set_type), POINTER :: basis_set
80 TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
81
82 CALL timeset(routinen, handle)
83 CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
84
85 IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//trim(pao_model_file)
86 CALL torch_model_load(model%torch_model, pao_model_file)
87
88 ! Read model attributes.
89 CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
90 CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
91 CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
92 CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
93 CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
94 CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
95 CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
96 CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
97 CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
98
99 ! Freeze model after all attributes have been read.
100 ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
101 ! https://github.com/pytorch/pytorch/issues/96726
102 ! CALL torch_model_freeze(model%torch_model)
103
104 ! For each feature kind name lookup its corresponding atomic kind number.
105 ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
106 model%feature_kinds(:) = -1
107 DO jkind = 1, SIZE(feature_kind_names)
108 DO kkind = 1, SIZE(atomic_kind_set)
109 IF (trim(atomic_kind_set(kkind)%name) == trim(feature_kind_names(jkind))) THEN
110 model%feature_kinds(jkind) = kkind
111 END IF
112 END DO
113 IF (model%feature_kinds(jkind) < 0) THEN
114 IF (pao%iw > 0) &
115 WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
116 trim(feature_kind_names(jkind))//"' that is not present in subsys."
117 END IF
118 END DO
119
120 ! Check for missing kinds.
121 DO jkind = 1, SIZE(atomic_kind_set)
122 IF (all(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
123 IF (pao%iw > 0) &
124 WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
125 trim(atomic_kind_set(jkind)%name)//"' that is present in subsys."
126 END IF
127 END DO
128
129 ! Check compatibility
130 CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
131 CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
132 IF (model%version /= 1) &
133 cpabort("Model version not supported.")
134 IF (trim(model%kind_name) .NE. trim(kind_name)) &
135 cpabort("Kind name does not match.")
136 IF (model%atomic_number /= z) &
137 cpabort("Atomic number does not match.")
138 IF (trim(model%prim_basis_name) .NE. trim(basis_set%name)) &
139 cpabort("Primary basis set name does not match.")
140 IF (model%prim_basis_size /= basis_set%nsgf) &
141 cpabort("Primary basis set size does not match.")
142 IF (model%pao_basis_size /= pao_basis_size) &
143 cpabort("PAO basis size does not match.")
144
145 CALL omp_init_lock(model%lock)
146 CALL timestop(handle)
147
148 END SUBROUTINE pao_model_load
149
150! **************************************************************************************************
151!> \brief Fills pao%matrix_X based on machine learning predictions
152!> \param pao ...
153!> \param qs_env ...
154! **************************************************************************************************
155 SUBROUTINE pao_model_predict(pao, qs_env)
156 TYPE(pao_env_type), POINTER :: pao
157 TYPE(qs_environment_type), POINTER :: qs_env
158
159 CHARACTER(len=*), PARAMETER :: routinen = 'pao_model_predict'
160
161 INTEGER :: acol, arow, handle, iatom
162 REAL(dp), DIMENSION(:, :), POINTER :: block_x
163 TYPE(dbcsr_iterator_type) :: iter
164
165 CALL timeset(routinen, handle)
166
167!$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
168 CALL dbcsr_iterator_start(iter, pao%matrix_X)
169 DO WHILE (dbcsr_iterator_blocks_left(iter))
170 CALL dbcsr_iterator_next_block(iter, arow, acol, block_x)
171 IF (SIZE(block_x) == 0) cycle ! pao disabled for iatom
172 iatom = arow; cpassert(arow == acol)
173 CALL predict_single_atom(pao, qs_env, iatom, block_x=block_x)
174 END DO
175 CALL dbcsr_iterator_stop(iter)
176!$OMP END PARALLEL
177
178 CALL timestop(handle)
179
180 END SUBROUTINE pao_model_predict
181
182! **************************************************************************************************
183!> \brief Calculate forces contributed by machine learning
184!> \param pao ...
185!> \param qs_env ...
186!> \param matrix_G ...
187!> \param forces ...
188! **************************************************************************************************
189 SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
190 TYPE(pao_env_type), POINTER :: pao
191 TYPE(qs_environment_type), POINTER :: qs_env
192 TYPE(dbcsr_type) :: matrix_g
193 REAL(dp), DIMENSION(:, :), INTENT(INOUT) :: forces
194
195 CHARACTER(len=*), PARAMETER :: routinen = 'pao_model_forces'
196
197 INTEGER :: acol, arow, handle, iatom
198 REAL(dp), DIMENSION(:, :), POINTER :: block_g
199 TYPE(dbcsr_iterator_type) :: iter
200
201 CALL timeset(routinen, handle)
202
203!$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
204 CALL dbcsr_iterator_start(iter, matrix_g)
205 DO WHILE (dbcsr_iterator_blocks_left(iter))
206 CALL dbcsr_iterator_next_block(iter, arow, acol, block_g)
207 iatom = arow; cpassert(arow == acol)
208 IF (SIZE(block_g) == 0) cycle ! pao disabled for iatom
209 CALL predict_single_atom(pao, qs_env, iatom, block_g=block_g, forces=forces)
210 END DO
211 CALL dbcsr_iterator_stop(iter)
212!$OMP END PARALLEL
213
214 CALL timestop(handle)
215
216 END SUBROUTINE pao_model_forces
217
218! **************************************************************************************************
219!> \brief Predicts a single block_X.
220!> \param pao ...
221!> \param qs_env ...
222!> \param iatom ...
223!> \param block_X ...
224!> \param block_G ...
225!> \param forces ...
226! **************************************************************************************************
227 SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
228 TYPE(pao_env_type), INTENT(IN), POINTER :: pao
229 TYPE(qs_environment_type), INTENT(IN), POINTER :: qs_env
230 INTEGER, INTENT(IN) :: iatom
231 REAL(dp), DIMENSION(:, :), OPTIONAL :: block_x, block_g, forces
232
233 INTEGER :: ikind, jatom, jkind, jneighbor, m, n, &
234 natoms
235 INTEGER, ALLOCATABLE, DIMENSION(:) :: neighbors_index
236 INTEGER, DIMENSION(:), POINTER :: blk_sizes_pao, blk_sizes_pri
237 REAL(dp), DIMENSION(3) :: ri, rij, rj
238 REAL(kind=dp), ALLOCATABLE, DIMENSION(:) :: neighbors_distance
239 REAL(sp), ALLOCATABLE, DIMENSION(:, :) :: features, outer_grad, relpos
240 REAL(sp), DIMENSION(:, :), POINTER :: predicted_xblock, relpos_grad
241 TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
242 TYPE(cell_type), POINTER :: cell
243 TYPE(mp_para_env_type), POINTER :: para_env
244 TYPE(pao_model_type), POINTER :: model
245 TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
246 TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
247 TYPE(torch_dict_type) :: model_inputs, model_outputs
248 TYPE(torch_tensor_type) :: features_tensor, outer_grad_tensor, &
249 predicted_xblock_tensor, &
250 relpos_grad_tensor, relpos_tensor
251
252 CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
253 n = blk_sizes_pri(iatom) ! size of primary basis
254 m = blk_sizes_pao(iatom) ! size of pao basis
255
256 CALL get_qs_env(qs_env, &
257 para_env=para_env, &
258 cell=cell, &
259 particle_set=particle_set, &
260 atomic_kind_set=atomic_kind_set, &
261 qs_kind_set=qs_kind_set, &
262 natom=natoms)
263
264 CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
265 model => pao%models(ikind)
266 cpassert(model%version > 0)
267 CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
268
269 ! Find neighbors.
270 ! TODO: this is a quadratic algorithm, use a neighbor-list instead
271 ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
272 ri = particle_set(iatom)%r
273 DO jatom = 1, natoms
274 rj = particle_set(jatom)%r
275 rij = pbc(ri, rj, cell)
276 neighbors_distance(jatom) = dot_product(rij, rij) ! using squared distances for performance
277 END DO
278 CALL sort(neighbors_distance, natoms, neighbors_index)
279 cpassert(neighbors_index(1) == iatom) ! central atom should be closesd to itself
280
281 ! Compute neighbors relative positions.
282 ALLOCATE (relpos(3, model%num_neighbors))
283 relpos(:, :) = 0.0_sp
284 DO jneighbor = 1, min(model%num_neighbors, natoms - 1)
285 jatom = neighbors_index(jneighbor + 1) ! skipping central atom
286 rj = particle_set(jatom)%r
287 rij = pbc(ri, rj, cell)
288 relpos(:, jneighbor) = real(angstrom*rij, kind=sp)
289 END DO
290
291 ! Compute neighbors features.
292 ALLOCATE (features(SIZE(model%feature_kinds), model%num_neighbors))
293 features(:, :) = 0.0_sp
294 DO jneighbor = 1, min(model%num_neighbors, natoms - 1)
295 jatom = neighbors_index(jneighbor + 1) ! skipping central atom
296 jkind = particle_set(jatom)%atomic_kind%kind_number
297 WHERE (model%feature_kinds == jkind) features(:, jneighbor) = 1.0_sp
298 END DO
299
300 ! Inference.
301 CALL torch_dict_create(model_inputs)
302
303 CALL torch_tensor_from_array(relpos_tensor, relpos, requires_grad=PRESENT(block_g))
304 CALL torch_dict_insert(model_inputs, "neighbors_relpos", relpos_tensor)
305 CALL torch_tensor_from_array(features_tensor, features)
306 CALL torch_dict_insert(model_inputs, "neighbors_features", features_tensor)
307
308 CALL torch_dict_create(model_outputs)
309 CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
310
311 ! Copy predicted XBlock.
312 NULLIFY (predicted_xblock)
313 CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
314 CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
315 cpassert(SIZE(predicted_xblock, 1) == n .AND. SIZE(predicted_xblock, 2) == m)
316 IF (PRESENT(block_x)) THEN
317 block_x = reshape(predicted_xblock, [n*m, 1])
318 END IF
319
320 ! TURNING POINT (if calc forces) ------------------------------------------
321 IF (PRESENT(block_g)) THEN
322 ALLOCATE (outer_grad(n, m))
323 outer_grad(:, :) = real(reshape(block_g, [n, m]), kind=sp)
324 CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
325 CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
326 CALL torch_tensor_grad(relpos_tensor, relpos_grad_tensor)
327 NULLIFY (relpos_grad)
328 CALL torch_tensor_data_ptr(relpos_grad_tensor, relpos_grad)
329 cpassert(SIZE(relpos_grad, 1) == 3 .AND. SIZE(relpos_grad, 2) == model%num_neighbors)
330 DO jneighbor = 1, min(model%num_neighbors, natoms - 1)
331 jatom = neighbors_index(jneighbor + 1) ! skipping central atom
332 forces(iatom, :) = forces(iatom, :) + relpos_grad(:, jneighbor)*angstrom
333 forces(jatom, :) = forces(jatom, :) - relpos_grad(:, jneighbor)*angstrom
334 END DO
335 CALL torch_tensor_release(outer_grad_tensor)
336 CALL torch_tensor_release(relpos_grad_tensor)
337 END IF
338
339 ! Clean up.
340 CALL torch_tensor_release(relpos_tensor)
341 CALL torch_tensor_release(features_tensor)
342 CALL torch_tensor_release(predicted_xblock_tensor)
343 CALL torch_dict_release(model_inputs)
344 CALL torch_dict_release(model_outputs)
345 DEALLOCATE (neighbors_distance, neighbors_index, relpos, features)
346 CALL omp_unset_lock(model%lock)
347
348 END SUBROUTINE predict_single_atom
349
350END MODULE pao_model
Define the atomic kind types and their sub types.
subroutine, public get_atomic_kind(atomic_kind, fist_potential, element_symbol, name, mass, kind_number, natom, atom_list, rcov, rvdw, z, qeff, apol, cpol, mm_radius, shell, shell_active, damping)
Get attributes of an atomic kind.
Handles all functions related to the CELL.
Definition cell_types.F:15
subroutine, public dbcsr_iterator_next_block(iterator, row, column, block, block_number_argument_has_been_removed, row_size, col_size, row_offset, col_offset)
...
logical function, public dbcsr_iterator_blocks_left(iterator)
...
subroutine, public dbcsr_iterator_stop(iterator)
...
subroutine, public dbcsr_get_info(matrix, nblkrows_total, nblkcols_total, nfullrows_total, nfullcols_total, nblkrows_local, nblkcols_local, nfullrows_local, nfullcols_local, my_prow, my_pcol, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, row_blk_offset, col_blk_offset, distribution, name, matrix_type, group)
...
subroutine, public dbcsr_iterator_start(iterator, matrix, shared, dynamic, dynamic_byrows)
...
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
integer, parameter, public default_path_length
Definition kinds.F:58
integer, parameter, public sp
Definition kinds.F:33
Interface to the message passing library MPI.
Module for equivariant PAO-ML based on PyTorch.
Definition pao_model.F:12
subroutine, public pao_model_predict(pao, qs_env)
Fills paomatrix_X based on machine learning predictions.
Definition pao_model.F:156
subroutine, public pao_model_forces(pao, qs_env, matrix_g, forces)
Calculate forces contributed by machine learning.
Definition pao_model.F:190
subroutine, public pao_model_load(pao, qs_env, ikind, pao_model_file, model)
Loads a PAO-ML model.
Definition pao_model.F:66
Types used by the PAO machinery.
Definition pao_types.F:12
Define the data structure for the particle information.
Definition of physical constants:
Definition physcon.F:68
real(kind=dp), parameter, public angstrom
Definition physcon.F:144
subroutine, public get_qs_env(qs_env, atomic_kind_set, qs_kind_set, cell, super_cell, cell_ref, use_ref_cell, kpoints, dft_control, mos, sab_orb, sab_all, qmmm, qmmm_periodic, sac_ae, sac_ppl, sac_lri, sap_ppnl, sab_vdw, sab_scp, sap_oce, sab_lrc, sab_se, sab_xtbe, sab_tbe, sab_core, sab_xb, sab_xtb_pp, sab_xtb_nonbond, sab_almo, sab_kp, sab_kp_nosym, particle_set, energy, force, matrix_h, matrix_h_im, matrix_ks, matrix_ks_im, matrix_vxc, run_rtp, rtp, matrix_h_kp, matrix_h_im_kp, matrix_ks_kp, matrix_ks_im_kp, matrix_vxc_kp, kinetic_kp, matrix_s_kp, matrix_w_kp, matrix_s_ri_aux_kp, matrix_s, matrix_s_ri_aux, matrix_w, matrix_p_mp2, matrix_p_mp2_admm, rho, rho_xc, pw_env, ewald_env, ewald_pw, active_space, mpools, input, para_env, blacs_env, scf_control, rel_control, kinetic, qs_charges, vppl, rho_core, rho_nlcc, rho_nlcc_g, ks_env, ks_qmmm_env, wf_history, scf_env, local_particles, local_molecules, distribution_2d, dbcsr_dist, molecule_kind_set, molecule_set, subsys, cp_subsys, oce, local_rho_set, rho_atom_set, task_list, task_list_soft, rho0_atom_set, rho0_mpole, rhoz_set, ecoul_1c, rho0_s_rs, rho0_s_gs, do_kpoints, has_unit_metric, requires_mo_derivs, mo_derivs, mo_loc_history, nkind, natom, nelectron_total, nelectron_spin, efield, neighbor_list_id, linres_control, xas_env, virial, cp_ddapc_env, cp_ddapc_ewald, outer_scf_history, outer_scf_ihistory, x_data, et_coupling, dftb_potential, results, se_taper, se_store_int_env, se_nddo_mpole, se_nonbond_env, admm_env, lri_env, lri_density, exstate_env, ec_env, harris_env, dispersion_env, gcp_env, vee, rho_external, external_vxc, mask, mp2_env, bs_env, kg_env, wanniercentres, atprop, ls_scf_env, do_transport, transport_env, v_hartree_rspace, s_mstruct_changed, rho_changed, potential_changed, forces_up_to_date, mscfg_env, almo_scf_env, gradient_history, variable_history, embed_pot, spin_embed_pot, polar_env, mos_last_converged, eeq, rhs)
Get the QUICKSTEP environment.
Define the quickstep kind type and their sub types.
subroutine, public get_qs_kind(qs_kind, basis_set, basis_type, ncgf, nsgf, all_potential, tnadd_potential, gth_potential, sgp_potential, upf_potential, se_parameter, dftb_parameter, xtb_parameter, dftb3_param, zatom, zeff, elec_conf, mao, lmax_dftb, alpha_core_charge, ccore_charge, core_charge, core_charge_radius, paw_proj_set, paw_atom, hard_radius, hard0_radius, max_rad_local, covalent_radius, vdw_radius, gpw_type_forced, harmonics, max_iso_not0, max_s_harm, grid_atom, ngrid_ang, ngrid_rad, lmax_rho0, dft_plus_u_atom, l_of_dft_plus_u, n_of_dft_plus_u, u_minus_j, u_of_dft_plus_u, j_of_dft_plus_u, alpha_of_dft_plus_u, beta_of_dft_plus_u, j0_of_dft_plus_u, occupation_of_dft_plus_u, dispersion, bs_occupation, magnetization, no_optimize, addel, laddel, naddel, orbitals, max_scf, eps_scf, smear, u_ramping, u_minus_j_target, eps_u_ramping, init_u_ramping_each_scf, reltmat, ghost, floating, name, element_symbol, pao_basis_size, pao_model_file, pao_potentials, pao_descriptors, nelec)
Get attributes of an atomic kind.
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1113
subroutine, public torch_tensor_backward(tensor, outer_grad)
Runs autograd on a Torch tensor.
Definition torch_api.F:937
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
Definition torch_api.F:1079
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1137
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1023
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
Definition torch_api.F:970
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1047
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:999
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Definition torch_api.F:1169
All kind of helpful little routines.
Definition util.F:14
Provides all information about an atomic kind.
Type defining parameters related to the simulation cell.
Definition cell_types.F:55
stores all the informations relevant to an mpi environment
PAO-ML model for a single atomic kind.
Definition pao_types.F:60
Provides all information about a quickstep kind.