(git:ccc2433)
torch_api.F
Go to the documentation of this file.
1 !--------------------------------------------------------------------------------------------------!
2 ! CP2K: A general program to perform molecular dynamics simulations !
3 ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 ! !
5 ! SPDX-License-Identifier: GPL-2.0-or-later !
6 !--------------------------------------------------------------------------------------------------!
7 MODULE torch_api
8  USE iso_c_binding, ONLY: c_associated, &
9  c_bool, &
10  c_char, &
11  c_float, &
12  c_double, &
13  c_f_pointer, &
14  c_int, &
15  c_null_char, &
16  c_null_ptr, &
17  c_ptr, &
18  c_int64_t
19 
20  USE kinds, ONLY: sp, int_8, dp
21 
22 #include "./base/base_uses.f90"
23 
24  IMPLICIT NONE
25 
26  PRIVATE
27 
28  TYPE torch_dict_type
29  PRIVATE
30  TYPE(C_PTR) :: c_ptr = c_null_ptr
31  END TYPE torch_dict_type
32 
33  TYPE torch_model_type
34  PRIVATE
35  TYPE(C_PTR) :: c_ptr = c_null_ptr
36  END TYPE torch_model_type
37 
38  INTERFACE torch_dict_insert
39  MODULE PROCEDURE torch_dict_insert_float_1d
40  MODULE PROCEDURE torch_dict_insert_int64_1d
41  MODULE PROCEDURE torch_dict_insert_double_1d
42  MODULE PROCEDURE torch_dict_insert_float_2d
43  MODULE PROCEDURE torch_dict_insert_int64_2d
44  MODULE PROCEDURE torch_dict_insert_double_2d
45  MODULE PROCEDURE torch_dict_insert_float_3d
46  MODULE PROCEDURE torch_dict_insert_int64_3d
47  MODULE PROCEDURE torch_dict_insert_double_3d
48  END INTERFACE torch_dict_insert
49 
50  INTERFACE torch_dict_get
51  MODULE PROCEDURE torch_dict_get_float_1d
52  MODULE PROCEDURE torch_dict_get_int64_1d
53  MODULE PROCEDURE torch_dict_get_double_1d
54  MODULE PROCEDURE torch_dict_get_float_2d
55  MODULE PROCEDURE torch_dict_get_int64_2d
56  MODULE PROCEDURE torch_dict_get_double_2d
57  MODULE PROCEDURE torch_dict_get_float_3d
58  MODULE PROCEDURE torch_dict_get_int64_3d
59  MODULE PROCEDURE torch_dict_get_double_3d
60  END INTERFACE torch_dict_get
61 
62  PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_release
63  PUBLIC :: torch_dict_insert, torch_dict_get
64  PUBLIC :: torch_model_type, torch_model_load, torch_model_eval, torch_model_release
67 
68 CONTAINS
69 
70 
71 
72 ! **************************************************************************************************
73 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
74 !> \author Ole Schuett
75 ! **************************************************************************************************
76  SUBROUTINE torch_dict_insert_float_1d(dict, key, source)
77  TYPE(torch_dict_type), INTENT(INOUT) :: dict
78  CHARACTER(len=*), INTENT(IN) :: key
79  REAL(sp), CONTIGUOUS, DIMENSION(:), INTENT(IN) :: source
80 
81 #if defined(__LIBTORCH)
82  INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
83 
84  INTERFACE
85  SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
86  BIND(C, name="torch_c_dict_insert_float")
87  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
88  TYPE(C_PTR), VALUE :: dict
89  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
90  INTEGER(kind=C_INT), VALUE :: ndims
91  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
92  REAL(kind=c_float), DIMENSION(*) :: source
93  END SUBROUTINE torch_c_dict_insert_float
94  END INTERFACE
95 
96  sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
97 
98  cpassert(c_associated(dict%c_ptr))
99  CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
100  key=trim(key)//c_null_char, &
101  ndims=1, &
102  sizes=sizes_c, &
103  source=source)
104 #else
105  cpabort("CP2K compiled without the Torch library.")
106  mark_used(dict)
107  mark_used(key)
108  mark_used(source)
109 #endif
110  END SUBROUTINE torch_dict_insert_float_1d
111 
112 ! **************************************************************************************************
113 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
114 !> \author Ole Schuett
115 ! **************************************************************************************************
116  SUBROUTINE torch_dict_get_float_1d(dict, key, dest)
117  TYPE(torch_dict_type), INTENT(IN) :: dict
118  CHARACTER(len=*), INTENT(IN) :: key
119  REAL(sp), DIMENSION(:), POINTER :: dest
120 
121 #if defined(__LIBTORCH)
122  INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
123  TYPE(C_PTR) :: dest_c
124 
125  INTERFACE
126  SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
127  BIND(C, name="torch_c_dict_get_float")
128  IMPORT :: c_char, c_ptr, c_int, c_int64_t
129  TYPE(C_PTR), VALUE :: dict
130  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
131  INTEGER(kind=C_INT), VALUE :: ndims
132  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
133  TYPE(C_PTR) :: dest
134  END SUBROUTINE torch_c_dict_get_float
135  END INTERFACE
136 
137  sizes_c(:) = -1
138  dest_c = c_null_ptr
139  cpassert(c_associated(dict%c_ptr))
140  cpassert(.NOT. ASSOCIATED(dest))
141  CALL torch_c_dict_get_float (dict=dict%c_ptr, &
142  key=trim(key)//c_null_char, &
143  ndims=1, &
144  sizes=sizes_c, &
145  dest=dest_c)
146 
147  cpassert(all(sizes_c >= 0))
148  cpassert(c_associated(dest_c))
149 
150  sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
151  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
152 #else
153  cpabort("CP2K compiled without the Torch library.")
154  mark_used(dict)
155  mark_used(key)
156  mark_used(dest)
157 #endif
158  END SUBROUTINE torch_dict_get_float_1d
159 
160 
161 ! **************************************************************************************************
162 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
163 !> \author Ole Schuett
164 ! **************************************************************************************************
165  SUBROUTINE torch_dict_insert_int64_1d(dict, key, source)
166  TYPE(torch_dict_type), INTENT(INOUT) :: dict
167  CHARACTER(len=*), INTENT(IN) :: key
168  INTEGER(kind=int_8), CONTIGUOUS, DIMENSION(:), INTENT(IN) :: source
169 
170 #if defined(__LIBTORCH)
171  INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
172 
173  INTERFACE
174  SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
175  BIND(C, name="torch_c_dict_insert_int64")
176  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
177  TYPE(C_PTR), VALUE :: dict
178  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
179  INTEGER(kind=C_INT), VALUE :: ndims
180  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
181  INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
182  END SUBROUTINE torch_c_dict_insert_int64
183  END INTERFACE
184 
185  sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
186 
187  cpassert(c_associated(dict%c_ptr))
188  CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
189  key=trim(key)//c_null_char, &
190  ndims=1, &
191  sizes=sizes_c, &
192  source=source)
193 #else
194  cpabort("CP2K compiled without the Torch library.")
195  mark_used(dict)
196  mark_used(key)
197  mark_used(source)
198 #endif
199  END SUBROUTINE torch_dict_insert_int64_1d
200 
201 ! **************************************************************************************************
202 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
203 !> \author Ole Schuett
204 ! **************************************************************************************************
205  SUBROUTINE torch_dict_get_int64_1d(dict, key, dest)
206  TYPE(torch_dict_type), INTENT(IN) :: dict
207  CHARACTER(len=*), INTENT(IN) :: key
208  INTEGER(kind=int_8), DIMENSION(:), POINTER :: dest
209 
210 #if defined(__LIBTORCH)
211  INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
212  TYPE(C_PTR) :: dest_c
213 
214  INTERFACE
215  SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
216  BIND(C, name="torch_c_dict_get_int64")
217  IMPORT :: c_char, c_ptr, c_int, c_int64_t
218  TYPE(C_PTR), VALUE :: dict
219  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
220  INTEGER(kind=C_INT), VALUE :: ndims
221  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
222  TYPE(C_PTR) :: dest
223  END SUBROUTINE torch_c_dict_get_int64
224  END INTERFACE
225 
226  sizes_c(:) = -1
227  dest_c = c_null_ptr
228  cpassert(c_associated(dict%c_ptr))
229  cpassert(.NOT. ASSOCIATED(dest))
230  CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
231  key=trim(key)//c_null_char, &
232  ndims=1, &
233  sizes=sizes_c, &
234  dest=dest_c)
235 
236  cpassert(all(sizes_c >= 0))
237  cpassert(c_associated(dest_c))
238 
239  sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
240  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
241 #else
242  cpabort("CP2K compiled without the Torch library.")
243  mark_used(dict)
244  mark_used(key)
245  mark_used(dest)
246 #endif
247  END SUBROUTINE torch_dict_get_int64_1d
248 
249 
250 ! **************************************************************************************************
251 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
252 !> \author Ole Schuett
253 ! **************************************************************************************************
254  SUBROUTINE torch_dict_insert_double_1d(dict, key, source)
255  TYPE(torch_dict_type), INTENT(INOUT) :: dict
256  CHARACTER(len=*), INTENT(IN) :: key
257  REAL(dp), CONTIGUOUS, DIMENSION(:), INTENT(IN) :: source
258 
259 #if defined(__LIBTORCH)
260  INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
261 
262  INTERFACE
263  SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
264  BIND(C, name="torch_c_dict_insert_double")
265  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
266  TYPE(C_PTR), VALUE :: dict
267  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
268  INTEGER(kind=C_INT), VALUE :: ndims
269  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
270  REAL(kind=c_double), DIMENSION(*) :: source
271  END SUBROUTINE torch_c_dict_insert_double
272  END INTERFACE
273 
274  sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
275 
276  cpassert(c_associated(dict%c_ptr))
277  CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
278  key=trim(key)//c_null_char, &
279  ndims=1, &
280  sizes=sizes_c, &
281  source=source)
282 #else
283  cpabort("CP2K compiled without the Torch library.")
284  mark_used(dict)
285  mark_used(key)
286  mark_used(source)
287 #endif
288  END SUBROUTINE torch_dict_insert_double_1d
289 
290 ! **************************************************************************************************
291 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
292 !> \author Ole Schuett
293 ! **************************************************************************************************
294  SUBROUTINE torch_dict_get_double_1d(dict, key, dest)
295  TYPE(torch_dict_type), INTENT(IN) :: dict
296  CHARACTER(len=*), INTENT(IN) :: key
297  REAL(dp), DIMENSION(:), POINTER :: dest
298 
299 #if defined(__LIBTORCH)
300  INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
301  TYPE(C_PTR) :: dest_c
302 
303  INTERFACE
304  SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
305  BIND(C, name="torch_c_dict_get_double")
306  IMPORT :: c_char, c_ptr, c_int, c_int64_t
307  TYPE(C_PTR), VALUE :: dict
308  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
309  INTEGER(kind=C_INT), VALUE :: ndims
310  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
311  TYPE(C_PTR) :: dest
312  END SUBROUTINE torch_c_dict_get_double
313  END INTERFACE
314 
315  sizes_c(:) = -1
316  dest_c = c_null_ptr
317  cpassert(c_associated(dict%c_ptr))
318  cpassert(.NOT. ASSOCIATED(dest))
319  CALL torch_c_dict_get_double (dict=dict%c_ptr, &
320  key=trim(key)//c_null_char, &
321  ndims=1, &
322  sizes=sizes_c, &
323  dest=dest_c)
324 
325  cpassert(all(sizes_c >= 0))
326  cpassert(c_associated(dest_c))
327 
328  sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
329  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
330 #else
331  cpabort("CP2K compiled without the Torch library.")
332  mark_used(dict)
333  mark_used(key)
334  mark_used(dest)
335 #endif
336  END SUBROUTINE torch_dict_get_double_1d
337 
338 
339 ! **************************************************************************************************
340 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
341 !> \author Ole Schuett
342 ! **************************************************************************************************
343  SUBROUTINE torch_dict_insert_float_2d(dict, key, source)
344  TYPE(torch_dict_type), INTENT(INOUT) :: dict
345  CHARACTER(len=*), INTENT(IN) :: key
346  REAL(sp), CONTIGUOUS, DIMENSION(:, :), INTENT(IN) :: source
347 
348 #if defined(__LIBTORCH)
349  INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
350 
351  INTERFACE
352  SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
353  BIND(C, name="torch_c_dict_insert_float")
354  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
355  TYPE(C_PTR), VALUE :: dict
356  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
357  INTEGER(kind=C_INT), VALUE :: ndims
358  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
359  REAL(kind=c_float), DIMENSION(*) :: source
360  END SUBROUTINE torch_c_dict_insert_float
361  END INTERFACE
362 
363  sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
364  sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
365 
366  cpassert(c_associated(dict%c_ptr))
367  CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
368  key=trim(key)//c_null_char, &
369  ndims=2, &
370  sizes=sizes_c, &
371  source=source)
372 #else
373  cpabort("CP2K compiled without the Torch library.")
374  mark_used(dict)
375  mark_used(key)
376  mark_used(source)
377 #endif
378  END SUBROUTINE torch_dict_insert_float_2d
379 
380 ! **************************************************************************************************
381 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
382 !> \author Ole Schuett
383 ! **************************************************************************************************
384  SUBROUTINE torch_dict_get_float_2d(dict, key, dest)
385  TYPE(torch_dict_type), INTENT(IN) :: dict
386  CHARACTER(len=*), INTENT(IN) :: key
387  REAL(sp), DIMENSION(:, :), POINTER :: dest
388 
389 #if defined(__LIBTORCH)
390  INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
391  TYPE(C_PTR) :: dest_c
392 
393  INTERFACE
394  SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
395  BIND(C, name="torch_c_dict_get_float")
396  IMPORT :: c_char, c_ptr, c_int, c_int64_t
397  TYPE(C_PTR), VALUE :: dict
398  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
399  INTEGER(kind=C_INT), VALUE :: ndims
400  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
401  TYPE(C_PTR) :: dest
402  END SUBROUTINE torch_c_dict_get_float
403  END INTERFACE
404 
405  sizes_c(:) = -1
406  dest_c = c_null_ptr
407  cpassert(c_associated(dict%c_ptr))
408  cpassert(.NOT. ASSOCIATED(dest))
409  CALL torch_c_dict_get_float (dict=dict%c_ptr, &
410  key=trim(key)//c_null_char, &
411  ndims=2, &
412  sizes=sizes_c, &
413  dest=dest_c)
414 
415  cpassert(all(sizes_c >= 0))
416  cpassert(c_associated(dest_c))
417 
418  sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
419  sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
420  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
421 #else
422  cpabort("CP2K compiled without the Torch library.")
423  mark_used(dict)
424  mark_used(key)
425  mark_used(dest)
426 #endif
427  END SUBROUTINE torch_dict_get_float_2d
428 
429 
430 ! **************************************************************************************************
431 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
432 !> \author Ole Schuett
433 ! **************************************************************************************************
434  SUBROUTINE torch_dict_insert_int64_2d(dict, key, source)
435  TYPE(torch_dict_type), INTENT(INOUT) :: dict
436  CHARACTER(len=*), INTENT(IN) :: key
437  INTEGER(kind=int_8), CONTIGUOUS, DIMENSION(:, :), INTENT(IN) :: source
438 
439 #if defined(__LIBTORCH)
440  INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
441 
442  INTERFACE
443  SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
444  BIND(C, name="torch_c_dict_insert_int64")
445  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
446  TYPE(C_PTR), VALUE :: dict
447  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
448  INTEGER(kind=C_INT), VALUE :: ndims
449  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
450  INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
451  END SUBROUTINE torch_c_dict_insert_int64
452  END INTERFACE
453 
454  sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
455  sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
456 
457  cpassert(c_associated(dict%c_ptr))
458  CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
459  key=trim(key)//c_null_char, &
460  ndims=2, &
461  sizes=sizes_c, &
462  source=source)
463 #else
464  cpabort("CP2K compiled without the Torch library.")
465  mark_used(dict)
466  mark_used(key)
467  mark_used(source)
468 #endif
469  END SUBROUTINE torch_dict_insert_int64_2d
470 
471 ! **************************************************************************************************
472 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
473 !> \author Ole Schuett
474 ! **************************************************************************************************
475  SUBROUTINE torch_dict_get_int64_2d(dict, key, dest)
476  TYPE(torch_dict_type), INTENT(IN) :: dict
477  CHARACTER(len=*), INTENT(IN) :: key
478  INTEGER(kind=int_8), DIMENSION(:, :), POINTER :: dest
479 
480 #if defined(__LIBTORCH)
481  INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
482  TYPE(C_PTR) :: dest_c
483 
484  INTERFACE
485  SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
486  BIND(C, name="torch_c_dict_get_int64")
487  IMPORT :: c_char, c_ptr, c_int, c_int64_t
488  TYPE(C_PTR), VALUE :: dict
489  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
490  INTEGER(kind=C_INT), VALUE :: ndims
491  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
492  TYPE(C_PTR) :: dest
493  END SUBROUTINE torch_c_dict_get_int64
494  END INTERFACE
495 
496  sizes_c(:) = -1
497  dest_c = c_null_ptr
498  cpassert(c_associated(dict%c_ptr))
499  cpassert(.NOT. ASSOCIATED(dest))
500  CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
501  key=trim(key)//c_null_char, &
502  ndims=2, &
503  sizes=sizes_c, &
504  dest=dest_c)
505 
506  cpassert(all(sizes_c >= 0))
507  cpassert(c_associated(dest_c))
508 
509  sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
510  sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
511  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
512 #else
513  cpabort("CP2K compiled without the Torch library.")
514  mark_used(dict)
515  mark_used(key)
516  mark_used(dest)
517 #endif
518  END SUBROUTINE torch_dict_get_int64_2d
519 
520 
521 ! **************************************************************************************************
522 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
523 !> \author Ole Schuett
524 ! **************************************************************************************************
525  SUBROUTINE torch_dict_insert_double_2d(dict, key, source)
526  TYPE(torch_dict_type), INTENT(INOUT) :: dict
527  CHARACTER(len=*), INTENT(IN) :: key
528  REAL(dp), CONTIGUOUS, DIMENSION(:, :), INTENT(IN) :: source
529 
530 #if defined(__LIBTORCH)
531  INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
532 
533  INTERFACE
534  SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
535  BIND(C, name="torch_c_dict_insert_double")
536  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
537  TYPE(C_PTR), VALUE :: dict
538  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
539  INTEGER(kind=C_INT), VALUE :: ndims
540  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
541  REAL(kind=c_double), DIMENSION(*) :: source
542  END SUBROUTINE torch_c_dict_insert_double
543  END INTERFACE
544 
545  sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
546  sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
547 
548  cpassert(c_associated(dict%c_ptr))
549  CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
550  key=trim(key)//c_null_char, &
551  ndims=2, &
552  sizes=sizes_c, &
553  source=source)
554 #else
555  cpabort("CP2K compiled without the Torch library.")
556  mark_used(dict)
557  mark_used(key)
558  mark_used(source)
559 #endif
560  END SUBROUTINE torch_dict_insert_double_2d
561 
562 ! **************************************************************************************************
563 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
564 !> \author Ole Schuett
565 ! **************************************************************************************************
566  SUBROUTINE torch_dict_get_double_2d(dict, key, dest)
567  TYPE(torch_dict_type), INTENT(IN) :: dict
568  CHARACTER(len=*), INTENT(IN) :: key
569  REAL(dp), DIMENSION(:, :), POINTER :: dest
570 
571 #if defined(__LIBTORCH)
572  INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
573  TYPE(C_PTR) :: dest_c
574 
575  INTERFACE
576  SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
577  BIND(C, name="torch_c_dict_get_double")
578  IMPORT :: c_char, c_ptr, c_int, c_int64_t
579  TYPE(C_PTR), VALUE :: dict
580  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
581  INTEGER(kind=C_INT), VALUE :: ndims
582  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
583  TYPE(C_PTR) :: dest
584  END SUBROUTINE torch_c_dict_get_double
585  END INTERFACE
586 
587  sizes_c(:) = -1
588  dest_c = c_null_ptr
589  cpassert(c_associated(dict%c_ptr))
590  cpassert(.NOT. ASSOCIATED(dest))
591  CALL torch_c_dict_get_double (dict=dict%c_ptr, &
592  key=trim(key)//c_null_char, &
593  ndims=2, &
594  sizes=sizes_c, &
595  dest=dest_c)
596 
597  cpassert(all(sizes_c >= 0))
598  cpassert(c_associated(dest_c))
599 
600  sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
601  sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
602  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
603 #else
604  cpabort("CP2K compiled without the Torch library.")
605  mark_used(dict)
606  mark_used(key)
607  mark_used(dest)
608 #endif
609  END SUBROUTINE torch_dict_get_double_2d
610 
611 
612 ! **************************************************************************************************
613 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
614 !> \author Ole Schuett
615 ! **************************************************************************************************
616  SUBROUTINE torch_dict_insert_float_3d(dict, key, source)
617  TYPE(torch_dict_type), INTENT(INOUT) :: dict
618  CHARACTER(len=*), INTENT(IN) :: key
619  REAL(sp), CONTIGUOUS, DIMENSION(:, :, :), INTENT(IN) :: source
620 
621 #if defined(__LIBTORCH)
622  INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
623 
624  INTERFACE
625  SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
626  BIND(C, name="torch_c_dict_insert_float")
627  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
628  TYPE(C_PTR), VALUE :: dict
629  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
630  INTEGER(kind=C_INT), VALUE :: ndims
631  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
632  REAL(kind=c_float), DIMENSION(*) :: source
633  END SUBROUTINE torch_c_dict_insert_float
634  END INTERFACE
635 
636  sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
637  sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
638  sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
639 
640  cpassert(c_associated(dict%c_ptr))
641  CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
642  key=trim(key)//c_null_char, &
643  ndims=3, &
644  sizes=sizes_c, &
645  source=source)
646 #else
647  cpabort("CP2K compiled without the Torch library.")
648  mark_used(dict)
649  mark_used(key)
650  mark_used(source)
651 #endif
652  END SUBROUTINE torch_dict_insert_float_3d
653 
654 ! **************************************************************************************************
655 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
656 !> \author Ole Schuett
657 ! **************************************************************************************************
658  SUBROUTINE torch_dict_get_float_3d(dict, key, dest)
659  TYPE(torch_dict_type), INTENT(IN) :: dict
660  CHARACTER(len=*), INTENT(IN) :: key
661  REAL(sp), DIMENSION(:, :, :), POINTER :: dest
662 
663 #if defined(__LIBTORCH)
664  INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
665  TYPE(C_PTR) :: dest_c
666 
667  INTERFACE
668  SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
669  BIND(C, name="torch_c_dict_get_float")
670  IMPORT :: c_char, c_ptr, c_int, c_int64_t
671  TYPE(C_PTR), VALUE :: dict
672  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
673  INTEGER(kind=C_INT), VALUE :: ndims
674  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
675  TYPE(C_PTR) :: dest
676  END SUBROUTINE torch_c_dict_get_float
677  END INTERFACE
678 
679  sizes_c(:) = -1
680  dest_c = c_null_ptr
681  cpassert(c_associated(dict%c_ptr))
682  cpassert(.NOT. ASSOCIATED(dest))
683  CALL torch_c_dict_get_float (dict=dict%c_ptr, &
684  key=trim(key)//c_null_char, &
685  ndims=3, &
686  sizes=sizes_c, &
687  dest=dest_c)
688 
689  cpassert(all(sizes_c >= 0))
690  cpassert(c_associated(dest_c))
691 
692  sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
693  sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
694  sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
695  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
696 #else
697  cpabort("CP2K compiled without the Torch library.")
698  mark_used(dict)
699  mark_used(key)
700  mark_used(dest)
701 #endif
702  END SUBROUTINE torch_dict_get_float_3d
703 
704 
705 ! **************************************************************************************************
706 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
707 !> \author Ole Schuett
708 ! **************************************************************************************************
709  SUBROUTINE torch_dict_insert_int64_3d(dict, key, source)
710  TYPE(torch_dict_type), INTENT(INOUT) :: dict
711  CHARACTER(len=*), INTENT(IN) :: key
712  INTEGER(kind=int_8), CONTIGUOUS, DIMENSION(:, :, :), INTENT(IN) :: source
713 
714 #if defined(__LIBTORCH)
715  INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
716 
717  INTERFACE
718  SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
719  BIND(C, name="torch_c_dict_insert_int64")
720  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
721  TYPE(C_PTR), VALUE :: dict
722  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
723  INTEGER(kind=C_INT), VALUE :: ndims
724  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
725  INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
726  END SUBROUTINE torch_c_dict_insert_int64
727  END INTERFACE
728 
729  sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
730  sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
731  sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
732 
733  cpassert(c_associated(dict%c_ptr))
734  CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
735  key=trim(key)//c_null_char, &
736  ndims=3, &
737  sizes=sizes_c, &
738  source=source)
739 #else
740  cpabort("CP2K compiled without the Torch library.")
741  mark_used(dict)
742  mark_used(key)
743  mark_used(source)
744 #endif
745  END SUBROUTINE torch_dict_insert_int64_3d
746 
747 ! **************************************************************************************************
748 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
749 !> \author Ole Schuett
750 ! **************************************************************************************************
751  SUBROUTINE torch_dict_get_int64_3d(dict, key, dest)
752  TYPE(torch_dict_type), INTENT(IN) :: dict
753  CHARACTER(len=*), INTENT(IN) :: key
754  INTEGER(kind=int_8), DIMENSION(:, :, :), POINTER :: dest
755 
756 #if defined(__LIBTORCH)
757  INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
758  TYPE(C_PTR) :: dest_c
759 
760  INTERFACE
761  SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
762  BIND(C, name="torch_c_dict_get_int64")
763  IMPORT :: c_char, c_ptr, c_int, c_int64_t
764  TYPE(C_PTR), VALUE :: dict
765  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
766  INTEGER(kind=C_INT), VALUE :: ndims
767  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
768  TYPE(C_PTR) :: dest
769  END SUBROUTINE torch_c_dict_get_int64
770  END INTERFACE
771 
772  sizes_c(:) = -1
773  dest_c = c_null_ptr
774  cpassert(c_associated(dict%c_ptr))
775  cpassert(.NOT. ASSOCIATED(dest))
776  CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
777  key=trim(key)//c_null_char, &
778  ndims=3, &
779  sizes=sizes_c, &
780  dest=dest_c)
781 
782  cpassert(all(sizes_c >= 0))
783  cpassert(c_associated(dest_c))
784 
785  sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
786  sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
787  sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
788  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
789 #else
790  cpabort("CP2K compiled without the Torch library.")
791  mark_used(dict)
792  mark_used(key)
793  mark_used(dest)
794 #endif
795  END SUBROUTINE torch_dict_get_int64_3d
796 
797 
798 ! **************************************************************************************************
799 !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
800 !> \author Ole Schuett
801 ! **************************************************************************************************
802  SUBROUTINE torch_dict_insert_double_3d(dict, key, source)
803  TYPE(torch_dict_type), INTENT(INOUT) :: dict
804  CHARACTER(len=*), INTENT(IN) :: key
805  REAL(dp), CONTIGUOUS, DIMENSION(:, :, :), INTENT(IN) :: source
806 
807 #if defined(__LIBTORCH)
808  INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
809 
810  INTERFACE
811  SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
812  BIND(C, name="torch_c_dict_insert_double")
813  IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
814  TYPE(C_PTR), VALUE :: dict
815  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
816  INTEGER(kind=C_INT), VALUE :: ndims
817  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
818  REAL(kind=c_double), DIMENSION(*) :: source
819  END SUBROUTINE torch_c_dict_insert_double
820  END INTERFACE
821 
822  sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
823  sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
824  sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
825 
826  cpassert(c_associated(dict%c_ptr))
827  CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
828  key=trim(key)//c_null_char, &
829  ndims=3, &
830  sizes=sizes_c, &
831  source=source)
832 #else
833  cpabort("CP2K compiled without the Torch library.")
834  mark_used(dict)
835  mark_used(key)
836  mark_used(source)
837 #endif
838  END SUBROUTINE torch_dict_insert_double_3d
839 
840 ! **************************************************************************************************
841 !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
842 !> \author Ole Schuett
843 ! **************************************************************************************************
844  SUBROUTINE torch_dict_get_double_3d(dict, key, dest)
845  TYPE(torch_dict_type), INTENT(IN) :: dict
846  CHARACTER(len=*), INTENT(IN) :: key
847  REAL(dp), DIMENSION(:, :, :), POINTER :: dest
848 
849 #if defined(__LIBTORCH)
850  INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
851  TYPE(C_PTR) :: dest_c
852 
853  INTERFACE
854  SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
855  BIND(C, name="torch_c_dict_get_double")
856  IMPORT :: c_char, c_ptr, c_int, c_int64_t
857  TYPE(C_PTR), VALUE :: dict
858  CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
859  INTEGER(kind=C_INT), VALUE :: ndims
860  INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
861  TYPE(C_PTR) :: dest
862  END SUBROUTINE torch_c_dict_get_double
863  END INTERFACE
864 
865  sizes_c(:) = -1
866  dest_c = c_null_ptr
867  cpassert(c_associated(dict%c_ptr))
868  cpassert(.NOT. ASSOCIATED(dest))
869  CALL torch_c_dict_get_double (dict=dict%c_ptr, &
870  key=trim(key)//c_null_char, &
871  ndims=3, &
872  sizes=sizes_c, &
873  dest=dest_c)
874 
875  cpassert(all(sizes_c >= 0))
876  cpassert(c_associated(dest_c))
877 
878  sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
879  sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
880  sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
881  CALL c_f_pointer(dest_c, dest, shape=sizes_f)
882 #else
883  cpabort("CP2K compiled without the Torch library.")
884  mark_used(dict)
885  mark_used(key)
886  mark_used(dest)
887 #endif
888  END SUBROUTINE torch_dict_get_double_3d
889 
890 
891 ! **************************************************************************************************
892 !> \brief Creates an empty Torch dictionary.
893 !> \author Ole Schuett
894 ! **************************************************************************************************
895  SUBROUTINE torch_dict_create(dict)
896  TYPE(torch_dict_type), INTENT(INOUT) :: dict
897 
898 #if defined(__LIBTORCH)
899  INTERFACE
900  SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
901  IMPORT :: c_ptr
902  TYPE(c_ptr) :: dict
903  END SUBROUTINE torch_c_dict_create
904  END INTERFACE
905 
906  cpassert(.NOT. c_associated(dict%c_ptr))
907  CALL torch_c_dict_create(dict=dict%c_ptr)
908  cpassert(c_associated(dict%c_ptr))
909 #else
910  cpabort("CP2K was compiled without Torch library.")
911  mark_used(dict)
912 #endif
913  END SUBROUTINE torch_dict_create
914 
915 ! **************************************************************************************************
916 !> \brief Releases a Torch dictionary and all its ressources.
917 !> \author Ole Schuett
918 ! **************************************************************************************************
919  SUBROUTINE torch_dict_release(dict)
920  TYPE(torch_dict_type), INTENT(INOUT) :: dict
921 
922 #if defined(__LIBTORCH)
923  INTERFACE
924  SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
925  IMPORT :: c_ptr
926  TYPE(c_ptr), VALUE :: dict
927  END SUBROUTINE torch_c_dict_release
928  END INTERFACE
929 
930  cpassert(c_associated(dict%c_ptr))
931  CALL torch_c_dict_release(dict=dict%c_ptr)
932  dict%c_ptr = c_null_ptr
933 #else
934  cpabort("CP2K was compiled without Torch library.")
935  mark_used(dict)
936 #endif
937  END SUBROUTINE torch_dict_release
938 
939 ! **************************************************************************************************
940 !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
941 !> \author Ole Schuett
942 ! **************************************************************************************************
943  SUBROUTINE torch_model_load(model, filename)
944  TYPE(torch_model_type), INTENT(INOUT) :: model
945  CHARACTER(len=*), INTENT(IN) :: filename
946 
947 #if defined(__LIBTORCH)
948  INTERFACE
949  SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
950  IMPORT :: c_ptr, c_char
951  TYPE(c_ptr) :: model
952  CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
953  END SUBROUTINE torch_c_model_load
954  END INTERFACE
955 
956  cpassert(.NOT. c_associated(model%c_ptr))
957  CALL torch_c_model_load(model=model%c_ptr, filename=trim(filename)//c_null_char)
958  cpassert(c_associated(model%c_ptr))
959 #else
960  cpabort("CP2K was compiled without Torch library.")
961  mark_used(model)
962  mark_used(filename)
963 #endif
964  END SUBROUTINE torch_model_load
965 
966 ! **************************************************************************************************
967 !> \brief Evaluates the given Torch model. (In Torch lingo this operation is called forward())
968 !> \author Ole Schuett
969 ! **************************************************************************************************
970  SUBROUTINE torch_model_eval(model, inputs, outputs)
971  TYPE(torch_model_type), INTENT(INOUT) :: model
972  TYPE(torch_dict_type), INTENT(IN) :: inputs
973  TYPE(torch_dict_type), INTENT(INOUT) :: outputs
974 
975 #if defined(__LIBTORCH)
976  INTERFACE
977  SUBROUTINE torch_c_model_eval(model, inputs, outputs) BIND(C, name="torch_c_model_eval")
978  IMPORT :: c_ptr
979  TYPE(c_ptr), VALUE :: model
980  TYPE(c_ptr), VALUE :: inputs
981  TYPE(c_ptr), VALUE :: outputs
982  END SUBROUTINE torch_c_model_eval
983  END INTERFACE
984 
985  cpassert(c_associated(model%c_ptr))
986  cpassert(c_associated(inputs%c_ptr))
987  cpassert(c_associated(outputs%c_ptr))
988  CALL torch_c_model_eval(model=model%c_ptr, &
989  inputs=inputs%c_ptr, &
990  outputs=outputs%c_ptr)
991 #else
992  cpabort("CP2K was compiled without Torch library.")
993  mark_used(model)
994  mark_used(inputs)
995  mark_used(outputs)
996 #endif
997  END SUBROUTINE torch_model_eval
998 
999 ! **************************************************************************************************
1000 !> \brief Releases a Torch model and all its ressources.
1001 !> \author Ole Schuett
1002 ! **************************************************************************************************
1003  SUBROUTINE torch_model_release(model)
1004  TYPE(torch_model_type), INTENT(INOUT) :: model
1005 
1006 #if defined(__LIBTORCH)
1007  INTERFACE
1008  SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
1009  IMPORT :: c_ptr
1010  TYPE(c_ptr), VALUE :: model
1011  END SUBROUTINE torch_c_model_release
1012  END INTERFACE
1013 
1014  cpassert(c_associated(model%c_ptr))
1015  CALL torch_c_model_release(model=model%c_ptr)
1016  model%c_ptr = c_null_ptr
1017 #else
1018  cpabort("CP2K was compiled without Torch library.")
1019  mark_used(model)
1020 #endif
1021  END SUBROUTINE torch_model_release
1022 
1023 ! **************************************************************************************************
1024 !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
1025 !> \author Ole Schuett
1026 ! **************************************************************************************************
1027  FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
1028  CHARACTER(len=*), INTENT(IN) :: filename, key
1029  CHARACTER(:), ALLOCATABLE :: res
1030 
1031 #if defined(__LIBTORCH)
1032  CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
1033  POINTER :: content_f
1034  INTEGER :: i
1035  INTEGER :: length
1036  TYPE(c_ptr) :: content_c
1037 
1038  INTERFACE
1039  SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
1040  BIND(C, name="torch_c_model_read_metadata")
1041  IMPORT :: c_char, c_ptr, c_int
1042  CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
1043  TYPE(c_ptr) :: content
1044  INTEGER(kind=C_INT) :: length
1045  END SUBROUTINE torch_c_model_read_metadata
1046  END INTERFACE
1047 
1048  content_c = c_null_ptr
1049  length = -1
1050  CALL torch_c_model_read_metadata(filename=trim(filename)//c_null_char, &
1051  key=trim(key)//c_null_char, &
1052  content=content_c, &
1053  length=length)
1054  cpassert(c_associated(content_c))
1055  cpassert(length >= 0)
1056 
1057  CALL c_f_pointer(content_c, content_f, shape=(/length + 1/))
1058  cpassert(content_f(length + 1) == c_null_char)
1059 
1060  ALLOCATE (CHARACTER(LEN=length) :: res)
1061  DO i = 1, length
1062  cpassert(content_f(i) /= c_null_char)
1063  res(i:i) = content_f(i)
1064  END DO
1065 
1066  DEALLOCATE (content_f) ! Was allocated on the C side.
1067 #else
1068  cpabort("CP2K was compiled without Torch library.")
1069  mark_used(filename)
1070  mark_used(key)
1071  mark_used(res)
1072 #endif
1073  END FUNCTION torch_model_read_metadata
1074 
1075 ! **************************************************************************************************
1076 !> \brief Returns true iff the Torch CUDA backend is available.
1077 !> \author Ole Schuett
1078 ! **************************************************************************************************
1079  FUNCTION torch_cuda_is_available() RESULT(res)
1080  LOGICAL :: res
1081 
1082 #if defined(__LIBTORCH)
1083  INTERFACE
1084  FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
1085  IMPORT :: c_bool
1086  LOGICAL(C_BOOL) :: torch_c_cuda_is_available
1087  END FUNCTION torch_c_cuda_is_available
1088  END INTERFACE
1089 
1090  res = torch_c_cuda_is_available()
1091 #else
1092  cpabort("CP2K was compiled without Torch library.")
1093  mark_used(res)
1094 #endif
1095  END FUNCTION torch_cuda_is_available
1096 
1097 ! **************************************************************************************************
1098 !> \brief Set whether to allow the use of TF32.
1099 !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
1100 !> See https://pytorch.org/docs/stable/notes/cuda.html
1101 !> \author Gabriele Tocci
1102 ! **************************************************************************************************
1103  SUBROUTINE torch_allow_tf32(allow_tf32)
1104  LOGICAL, INTENT(IN) :: allow_tf32
1105 
1106 #if defined(__LIBTORCH)
1107  INTERFACE
1108  SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
1109  IMPORT :: c_bool
1110  LOGICAL(C_BOOL), VALUE :: allow_tf32
1111  END SUBROUTINE torch_c_allow_tf32
1112  END INTERFACE
1113 
1114  CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, c_bool))
1115 #else
1116  cpabort("CP2K was compiled without Torch library.")
1117  mark_used(allow_tf32)
1118 #endif
1119  END SUBROUTINE torch_allow_tf32
1120 
1121 ! **************************************************************************************************
1122 !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
1123 !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
1124 !> \author Gabriele Tocci
1125 ! **************************************************************************************************
1126  SUBROUTINE torch_model_freeze(model)
1127  TYPE(torch_model_type), INTENT(INOUT) :: model
1128 
1129 #if defined(__LIBTORCH)
1130  INTERFACE
1131  SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
1132  IMPORT :: c_ptr
1133  TYPE(c_ptr), VALUE :: model
1134  END SUBROUTINE torch_c_model_freeze
1135  END INTERFACE
1136 
1137  cpassert(c_associated(model%c_ptr))
1138  CALL torch_c_model_freeze(model=model%c_ptr)
1139 #else
1140  cpabort("CP2K was compiled without Torch library.")
1141  mark_used(model)
1142 #endif
1143  END SUBROUTINE torch_model_freeze
1144 
1145 END MODULE torch_api
Defines the basic variable types.
Definition: kinds.F:23
integer, parameter, public int_8
Definition: kinds.F:54
integer, parameter, public dp
Definition: kinds.F:34
integer, parameter, public sp
Definition: kinds.F:33
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition: torch_api.F:920
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition: torch_api.F:944
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition: torch_api.F:896
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition: torch_api.F:1004
subroutine, public torch_allow_tf32(allow_tf32)
Set whether to allow the use of TF32. Needed due to changes in defaults from pytorch 1....
Definition: torch_api.F:1104
subroutine, public torch_model_eval(model, inputs, outputs)
Evaluates the given Torch model. (In Torch lingo this operation is called forward())
Definition: torch_api.F:971
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
Definition: torch_api.F:1127
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition: torch_api.F:1028
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition: torch_api.F:1080