(git:06f838d)
Loading...
Searching...
No Matches
skala_torch_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief Small CP2K wrapper around the SKALA TorchScript functional protocol.
10! **************************************************************************************************
12 USE kinds, ONLY: default_string_length,&
13 dp
15 USE torch_api, ONLY: &
20#include "./base/base_uses.f90"
21
22 IMPLICIT NONE
23
24 PRIVATE
25
26 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_torch_api'
27
31
33 PRIVATE
34 INTEGER :: protocol_version = -1
35 CHARACTER(len=default_string_length), ALLOCATABLE, &
36 DIMENSION(:) :: features
37 TYPE(torch_model_type) :: torch_model
39
40CONTAINS
41
42! **************************************************************************************************
43!> \brief Load a SKALA TorchScript model and its feature metadata.
44!> \param model ...
45!> \param filename ...
46! **************************************************************************************************
47 SUBROUTINE skala_torch_model_load(model, filename)
48 TYPE(skala_torch_model_type), INTENT(INOUT) :: model
49 CHARACTER(len=*), INTENT(IN) :: filename
50
51 CHARACTER(:), ALLOCATABLE :: features_json, protocol_string
52 INTEGER :: ios
53
54 CALL torch_model_load(model%torch_model, filename)
55 protocol_string = torch_model_read_metadata(filename, "protocol_version")
56 features_json = torch_model_read_metadata(filename, "features")
57 READ (protocol_string, *, iostat=ios) model%protocol_version
58 IF (ios /= 0) cpabort("Could not parse SKALA TorchScript protocol_version metadata")
59 IF (model%protocol_version /= 2) THEN
60 cpabort("Unsupported SKALA TorchScript protocol version")
61 END IF
62
63 CALL parse_feature_list(features_json, model%features)
64
65 END SUBROUTINE skala_torch_model_load
66
67! **************************************************************************************************
68!> \brief Release a loaded SKALA TorchScript model.
69!> \param model ...
70! **************************************************************************************************
71 SUBROUTINE skala_torch_model_release(model)
72 TYPE(skala_torch_model_type), INTENT(INOUT) :: model
73
74 CALL torch_model_release(model%torch_model)
75 IF (ALLOCATED(model%features)) DEALLOCATE (model%features)
76 model%protocol_version = -1
77
78 END SUBROUTINE skala_torch_model_release
79
80! **************************************************************************************************
81!> \brief Check whether a loaded SKALA model requests a feature.
82!> \param model ...
83!> \param feature ...
84!> \return ...
85! **************************************************************************************************
86 FUNCTION skala_torch_model_needs_feature(model, feature) RESULT(needs_feature)
87 TYPE(skala_torch_model_type), INTENT(IN) :: model
88 CHARACTER(len=*), INTENT(IN) :: feature
89 LOGICAL :: needs_feature
90
91 CHARACTER(len=default_string_length) :: feature_key, model_feature
92 INTEGER :: i
93
94 feature_key = adjustl(feature)
95 CALL uppercase(feature_key)
96
97 needs_feature = .false.
98 IF (.NOT. ALLOCATED(model%features)) RETURN
99
100 DO i = 1, SIZE(model%features)
101 model_feature = adjustl(model%features(i))
102 CALL uppercase(model_feature)
103 IF (trim(model_feature) == trim(feature_key)) THEN
104 needs_feature = .true.
105 RETURN
106 END IF
107 END DO
108
110
111! **************************************************************************************************
112!> \brief Return the loaded SKALA TorchScript protocol version.
113!> \param model ...
114!> \return ...
115! **************************************************************************************************
116 FUNCTION skala_torch_model_protocol_version(model) RESULT(protocol_version)
117 TYPE(skala_torch_model_type), INTENT(IN) :: model
118 INTEGER :: protocol_version
119
120 protocol_version = model%protocol_version
121
123
124! **************************************************************************************************
125!> \brief Evaluate the SKALA exchange-correlation energy density.
126!> \param model ...
127!> \param inputs ...
128!> \param exc_density ...
129! **************************************************************************************************
130 SUBROUTINE skala_torch_model_get_exc_density(model, inputs, exc_density)
131 TYPE(skala_torch_model_type), INTENT(INOUT) :: model
132 TYPE(torch_dict_type), INTENT(IN) :: inputs
133 TYPE(torch_tensor_type), INTENT(INOUT) :: exc_density
134
135 CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
136
138
139! **************************************************************************************************
140!> \brief Evaluate the weighted SKALA exchange-correlation energy.
141!> \param model ...
142!> \param inputs ...
143!> \param grid_weights ...
144!> \param exc_tensor ...
145!> \param exc ...
146! **************************************************************************************************
147 SUBROUTINE skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
148 TYPE(skala_torch_model_type), INTENT(INOUT) :: model
149 TYPE(torch_dict_type), INTENT(IN) :: inputs
150 TYPE(torch_tensor_type), INTENT(IN) :: grid_weights
151 TYPE(torch_tensor_type), INTENT(INOUT) :: exc_tensor
152 REAL(kind=dp), INTENT(OUT) :: exc
153
154 TYPE(torch_tensor_type) :: exc_density
155
156 CALL skala_torch_model_get_exc_density(model, inputs, exc_density)
157 CALL torch_tensor_weighted_sum(exc_density, grid_weights, exc_tensor)
158 exc = torch_tensor_item_double(exc_tensor)
159 CALL torch_tensor_release(exc_density)
160
161 END SUBROUTINE skala_torch_model_get_exc
162
163! **************************************************************************************************
164!> \brief Parse a TorchScript extra_files JSON list of feature names.
165!> \param features_json ...
166!> \param features ...
167! **************************************************************************************************
168 SUBROUTINE parse_feature_list(features_json, features)
169 CHARACTER(len=*), INTENT(IN) :: features_json
170 CHARACTER(len=default_string_length), &
171 ALLOCATABLE, DIMENSION(:), INTENT(OUT) :: features
172
173 INTEGER :: end_pos, feature_count, i, pos, quote1, &
174 quote2, start_pos
175
176 feature_count = 0
177 pos = 1
178 DO
179 quote1 = index(features_json(pos:), '"')
180 IF (quote1 == 0) EXIT
181 start_pos = pos + quote1
182 quote2 = index(features_json(start_pos:), '"')
183 IF (quote2 == 0) EXIT
184 feature_count = feature_count + 1
185 pos = start_pos + quote2
186 END DO
187
188 IF (feature_count == 0) cpabort("SKALA TorchScript model does not list any features")
189 ALLOCATE (features(feature_count))
190 features = ""
191
192 pos = 1
193 DO i = 1, feature_count
194 quote1 = index(features_json(pos:), '"')
195 start_pos = pos + quote1
196 quote2 = index(features_json(start_pos:), '"')
197 end_pos = start_pos + quote2 - 2
198 features(i) = features_json(start_pos:end_pos)
199 pos = start_pos + quote2
200 END DO
201
202 END SUBROUTINE parse_feature_list
203
204END MODULE skala_torch_api
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
Small CP2K wrapper around the SKALA TorchScript functional protocol.
subroutine, public skala_torch_model_release(model)
Release a loaded SKALA TorchScript model.
logical function, public skala_torch_model_needs_feature(model, feature)
Check whether a loaded SKALA model requests a feature.
subroutine, public skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
Evaluate the weighted SKALA exchange-correlation energy.
integer function, public skala_torch_model_protocol_version(model)
Return the loaded SKALA TorchScript protocol version.
subroutine, public skala_torch_model_get_exc_density(model, inputs, exc_density)
Evaluate the SKALA exchange-correlation energy density.
subroutine, public skala_torch_model_load(model, filename)
Load a SKALA TorchScript model and its feature metadata.
Utilities for string manipulations.
elemental subroutine, public uppercase(string)
Convert all lower case characters in a string to upper case.
real(kind=dp) function, public torch_tensor_item_double(tensor)
Returns a scalar double value from a Torch tensor.
Definition torch_api.F:1553
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1745
subroutine, public torch_model_forward_mol_tensor(model, method_name, inputs, output)
Evaluates a TorchScript model method expecting keyword argument "mol".
Definition torch_api.F:1812
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition torch_api.F:1856
subroutine, public torch_tensor_weighted_sum(values, weights, result)
Returns the weighted sum of two Torch tensors.
Definition torch_api.F:1522
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition torch_api.F:1880
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:1580