20#include "./base/base_uses.f90"
26 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'skala_torch_api'
34 INTEGER :: protocol_version = -1
35 CHARACTER(len=default_string_length),
ALLOCATABLE, &
36 DIMENSION(:) :: features
49 CHARACTER(len=*),
INTENT(IN) :: filename
51 CHARACTER(:),
ALLOCATABLE :: features_json, protocol_string
57 READ (protocol_string, *, iostat=ios) model%protocol_version
58 IF (ios /= 0) cpabort(
"Could not parse SKALA TorchScript protocol_version metadata")
59 IF (model%protocol_version /= 2)
THEN
60 cpabort(
"Unsupported SKALA TorchScript protocol version")
63 CALL parse_feature_list(features_json, model%features)
75 IF (
ALLOCATED(model%features))
DEALLOCATE (model%features)
76 model%protocol_version = -1
88 CHARACTER(len=*),
INTENT(IN) :: feature
89 LOGICAL :: needs_feature
91 CHARACTER(len=default_string_length) :: feature_key, model_feature
94 feature_key = adjustl(feature)
97 needs_feature = .false.
98 IF (.NOT.
ALLOCATED(model%features))
RETURN
100 DO i = 1,
SIZE(model%features)
101 model_feature = adjustl(model%features(i))
103 IF (trim(model_feature) == trim(feature_key))
THEN
104 needs_feature = .true.
118 INTEGER :: protocol_version
120 protocol_version = model%protocol_version
152 REAL(kind=
dp),
INTENT(OUT) :: exc
168 SUBROUTINE parse_feature_list(features_json, features)
169 CHARACTER(len=*),
INTENT(IN) :: features_json
170 CHARACTER(len=default_string_length), &
171 ALLOCATABLE,
DIMENSION(:),
INTENT(OUT) :: features
173 INTEGER :: end_pos, feature_count, i, pos, quote1, &
179 quote1 = index(features_json(pos:),
'"')
180 IF (quote1 == 0)
EXIT
181 start_pos = pos + quote1
182 quote2 = index(features_json(start_pos:),
'"')
183 IF (quote2 == 0)
EXIT
184 feature_count = feature_count + 1
185 pos = start_pos + quote2
188 IF (feature_count == 0) cpabort(
"SKALA TorchScript model does not list any features")
189 ALLOCATE (features(feature_count))
193 DO i = 1, feature_count
194 quote1 = index(features_json(pos:),
'"')
195 start_pos = pos + quote1
196 quote2 = index(features_json(start_pos:),
'"')
197 end_pos = start_pos + quote2 - 2
198 features(i) = features_json(start_pos:end_pos)
199 pos = start_pos + quote2
202 END SUBROUTINE parse_feature_list
Defines the basic variable types.
integer, parameter, public dp
integer, parameter, public default_string_length
Small CP2K wrapper around the SKALA TorchScript functional protocol.
subroutine, public skala_torch_model_release(model)
Release a loaded SKALA TorchScript model.
logical function, public skala_torch_model_needs_feature(model, feature)
Check whether a loaded SKALA model requests a feature.
subroutine, public skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
Evaluate the weighted SKALA exchange-correlation energy.
integer function, public skala_torch_model_protocol_version(model)
Return the loaded SKALA TorchScript protocol version.
subroutine, public skala_torch_model_get_exc_density(model, inputs, exc_density)
Evaluate the SKALA exchange-correlation energy density.
subroutine, public skala_torch_model_load(model, filename)
Load a SKALA TorchScript model and its feature metadata.
Utilities for string manipulations.
elemental subroutine, public uppercase(string)
Convert all lower case characters in a string to upper case.
real(kind=dp) function, public torch_tensor_item_double(tensor)
Returns a scalar double value from a Torch tensor.
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
subroutine, public torch_model_forward_mol_tensor(model, method_name, inputs, output)
Evaluates a TorchScript model method expecting keyword argument "mol".
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
subroutine, public torch_tensor_weighted_sum(values, weights, result)
Returns the weighted sum of two Torch tensors.
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.