13 USE omp_lib,
ONLY: omp_init_lock,&
46#include "./base/base_uses.f90"
67 INTEGER,
INTENT(IN) :: ikind
68 CHARACTER(LEN=default_path_length),
INTENT(IN) :: pao_model_file
71 CHARACTER(len=*),
PARAMETER :: routinen =
'pao_model_load'
73 CHARACTER(LEN=default_string_length) :: kind_name
74 CHARACTER(LEN=default_string_length), &
75 ALLOCATABLE,
DIMENSION(:) :: model_kind_names
76 INTEGER :: handle, jkind, kkind, pao_basis_size, z
77 REAL(
dp) :: cutoff_angstrom
82 CALL timeset(routinen, handle)
83 CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
85 IF (pao%iw > 0)
WRITE (pao%iw,
'(A)')
" PAO| Loading PyTorch model from: "//trim(pao_model_file)
98 model%cutoff = cutoff_angstrom/
angstrom
106 ALLOCATE (model%kinds_mapping(
SIZE(atomic_kind_set)))
107 model%kinds_mapping(:) = -1
108 DO jkind = 1,
SIZE(atomic_kind_set)
109 DO kkind = 1,
SIZE(model_kind_names)
110 IF (trim(atomic_kind_set(jkind)%name) == trim(model_kind_names(kkind)))
THEN
111 model%kinds_mapping(jkind) = kkind - 1
115 IF (model%kinds_mapping(jkind) < 0)
THEN
116 CALL cp_abort(__location__,
"PAO-ML model lacks kind '"//trim(atomic_kind_set(jkind)%name)//
"' .")
121 CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
123 IF (model%version /= 2) &
124 cpabort(
"Model version not supported.")
125 IF (trim(model%kind_name) .NE. trim(kind_name)) &
126 cpabort(
"Kind name does not match.")
127 IF (model%atomic_number /= z) &
128 cpabort(
"Atomic number does not match.")
129 IF (trim(model%prim_basis_name) .NE. trim(basis_set%name)) &
130 cpabort(
"Primary basis set name does not match.")
131 IF (model%prim_basis_size /= basis_set%nsgf) &
132 cpabort(
"Primary basis set size does not match.")
133 IF (model%pao_basis_size /= pao_basis_size) &
134 cpabort(
"PAO basis size does not match.")
136 CALL omp_init_lock(model%lock)
137 CALL timestop(handle)
150 CHARACTER(len=*),
PARAMETER :: routinen =
'pao_model_predict'
152 INTEGER :: acol, arow, handle, iatom
153 REAL(
dp),
DIMENSION(:, :),
POINTER :: block_x
156 CALL timeset(routinen, handle)
162 IF (
SIZE(block_x) == 0) cycle
163 iatom = arow; cpassert(arow == acol)
164 CALL predict_single_atom(pao, qs_env, iatom, block_x=block_x)
169 CALL timestop(handle)
184 REAL(
dp),
DIMENSION(:, :),
INTENT(INOUT) :: forces
186 CHARACTER(len=*),
PARAMETER :: routinen =
'pao_model_forces'
188 INTEGER :: acol, arow, handle, iatom
189 REAL(
dp),
DIMENSION(:, :),
POINTER :: block_g
192 CALL timeset(routinen, handle)
198 iatom = arow; cpassert(arow == acol)
199 IF (
SIZE(block_g) == 0) cycle
200 CALL predict_single_atom(pao, qs_env, iatom, block_g=block_g, forces=forces)
205 CALL timestop(handle)
218 SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
221 INTEGER,
INTENT(IN) :: iatom
222 REAL(
dp),
DIMENSION(:, :),
OPTIONAL :: block_x, block_g, forces
224 INTEGER :: i, iedge, ikind, j, jatom, jcell, jkind, &
225 jneighbor, k, katom, kneighbor, m, n, &
226 natoms, num_edges, num_neighbors
227 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:) :: neighbor_atom_types
228 INTEGER(kind=int_8),
ALLOCATABLE,
DIMENSION(:, :) :: central_edge_index, edge_index
229 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: neighbor_atom_index
230 INTEGER,
DIMENSION(:),
POINTER :: blk_sizes_pao, blk_sizes_pri
231 REAL(
dp),
DIMENSION(3) :: ri, rj, rjk
232 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: cell_shifts, neighbor_pos
233 REAL(
sp),
ALLOCATABLE,
DIMENSION(:, :) :: edge_vectors
234 REAL(
sp),
ALLOCATABLE,
DIMENSION(:, :, :) :: outer_grad
235 REAL(
sp),
DIMENSION(:, :),
POINTER :: edge_vectors_grad
236 REAL(
sp),
DIMENSION(:, :, :),
POINTER :: predicted_xblock
242 TYPE(
qs_kind_type),
DIMENSION(:),
POINTER :: qs_kind_set
244 TYPE(
torch_tensor_type) :: atom_types_tensor, central_edge_index_tensor, edge_index_tensor, &
245 edge_vectors_grad_tensor, edge_vectors_tensor, outer_grad_tensor, predicted_xblock_tensor
247 CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
248 n = blk_sizes_pri(iatom)
249 m = blk_sizes_pao(iatom)
254 particle_set=particle_set, &
255 atomic_kind_set=atomic_kind_set, &
256 qs_kind_set=qs_kind_set, &
259 CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
260 ri = particle_set(iatom)%r
261 model => pao%models(ikind)
262 cpassert(model%version > 0)
263 CALL omp_set_lock(model%lock)
268 ALLOCATE (cell_shifts(27, 3))
274 cell_shifts(jcell, :) = i*cell%hmat(:, 1) + j*cell%hmat(:, 2) + k*cell%hmat(:, 3)
284 rj = particle_set(jatom)%r + cell_shifts(jcell, :)
285 IF (norm2(rj - ri) < model%num_layers*model%cutoff .AND. any(rj /= ri))
THEN
286 num_neighbors = num_neighbors + 1
292 ALLOCATE (neighbor_pos(num_neighbors, 3), neighbor_atom_types(num_neighbors), neighbor_atom_index(num_neighbors))
294 neighbor_pos(1, :) = ri
295 neighbor_atom_types(1) = model%kinds_mapping(ikind)
296 neighbor_atom_index(1) = iatom
299 rj = particle_set(jatom)%r + cell_shifts(jcell, :)
300 jkind = particle_set(jatom)%atomic_kind%kind_number
301 IF (norm2(rj - ri) < model%num_layers*model%cutoff .AND. any(rj /= ri))
THEN
302 num_neighbors = num_neighbors + 1
303 neighbor_pos(num_neighbors, :) = rj
304 neighbor_atom_types(num_neighbors) = model%kinds_mapping(jkind)
305 neighbor_atom_index(num_neighbors) = jatom
313 DO jneighbor = 1, num_neighbors
314 DO kneighbor = 1, num_neighbors
315 rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
316 IF (norm2(rjk) < model%cutoff .AND. jneighbor /= kneighbor)
THEN
317 num_edges = num_edges + 1
323 ALLOCATE (edge_index(num_edges, 2), edge_vectors(3, num_edges))
325 DO jneighbor = 1, num_neighbors
326 DO kneighbor = 1, num_neighbors
327 rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
328 IF (norm2(rjk) < model%cutoff .AND. jneighbor /= kneighbor)
THEN
329 num_edges = num_edges + 1
330 edge_index(num_edges, :) = [jneighbor - 1, kneighbor - 1]
331 edge_vectors(:, num_edges) = real(rjk*
angstrom, kind=
sp)
336 ALLOCATE (central_edge_index(1, 2))
337 central_edge_index(:, :) = 0
352 CALL torch_dict_insert(model_inputs,
"central_edge_index", central_edge_index_tensor)
358 NULLIFY (predicted_xblock)
359 CALL torch_dict_get(model_outputs,
"xblock", predicted_xblock_tensor)
361 cpassert(
SIZE(predicted_xblock, 1) == n)
362 cpassert(
SIZE(predicted_xblock, 2) == m)
363 cpassert(
SIZE(predicted_xblock, 3) == 1)
364 IF (
PRESENT(block_x))
THEN
365 block_x = reshape(predicted_xblock, [n*m, 1])
369 IF (
PRESENT(block_g))
THEN
370 ALLOCATE (outer_grad(n, m, 1))
371 outer_grad(:, :, :) = real(reshape(block_g, [n, m, 1]), kind=
sp)
375 NULLIFY (edge_vectors_grad)
377 cpassert(
SIZE(edge_vectors_grad, 1) == 3 .AND.
SIZE(edge_vectors_grad, 2) == num_edges)
378 DO iedge = 1, num_edges
379 jneighbor = int(edge_index(iedge, 1) + 1)
380 kneighbor = int(edge_index(iedge, 2) + 1)
381 jatom = neighbor_atom_index(jneighbor)
382 katom = neighbor_atom_index(kneighbor)
383 forces(jatom, :) = forces(jatom, :) + edge_vectors_grad(:, iedge)*
angstrom
384 forces(katom, :) = forces(katom, :) - edge_vectors_grad(:, iedge)*
angstrom
398 CALL omp_unset_lock(model%lock)
400 END SUBROUTINE predict_single_atom
Define the atomic kind types and their sub types.
subroutine, public get_atomic_kind(atomic_kind, fist_potential, element_symbol, name, mass, kind_number, natom, atom_list, rcov, rvdw, z, qeff, apol, cpol, mm_radius, shell, shell_active, damping)
Get attributes of an atomic kind.
Handles all functions related to the CELL.
subroutine, public dbcsr_iterator_next_block(iterator, row, column, block, block_number_argument_has_been_removed, row_size, col_size, row_offset, col_offset)
...
logical function, public dbcsr_iterator_blocks_left(iterator)
...
subroutine, public dbcsr_iterator_stop(iterator)
...
subroutine, public dbcsr_get_info(matrix, nblkrows_total, nblkcols_total, nfullrows_total, nfullcols_total, nblkrows_local, nblkcols_local, nfullrows_local, nfullcols_local, my_prow, my_pcol, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, row_blk_offset, col_blk_offset, distribution, name, matrix_type, group)
...
subroutine, public dbcsr_iterator_start(iterator, matrix, shared, dynamic, dynamic_byrows)
...
Defines the basic variable types.
integer, parameter, public int_8
integer, parameter, public dp
integer, parameter, public default_string_length
integer, parameter, public default_path_length
integer, parameter, public sp
Interface to the message passing library MPI.
Module for equivariant PAO-ML based on PyTorch.
subroutine, public pao_model_predict(pao, qs_env)
Fills paomatrix_X based on machine learning predictions.
subroutine, public pao_model_forces(pao, qs_env, matrix_g, forces)
Calculate forces contributed by machine learning.
subroutine, public pao_model_load(pao, qs_env, ikind, pao_model_file, model)
Loads a PAO-ML model.
Types used by the PAO machinery.
Define the data structure for the particle information.
Definition of physical constants:
real(kind=dp), parameter, public angstrom
subroutine, public get_qs_env(qs_env, atomic_kind_set, qs_kind_set, cell, super_cell, cell_ref, use_ref_cell, kpoints, dft_control, mos, sab_orb, sab_all, qmmm, qmmm_periodic, sac_ae, sac_ppl, sac_lri, sap_ppnl, sab_vdw, sab_scp, sap_oce, sab_lrc, sab_se, sab_xtbe, sab_tbe, sab_core, sab_xb, sab_xtb_pp, sab_xtb_nonbond, sab_almo, sab_kp, sab_kp_nosym, sab_cneo, particle_set, energy, force, matrix_h, matrix_h_im, matrix_ks, matrix_ks_im, matrix_vxc, run_rtp, rtp, matrix_h_kp, matrix_h_im_kp, matrix_ks_kp, matrix_ks_im_kp, matrix_vxc_kp, kinetic_kp, matrix_s_kp, matrix_w_kp, matrix_s_ri_aux_kp, matrix_s, matrix_s_ri_aux, matrix_w, matrix_p_mp2, matrix_p_mp2_admm, rho, rho_xc, pw_env, ewald_env, ewald_pw, active_space, mpools, input, para_env, blacs_env, scf_control, rel_control, kinetic, qs_charges, vppl, rho_core, rho_nlcc, rho_nlcc_g, ks_env, ks_qmmm_env, wf_history, scf_env, local_particles, local_molecules, distribution_2d, dbcsr_dist, molecule_kind_set, molecule_set, subsys, cp_subsys, oce, local_rho_set, rho_atom_set, task_list, task_list_soft, rho0_atom_set, rho0_mpole, rhoz_set, rhoz_cneo_set, ecoul_1c, rho0_s_rs, rho0_s_gs, rhoz_cneo_s_rs, rhoz_cneo_s_gs, do_kpoints, has_unit_metric, requires_mo_derivs, mo_derivs, mo_loc_history, nkind, natom, nelectron_total, nelectron_spin, efield, neighbor_list_id, linres_control, xas_env, virial, cp_ddapc_env, cp_ddapc_ewald, outer_scf_history, outer_scf_ihistory, x_data, et_coupling, dftb_potential, results, se_taper, se_store_int_env, se_nddo_mpole, se_nonbond_env, admm_env, lri_env, lri_density, exstate_env, ec_env, harris_env, dispersion_env, gcp_env, vee, rho_external, external_vxc, mask, mp2_env, bs_env, kg_env, wanniercentres, atprop, ls_scf_env, do_transport, transport_env, v_hartree_rspace, s_mstruct_changed, rho_changed, potential_changed, forces_up_to_date, mscfg_env, almo_scf_env, gradient_history, variable_history, embed_pot, spin_embed_pot, polar_env, mos_last_converged, eeq, rhs, do_rixs, tb_tblite)
Get the QUICKSTEP environment.
Define the quickstep kind type and their sub types.
subroutine, public get_qs_kind(qs_kind, basis_set, basis_type, ncgf, nsgf, all_potential, tnadd_potential, gth_potential, sgp_potential, upf_potential, cneo_potential, se_parameter, dftb_parameter, xtb_parameter, dftb3_param, zatom, zeff, elec_conf, mao, lmax_dftb, alpha_core_charge, ccore_charge, core_charge, core_charge_radius, paw_proj_set, paw_atom, hard_radius, hard0_radius, max_rad_local, covalent_radius, vdw_radius, gpw_type_forced, harmonics, max_iso_not0, max_s_harm, grid_atom, ngrid_ang, ngrid_rad, lmax_rho0, dft_plus_u_atom, l_of_dft_plus_u, n_of_dft_plus_u, u_minus_j, u_of_dft_plus_u, j_of_dft_plus_u, alpha_of_dft_plus_u, beta_of_dft_plus_u, j0_of_dft_plus_u, occupation_of_dft_plus_u, dispersion, bs_occupation, magnetization, no_optimize, addel, laddel, naddel, orbitals, max_scf, eps_scf, smear, u_ramping, u_minus_j_target, eps_u_ramping, init_u_ramping_each_scf, reltmat, ghost, floating, name, element_symbol, pao_basis_size, pao_model_file, pao_potentials, pao_descriptors, nelec)
Get attributes of an atomic kind.
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
subroutine, public torch_tensor_backward(tensor, outer_grad)
Runs autograd on a Torch tensor.
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Provides all information about an atomic kind.
Type defining parameters related to the simulation cell.
stores all the informations relevant to an mpi environment
PAO-ML model for a single atomic kind.
Provides all information about a quickstep kind.