(git:ed6f26b)
Loading...
Searching...
No Matches
nequip_unittest.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
9
10 USE cp_files, ONLY: discover_file
11 USE kinds, ONLY: default_path_length,&
12 dp,&
13 int_8,&
14 sp
15 USE mathlib, ONLY: inv_3x3
16 USE physcon, ONLY: angstrom,&
17 evolt
18 USE torch_api, ONLY: &
23#include "./base/base_uses.f90"
24
25 IMPLICIT NONE
26
27 CHARACTER(LEN=default_path_length) :: filename, cutoff_str, nequip_version
28 REAL(dp) :: cutoff
29
30 ! Inputs.
31 INTEGER, PARAMETER :: natoms = 96
32 INTEGER :: iatom, nedges
33 REAL(sp), DIMENSION(:, :), ALLOCATABLE :: pos, cell
34 REAL(dp), DIMENSION(3, 3) :: hinv
35 INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE :: atom_types
36 INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE :: edge_index
37 REAL(sp), DIMENSION(:, :), ALLOCATABLE:: edge_cell_shift
38
39 ! Torch objects.
40 TYPE(torch_model_type) :: model
41 TYPE(torch_dict_type) :: inputs, outputs
42 TYPE(torch_tensor_type) :: pos_tensor, edge_index_tensor, edge_cell_shift_tensor, cell_tensor, &
43 atom_types_tensor, total_energy_tensor, atomic_energy_tensor, forces_tensor
44
45 ! Outputs.
46 REAL(sp), DIMENSION(:, :), POINTER :: total_energy, atomic_energy, forces
47 NULLIFY (total_energy, atomic_energy, forces)
48
49 ! A box with 32 water molecules.
50 ALLOCATE (pos(3, natoms))
51 pos(:, :) = reshape(real([ &
52 42.8861696_dp, -0.0556816_dp, 38.3291611_dp, &
53 34.2025887_dp, -0.6185484_dp, 37.3655680_dp, &
54 30.0803925_dp, -2.0124176_dp, 36.4807960_dp, &
55 28.7057911_dp, -2.6880392_dp, 36.6020983_dp, &
56 36.2479426_dp, -0.5163484_dp, 34.4923596_dp, &
57 37.6964724_dp, -0.0410872_dp, 35.0140735_dp, &
58 27.7606699_dp, 7.4854206_dp, 33.9276919_dp, &
59 28.8160999_dp, 6.4985777_dp, 34.2163608_dp, &
60 37.1576372_dp, 9.0188280_dp, 31.9265812_dp, &
61 38.6063816_dp, 9.5820079_dp, 32.3435972_dp, &
62 34.3031959_dp, 2.2195014_dp, 45.9880451_dp, &
63 33.2444139_dp, 1.3025332_dp, 46.4698427_dp, &
64 38.7286174_dp, -5.0541897_dp, 26.0743968_dp, &
65 38.3483921_dp, -6.2832846_dp, 26.9867253_dp, &
66 32.8642520_dp, 3.2060632_dp, 30.8971160_dp, &
67 31.2904088_dp, 3.0871834_dp, 30.6273977_dp, &
68 33.7519869_dp, -3.1383262_dp, 39.6727607_dp, &
69 34.6642979_dp, -3.6643859_dp, 38.6466027_dp, &
70 42.7173214_dp, 5.1246883_dp, 32.5883401_dp, &
71 41.5627455_dp, 5.5893544_dp, 33.4174902_dp, &
72 32.4283800_dp, 9.1182520_dp, 30.5477678_dp, &
73 32.6432407_dp, 10.770683_dp, 30.4842778_dp, &
74 31.4848670_dp, 4.6777144_dp, 37.3957194_dp, &
75 32.3171882_dp, -6.2287496_dp, 36.4671864_dp, &
76 26.6621340_dp, 3.1708123_dp, 35.6820146_dp, &
77 26.5271367_dp, 1.6039040_dp, 35.4883482_dp, &
78 32.0238236_dp, 16.918208_dp, 31.6883569_dp, &
79 31.4006579_dp, 7.0315610_dp, 30.2394554_dp, &
80 33.5264253_dp, -3.5594808_dp, 34.2636830_dp, &
81 34.6404855_dp, -3.2653833_dp, 35.4971482_dp, &
82 40.0564375_dp, -0.3054386_dp, 29.8312074_dp, &
83 39.4784464_dp, -1.0948314_dp, 38.3101140_dp, &
84 39.7040761_dp, 1.9584631_dp, 33.3902375_dp, &
85 38.3338570_dp, 2.6967178_dp, 42.9261945_dp, &
86 40.1820455_dp, -7.2199289_dp, 27.6580390_dp, &
87 39.3204431_dp, -8.4564252_dp, 28.1319658_dp, &
88 36.3876963_dp, 8.8117085_dp, 38.3545362_dp, &
89 36.3205637_dp, 9.0063075_dp, 36.7526001_dp, &
90 29.9991583_dp, -5.5637817_dp, 33.9295050_dp, &
91 30.7728545_dp, -5.0385870_dp, 35.1998067_dp, &
92 40.0592517_dp, 6.3305279_dp, 28.2579461_dp, &
93 40.2398360_dp, 5.1745923_dp, 29.2962956_dp, &
94 26.3320911_dp, 2.4393638_dp, 33.5653868_dp, &
95 26.9606971_dp, 1.2711078_dp, 32.5923884_dp, &
96 34.8372697_dp, -0.4722708_dp, 30.3824362_dp, &
97 35.3968813_dp, -1.9268483_dp, 30.3081837_dp, &
98 32.1217607_dp, -0.7333429_dp, 36.5104382_dp, &
99 32.2180843_dp, 7.8454304_dp, 35.6671967_dp, &
100 36.3780998_dp, -4.3048878_dp, 36.4539793_dp, &
101 35.8119275_dp, -3.0013928_dp, 27.0348937_dp, &
102 29.6452491_dp, 1.0652123_dp, 35.7143653_dp, &
103 30.3794654_dp, -0.0668146_dp, 34.9882468_dp, &
104 34.2149336_dp, -1.6559120_dp, 33.8876437_dp, &
105 34.7842435_dp, -1.0252141_dp, 32.5034832_dp, &
106 40.4649954_dp, 1.1467825_dp, 31.3073503_dp, &
107 41.3262469_dp, 0.6550803_dp, 32.4555882_dp, &
108 29.0210859_dp, 3.5038194_dp, 39.9087702_dp, &
109 29.4945426_dp, 3.7276637_dp, 41.3766138_dp, &
110 34.1359664_dp, -6.7533422_dp, 32.3568410_dp, &
111 34.9546570_dp, -5.7704242_dp, 31.4571066_dp, &
112 33.2532356_dp, 1.5268048_dp, 44.0562171_dp, &
113 33.7931669_dp, 0.5014632_dp, 43.0597590_dp, &
114 36.8205409_dp, 2.6214681_dp, 40.6834006_dp, &
115 37.5552706_dp, 1.5649832_dp, 39.7648935_dp, &
116 43.2099087_dp, -0.0628456_dp, 47.2593155_dp, &
117 29.3940583_dp, -2.3133019_dp, 37.1407883_dp, &
118 36.7415708_dp, -0.0838710_dp, 35.2591783_dp, &
119 27.9424776_dp, 6.7622961_dp, 34.5648384_dp, &
120 37.6812656_dp, 9.4216399_dp, 32.6478643_dp, &
121 33.3171290_dp, 2.0951401_dp, 45.8722265_dp, &
122 37.9951355_dp, 4.3611431_dp, 26.5571819_dp, &
123 32.1824670_dp, 2.6611503_dp, 30.4577248_dp, &
124 34.6538012_dp, -3.4374573_dp, 39.5889245_dp, &
125 42.2929833_dp, 5.9471069_dp, 32.8460995_dp, &
126 32.9604690_dp, 9.9050313_dp, 30.1587306_dp, &
127 31.4281886_dp, -5.8338304_dp, 36.6738743_dp, &
128 26.0563730_dp, 2.4973869_dp, 35.3486870_dp, &
129 32.0334927_dp, 17.3252289_dp, 30.8116013_dp, &
130 33.8252182_dp, -2.9520949_dp, 35.0220460_dp, &
131 39.4569981_dp, -0.3072759_dp, 38.9347829_dp, &
132 29.4846708_dp, 2.8692561_dp, 43.0061868_dp, &
133 39.2864184_dp, -7.6206103_dp, 27.6271147_dp, &
134 35.8797502_dp, 8.6515870_dp, 37.5221734_dp, &
135 30.3582543_dp, -4.7607656_dp, 34.3355645_dp, &
136 40.7098956_dp, 5.8331250_dp, 28.7558375_dp, &
137 26.7179083_dp, 2.2415138_dp, 32.6577297_dp, &
138 35.6589256_dp, -0.9968903_dp, 30.5749530_dp, &
139 31.5851602_dp, -1.3121804_dp, 35.9011109_dp, &
140 35.5489386_dp, -3.9056138_dp, 26.8214490_dp, &
141 29.5656616_dp, 0.4681794_dp, 34.9670711_dp, &
142 34.7615128_dp, -0.9569680_dp, 33.4891367_dp, &
143 40.4853406_dp, 0.4023620_dp, 31.9425416_dp, &
144 29.6728289_dp, 4.0134825_dp, 40.4505780_dp, &
145 34.1272286_dp, -5.8796882_dp, 31.8925999_dp, &
146 33.1168884_dp, 1.2338084_dp, 43.1127997_dp, &
147 37.1996993_dp, 2.5049007_dp, 39.7917126_dp], kind=sp), shape=[3, natoms])
148
149 ALLOCATE (cell(3, 3))
150 cell(1, :) = [9.85_sp, 0.0_sp, 0.0_sp]
151 cell(2, :) = [0.0_sp, 9.85_sp, 0.0_sp]
152 cell(3, :) = [0.0_sp, 0.0_sp, 9.85_sp]
153
154 hinv(:, :) = inv_3x3(real(cell, kind=dp))
155
156 ALLOCATE (atom_types(natoms))
157 atom_types(:64) = 0 ! Hydrogen
158 atom_types(65:) = 1 ! Oxygen
159
160 WRITE (*, *) "CUDA is available: ", torch_cuda_is_available()
161
162 filename = discover_file('NequIP/water-deployed-neq060sp.pth')
163 WRITE (*, *) "Loading NequIP model from: "//trim(filename)
164 CALL torch_model_load(model, filename)
165 cutoff_str = torch_model_read_metadata(filename, "r_max")
166 nequip_version = torch_model_read_metadata(filename, "nequip_version")
167 READ (cutoff_str, *) cutoff
168 WRITE (*, *) "Version: ", trim(nequip_version)
169 WRITE (*, *) "Cutoff: ", cutoff
170
171 CALL neighbor_search(nedges)
172 ALLOCATE (edge_index(nedges, 2))
173 ALLOCATE (edge_cell_shift(3, nedges))
174 CALL neighbor_search(nedges, edge_index, edge_cell_shift)
175 WRITE (*, *) "Found", nedges, "neighbor edges between", natoms, "atoms."
176
177 CALL torch_dict_create(inputs)
178 CALL torch_dict_create(outputs)
179
180 CALL torch_tensor_from_array(pos_tensor, pos)
181 CALL torch_dict_insert(inputs, "pos", pos_tensor)
182 CALL torch_tensor_release(pos_tensor)
183
184 CALL torch_tensor_from_array(edge_index_tensor, edge_index)
185 CALL torch_dict_insert(inputs, "edge_index", edge_index_tensor)
186 CALL torch_tensor_release(edge_index_tensor)
187
188 CALL torch_tensor_from_array(edge_cell_shift_tensor, edge_cell_shift)
189 CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shift_tensor)
190 CALL torch_tensor_release(edge_cell_shift_tensor)
191
192 CALL torch_tensor_from_array(cell_tensor, cell)
193 CALL torch_dict_insert(inputs, "cell", cell_tensor)
194 CALL torch_tensor_release(cell_tensor)
195
196 CALL torch_tensor_from_array(atom_types_tensor, atom_types)
197 CALL torch_dict_insert(inputs, "atom_types", atom_types_tensor)
198 CALL torch_tensor_release(atom_types_tensor)
199
200 CALL torch_model_forward(model, inputs, outputs)
201
202 CALL torch_dict_get(outputs, "total_energy", total_energy_tensor)
203 CALL torch_tensor_data_ptr(total_energy_tensor, total_energy)
204
205 CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_tensor)
206 CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy)
207
208 CALL torch_dict_get(outputs, "forces", forces_tensor)
209 CALL torch_tensor_data_ptr(forces_tensor, forces)
210
211 WRITE (*, *) "Total Energy [Hartree] : ", total_energy(1, 1)/evolt
212 WRITE (*, *) "FORCES: [Hartree/Bohr]: "
213 DO iatom = 1, natoms
214 WRITE (*, *) forces(:, iatom)*angstrom/evolt
215 END DO
216 cpassert(abs(-14985.4443_dp - real(total_energy(1, 1), kind=dp)) < 2e-3_dp)
217
218 CALL torch_tensor_release(total_energy_tensor)
219 CALL torch_tensor_release(atomic_energy_tensor)
220 CALL torch_tensor_release(forces_tensor)
221 CALL torch_dict_release(inputs)
222 CALL torch_dict_release(outputs)
223 CALL torch_model_release(model)
224 DEALLOCATE (edge_index, edge_cell_shift, pos, cell, atom_types)
225
226 WRITE (*, *) "NequIP unittest was successfully :-)"
227
228CONTAINS
229
230! **************************************************************************************************
231!> \brief Naive neighbor search - beware it scales O(N**2).
232!> \param nedges ...
233!> \param edge_index ...
234!> \param edge_cell_shift ...
235! **************************************************************************************************
236 SUBROUTINE neighbor_search(nedges, edge_index, edge_cell_shift)
237 INTEGER, INTENT(OUT) :: nedges
238 INTEGER(kind=int_8), DIMENSION(:, :), &
239 INTENT(OUT), OPTIONAL :: edge_index
240 REAL(sp), DIMENSION(:, :), INTENT(OUT), OPTIONAL :: edge_cell_shift
241
242 INTEGER:: iatom, jatom
243 REAL(dp), DIMENSION(3) :: s1, s2, s12, cell_shift, dx
244
245 nedges = 0
246 DO iatom = 1, natoms
247 DO jatom = 1, natoms
248 IF (iatom == jatom) cycle
249 s1 = matmul(hinv, pos(:, iatom))
250 s2 = matmul(hinv, pos(:, jatom))
251 s12 = s1 - s2
252 cell_shift = anint(s12)
253 dx = matmul(cell, s12 - cell_shift)
254 IF (dot_product(dx, dx) <= cutoff**2) THEN
255 nedges = nedges + 1
256 IF (PRESENT(edge_index)) THEN
257 edge_index(nedges, :) = [iatom - 1, jatom - 1]
258 END IF
259 IF (PRESENT(edge_cell_shift)) THEN
260 edge_cell_shift(:, nedges) = real(cell_shift, kind=sp)
261 END IF
262 END IF
263 END DO
264 END DO
265 END SUBROUTINE neighbor_search
266
267END PROGRAM nequip_unittest
Define the atom type and its sub types.
Definition atom_types.F:15
Utility routines to open and close files. Tracking of preconnections.
Definition cp_files.F:16
character(len=default_path_length) function, public discover_file(file_name)
Checks various locations for a file name.
Definition cp_files.F:518
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_path_length
Definition kinds.F:58
integer, parameter, public sp
Definition kinds.F:33
Collection of simple mathematical functions and subroutines.
Definition mathlib.F:15
pure real(kind=dp) function, dimension(3, 3), public inv_3x3(a)
Returns the inverse of the 3 x 3 matrix a.
Definition mathlib.F:516
Definition of physical constants:
Definition physcon.F:68
real(kind=dp), parameter, public evolt
Definition physcon.F:183
real(kind=dp), parameter, public angstrom
Definition physcon.F:144
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1113
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
Definition torch_api.F:1079
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1137
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1023
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition torch_api.F:1205
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition torch_api.F:1229
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1047
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition torch_api.F:1286
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:999
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Definition torch_api.F:1169
program nequip_unittest
subroutine neighbor_search(nedges, edge_index, edge_cell_shift)
Naive neighbor search - beware it scales O(N**2).