(git:e68414f)
Loading...
Searching...
No Matches
torch_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
8 USE iso_c_binding, ONLY: c_associated, &
9 c_bool, &
10 c_char, &
11 c_float, &
12 c_double, &
13 c_f_pointer, &
14 c_int, &
15 c_null_char, &
16 c_null_ptr, &
17 c_ptr, &
18 c_int32_t, &
19 c_int64_t
20
22
23#include "./base/base_uses.f90"
24
25 IMPLICIT NONE
26
27 PRIVATE
28
30 PRIVATE
31 TYPE(C_PTR) :: c_ptr = c_null_ptr
32 END TYPE torch_tensor_type
33
35 PRIVATE
36 TYPE(C_PTR) :: c_ptr = c_null_ptr
37 END TYPE torch_dict_type
38
40 PRIVATE
41 TYPE(C_PTR) :: c_ptr = c_null_ptr
42 END TYPE torch_model_type
43
45 MODULE PROCEDURE torch_tensor_from_array_int32_1d
46 MODULE PROCEDURE torch_tensor_from_array_float_1d
47 MODULE PROCEDURE torch_tensor_from_array_int64_1d
48 MODULE PROCEDURE torch_tensor_from_array_double_1d
49 MODULE PROCEDURE torch_tensor_from_array_int32_2d
50 MODULE PROCEDURE torch_tensor_from_array_float_2d
51 MODULE PROCEDURE torch_tensor_from_array_int64_2d
52 MODULE PROCEDURE torch_tensor_from_array_double_2d
53 MODULE PROCEDURE torch_tensor_from_array_int32_3d
54 MODULE PROCEDURE torch_tensor_from_array_float_3d
55 MODULE PROCEDURE torch_tensor_from_array_int64_3d
56 MODULE PROCEDURE torch_tensor_from_array_double_3d
57 END INTERFACE torch_tensor_from_array
58
60 MODULE PROCEDURE torch_tensor_data_ptr_int32_1d
61 MODULE PROCEDURE torch_tensor_data_ptr_float_1d
62 MODULE PROCEDURE torch_tensor_data_ptr_int64_1d
63 MODULE PROCEDURE torch_tensor_data_ptr_double_1d
64 MODULE PROCEDURE torch_tensor_data_ptr_int32_2d
65 MODULE PROCEDURE torch_tensor_data_ptr_float_2d
66 MODULE PROCEDURE torch_tensor_data_ptr_int64_2d
67 MODULE PROCEDURE torch_tensor_data_ptr_double_2d
68 MODULE PROCEDURE torch_tensor_data_ptr_int32_3d
69 MODULE PROCEDURE torch_tensor_data_ptr_float_3d
70 MODULE PROCEDURE torch_tensor_data_ptr_int64_3d
71 MODULE PROCEDURE torch_tensor_data_ptr_double_3d
72 END INTERFACE torch_tensor_data_ptr
73
75 MODULE PROCEDURE torch_model_get_attr_string
76 MODULE PROCEDURE torch_model_get_attr_double
77 MODULE PROCEDURE torch_model_get_attr_int64
78 MODULE PROCEDURE torch_model_get_attr_int32
79 MODULE PROCEDURE torch_model_get_attr_strlist
80 END INTERFACE torch_model_get_attr
81
88
89CONTAINS
90
91
92
93! **************************************************************************************************
94!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
95!> The source must be an ALLOCATABLE to prevent passing a temporary array.
96!> \author Ole Schuett
97! **************************************************************************************************
98 SUBROUTINE torch_tensor_from_array_int32_1d(tensor, source, requires_grad)
99 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
100 INTEGER(kind=int_4), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
101 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
102
103#if defined(__LIBTORCH)
104 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
105 LOGICAL :: my_req_grad
106
107 INTERFACE
108 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
109 BIND(C, name="torch_c_tensor_from_array_int32")
110 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
111 TYPE(c_ptr) :: tensor
112 LOGICAL(kind=C_BOOL), VALUE :: req_grad
113 INTEGER(kind=C_INT), VALUE :: ndims
114 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
115 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
116 END SUBROUTINE torch_c_tensor_from_array_int32
117 END INTERFACE
118
119 my_req_grad = .false.
120 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
121
122 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
123
124 cpassert(.NOT. c_associated(tensor%c_ptr))
125 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
126 req_grad=LOGICAL(my_req_grad, C_BOOL), &
127 ndims=1, &
128 sizes=sizes_c, &
129 source=source)
130 cpassert(c_associated(tensor%c_ptr))
131#else
132 cpabort("CP2K compiled without the Torch library.")
133 mark_used(tensor)
134 mark_used(source)
135 mark_used(requires_grad)
136#endif
137 END SUBROUTINE torch_tensor_from_array_int32_1d
138
139! **************************************************************************************************
140!> \brief Copies data from a Torch tensor to an array.
141!> The returned pointer is only valide during the tensor's lifetime!
142!> \author Ole Schuett
143! **************************************************************************************************
144 SUBROUTINE torch_tensor_data_ptr_int32_1d(tensor, data_ptr)
145 TYPE(torch_tensor_type), INTENT(IN) :: tensor
146 INTEGER(kind=int_4), DIMENSION(:), POINTER :: data_ptr
147
148#if defined(__LIBTORCH)
149 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
150 TYPE(c_ptr) :: data_ptr_c
151
152 INTERFACE
153 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
154 BIND(C, name="torch_c_tensor_data_ptr_int32")
155 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
156 TYPE(c_ptr), VALUE :: tensor
157 INTEGER(kind=C_INT), VALUE :: ndims
158 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
159 TYPE(c_ptr) :: data_ptr
160 END SUBROUTINE torch_c_tensor_data_ptr_int32
161 END INTERFACE
162
163 sizes_c(:) = -1
164 data_ptr_c = c_null_ptr
165 cpassert(c_associated(tensor%c_ptr))
166 cpassert(.NOT. ASSOCIATED(data_ptr))
167 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
168 ndims=1, &
169 sizes=sizes_c, &
170 data_ptr=data_ptr_c)
171
172 cpassert(all(sizes_c >= 0))
173 cpassert(c_associated(data_ptr_c))
174
175 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
176 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
177#else
178 cpabort("CP2K compiled without the Torch library.")
179 mark_used(tensor)
180 mark_used(data_ptr)
181#endif
182 END SUBROUTINE torch_tensor_data_ptr_int32_1d
183
184
185! **************************************************************************************************
186!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
187!> The source must be an ALLOCATABLE to prevent passing a temporary array.
188!> \author Ole Schuett
189! **************************************************************************************************
190 SUBROUTINE torch_tensor_from_array_float_1d(tensor, source, requires_grad)
191 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
192 REAL(sp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
193 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
194
195#if defined(__LIBTORCH)
196 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
197 LOGICAL :: my_req_grad
198
199 INTERFACE
200 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
201 BIND(C, name="torch_c_tensor_from_array_float")
202 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
203 TYPE(c_ptr) :: tensor
204 LOGICAL(kind=C_BOOL), VALUE :: req_grad
205 INTEGER(kind=C_INT), VALUE :: ndims
206 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
207 REAL(kind=c_float), DIMENSION(*) :: source
208 END SUBROUTINE torch_c_tensor_from_array_float
209 END INTERFACE
210
211 my_req_grad = .false.
212 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
213
214 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
215
216 cpassert(.NOT. c_associated(tensor%c_ptr))
217 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
218 req_grad=LOGICAL(my_req_grad, C_BOOL), &
219 ndims=1, &
220 sizes=sizes_c, &
221 source=source)
222 cpassert(c_associated(tensor%c_ptr))
223#else
224 cpabort("CP2K compiled without the Torch library.")
225 mark_used(tensor)
226 mark_used(source)
227 mark_used(requires_grad)
228#endif
229 END SUBROUTINE torch_tensor_from_array_float_1d
230
231! **************************************************************************************************
232!> \brief Copies data from a Torch tensor to an array.
233!> The returned pointer is only valide during the tensor's lifetime!
234!> \author Ole Schuett
235! **************************************************************************************************
236 SUBROUTINE torch_tensor_data_ptr_float_1d(tensor, data_ptr)
237 TYPE(torch_tensor_type), INTENT(IN) :: tensor
238 REAL(sp), DIMENSION(:), POINTER :: data_ptr
239
240#if defined(__LIBTORCH)
241 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
242 TYPE(c_ptr) :: data_ptr_c
243
244 INTERFACE
245 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
246 BIND(C, name="torch_c_tensor_data_ptr_float")
247 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
248 TYPE(c_ptr), VALUE :: tensor
249 INTEGER(kind=C_INT), VALUE :: ndims
250 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
251 TYPE(c_ptr) :: data_ptr
252 END SUBROUTINE torch_c_tensor_data_ptr_float
253 END INTERFACE
254
255 sizes_c(:) = -1
256 data_ptr_c = c_null_ptr
257 cpassert(c_associated(tensor%c_ptr))
258 cpassert(.NOT. ASSOCIATED(data_ptr))
259 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
260 ndims=1, &
261 sizes=sizes_c, &
262 data_ptr=data_ptr_c)
263
264 cpassert(all(sizes_c >= 0))
265 cpassert(c_associated(data_ptr_c))
266
267 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
268 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
269#else
270 cpabort("CP2K compiled without the Torch library.")
271 mark_used(tensor)
272 mark_used(data_ptr)
273#endif
274 END SUBROUTINE torch_tensor_data_ptr_float_1d
275
276
277! **************************************************************************************************
278!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
279!> The source must be an ALLOCATABLE to prevent passing a temporary array.
280!> \author Ole Schuett
281! **************************************************************************************************
282 SUBROUTINE torch_tensor_from_array_int64_1d(tensor, source, requires_grad)
283 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
284 INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
285 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
286
287#if defined(__LIBTORCH)
288 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
289 LOGICAL :: my_req_grad
290
291 INTERFACE
292 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
293 BIND(C, name="torch_c_tensor_from_array_int64")
294 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
295 TYPE(c_ptr) :: tensor
296 LOGICAL(kind=C_BOOL), VALUE :: req_grad
297 INTEGER(kind=C_INT), VALUE :: ndims
298 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
299 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
300 END SUBROUTINE torch_c_tensor_from_array_int64
301 END INTERFACE
302
303 my_req_grad = .false.
304 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
305
306 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
307
308 cpassert(.NOT. c_associated(tensor%c_ptr))
309 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
310 req_grad=LOGICAL(my_req_grad, C_BOOL), &
311 ndims=1, &
312 sizes=sizes_c, &
313 source=source)
314 cpassert(c_associated(tensor%c_ptr))
315#else
316 cpabort("CP2K compiled without the Torch library.")
317 mark_used(tensor)
318 mark_used(source)
319 mark_used(requires_grad)
320#endif
321 END SUBROUTINE torch_tensor_from_array_int64_1d
322
323! **************************************************************************************************
324!> \brief Copies data from a Torch tensor to an array.
325!> The returned pointer is only valide during the tensor's lifetime!
326!> \author Ole Schuett
327! **************************************************************************************************
328 SUBROUTINE torch_tensor_data_ptr_int64_1d(tensor, data_ptr)
329 TYPE(torch_tensor_type), INTENT(IN) :: tensor
330 INTEGER(kind=int_8), DIMENSION(:), POINTER :: data_ptr
331
332#if defined(__LIBTORCH)
333 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
334 TYPE(c_ptr) :: data_ptr_c
335
336 INTERFACE
337 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
338 BIND(C, name="torch_c_tensor_data_ptr_int64")
339 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
340 TYPE(c_ptr), VALUE :: tensor
341 INTEGER(kind=C_INT), VALUE :: ndims
342 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
343 TYPE(c_ptr) :: data_ptr
344 END SUBROUTINE torch_c_tensor_data_ptr_int64
345 END INTERFACE
346
347 sizes_c(:) = -1
348 data_ptr_c = c_null_ptr
349 cpassert(c_associated(tensor%c_ptr))
350 cpassert(.NOT. ASSOCIATED(data_ptr))
351 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
352 ndims=1, &
353 sizes=sizes_c, &
354 data_ptr=data_ptr_c)
355
356 cpassert(all(sizes_c >= 0))
357 cpassert(c_associated(data_ptr_c))
358
359 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
360 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
361#else
362 cpabort("CP2K compiled without the Torch library.")
363 mark_used(tensor)
364 mark_used(data_ptr)
365#endif
366 END SUBROUTINE torch_tensor_data_ptr_int64_1d
367
368
369! **************************************************************************************************
370!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
371!> The source must be an ALLOCATABLE to prevent passing a temporary array.
372!> \author Ole Schuett
373! **************************************************************************************************
374 SUBROUTINE torch_tensor_from_array_double_1d(tensor, source, requires_grad)
375 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
376 REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
377 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
378
379#if defined(__LIBTORCH)
380 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
381 LOGICAL :: my_req_grad
382
383 INTERFACE
384 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
385 BIND(C, name="torch_c_tensor_from_array_double")
386 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
387 TYPE(c_ptr) :: tensor
388 LOGICAL(kind=C_BOOL), VALUE :: req_grad
389 INTEGER(kind=C_INT), VALUE :: ndims
390 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
391 REAL(kind=c_double), DIMENSION(*) :: source
392 END SUBROUTINE torch_c_tensor_from_array_double
393 END INTERFACE
394
395 my_req_grad = .false.
396 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
397
398 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
399
400 cpassert(.NOT. c_associated(tensor%c_ptr))
401 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
402 req_grad=LOGICAL(my_req_grad, C_BOOL), &
403 ndims=1, &
404 sizes=sizes_c, &
405 source=source)
406 cpassert(c_associated(tensor%c_ptr))
407#else
408 cpabort("CP2K compiled without the Torch library.")
409 mark_used(tensor)
410 mark_used(source)
411 mark_used(requires_grad)
412#endif
413 END SUBROUTINE torch_tensor_from_array_double_1d
414
415! **************************************************************************************************
416!> \brief Copies data from a Torch tensor to an array.
417!> The returned pointer is only valide during the tensor's lifetime!
418!> \author Ole Schuett
419! **************************************************************************************************
420 SUBROUTINE torch_tensor_data_ptr_double_1d(tensor, data_ptr)
421 TYPE(torch_tensor_type), INTENT(IN) :: tensor
422 REAL(dp), DIMENSION(:), POINTER :: data_ptr
423
424#if defined(__LIBTORCH)
425 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
426 TYPE(c_ptr) :: data_ptr_c
427
428 INTERFACE
429 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
430 BIND(C, name="torch_c_tensor_data_ptr_double")
431 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
432 TYPE(c_ptr), VALUE :: tensor
433 INTEGER(kind=C_INT), VALUE :: ndims
434 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
435 TYPE(c_ptr) :: data_ptr
436 END SUBROUTINE torch_c_tensor_data_ptr_double
437 END INTERFACE
438
439 sizes_c(:) = -1
440 data_ptr_c = c_null_ptr
441 cpassert(c_associated(tensor%c_ptr))
442 cpassert(.NOT. ASSOCIATED(data_ptr))
443 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
444 ndims=1, &
445 sizes=sizes_c, &
446 data_ptr=data_ptr_c)
447
448 cpassert(all(sizes_c >= 0))
449 cpassert(c_associated(data_ptr_c))
450
451 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
452 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
453#else
454 cpabort("CP2K compiled without the Torch library.")
455 mark_used(tensor)
456 mark_used(data_ptr)
457#endif
458 END SUBROUTINE torch_tensor_data_ptr_double_1d
459
460
461! **************************************************************************************************
462!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
463!> The source must be an ALLOCATABLE to prevent passing a temporary array.
464!> \author Ole Schuett
465! **************************************************************************************************
466 SUBROUTINE torch_tensor_from_array_int32_2d(tensor, source, requires_grad)
467 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
468 INTEGER(kind=int_4), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
469 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
470
471#if defined(__LIBTORCH)
472 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
473 LOGICAL :: my_req_grad
474
475 INTERFACE
476 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
477 BIND(C, name="torch_c_tensor_from_array_int32")
478 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
479 TYPE(c_ptr) :: tensor
480 LOGICAL(kind=C_BOOL), VALUE :: req_grad
481 INTEGER(kind=C_INT), VALUE :: ndims
482 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
483 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
484 END SUBROUTINE torch_c_tensor_from_array_int32
485 END INTERFACE
486
487 my_req_grad = .false.
488 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
489
490 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
491 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
492
493 cpassert(.NOT. c_associated(tensor%c_ptr))
494 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
495 req_grad=LOGICAL(my_req_grad, C_BOOL), &
496 ndims=2, &
497 sizes=sizes_c, &
498 source=source)
499 cpassert(c_associated(tensor%c_ptr))
500#else
501 cpabort("CP2K compiled without the Torch library.")
502 mark_used(tensor)
503 mark_used(source)
504 mark_used(requires_grad)
505#endif
506 END SUBROUTINE torch_tensor_from_array_int32_2d
507
508! **************************************************************************************************
509!> \brief Copies data from a Torch tensor to an array.
510!> The returned pointer is only valide during the tensor's lifetime!
511!> \author Ole Schuett
512! **************************************************************************************************
513 SUBROUTINE torch_tensor_data_ptr_int32_2d(tensor, data_ptr)
514 TYPE(torch_tensor_type), INTENT(IN) :: tensor
515 INTEGER(kind=int_4), DIMENSION(:, :), POINTER :: data_ptr
516
517#if defined(__LIBTORCH)
518 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
519 TYPE(c_ptr) :: data_ptr_c
520
521 INTERFACE
522 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
523 BIND(C, name="torch_c_tensor_data_ptr_int32")
524 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
525 TYPE(c_ptr), VALUE :: tensor
526 INTEGER(kind=C_INT), VALUE :: ndims
527 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
528 TYPE(c_ptr) :: data_ptr
529 END SUBROUTINE torch_c_tensor_data_ptr_int32
530 END INTERFACE
531
532 sizes_c(:) = -1
533 data_ptr_c = c_null_ptr
534 cpassert(c_associated(tensor%c_ptr))
535 cpassert(.NOT. ASSOCIATED(data_ptr))
536 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
537 ndims=2, &
538 sizes=sizes_c, &
539 data_ptr=data_ptr_c)
540
541 cpassert(all(sizes_c >= 0))
542 cpassert(c_associated(data_ptr_c))
543
544 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
545 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
546 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
547#else
548 cpabort("CP2K compiled without the Torch library.")
549 mark_used(tensor)
550 mark_used(data_ptr)
551#endif
552 END SUBROUTINE torch_tensor_data_ptr_int32_2d
553
554
555! **************************************************************************************************
556!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
557!> The source must be an ALLOCATABLE to prevent passing a temporary array.
558!> \author Ole Schuett
559! **************************************************************************************************
560 SUBROUTINE torch_tensor_from_array_float_2d(tensor, source, requires_grad)
561 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
562 REAL(sp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
563 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
564
565#if defined(__LIBTORCH)
566 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
567 LOGICAL :: my_req_grad
568
569 INTERFACE
570 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
571 BIND(C, name="torch_c_tensor_from_array_float")
572 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
573 TYPE(c_ptr) :: tensor
574 LOGICAL(kind=C_BOOL), VALUE :: req_grad
575 INTEGER(kind=C_INT), VALUE :: ndims
576 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
577 REAL(kind=c_float), DIMENSION(*) :: source
578 END SUBROUTINE torch_c_tensor_from_array_float
579 END INTERFACE
580
581 my_req_grad = .false.
582 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
583
584 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
585 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
586
587 cpassert(.NOT. c_associated(tensor%c_ptr))
588 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
589 req_grad=LOGICAL(my_req_grad, C_BOOL), &
590 ndims=2, &
591 sizes=sizes_c, &
592 source=source)
593 cpassert(c_associated(tensor%c_ptr))
594#else
595 cpabort("CP2K compiled without the Torch library.")
596 mark_used(tensor)
597 mark_used(source)
598 mark_used(requires_grad)
599#endif
600 END SUBROUTINE torch_tensor_from_array_float_2d
601
602! **************************************************************************************************
603!> \brief Copies data from a Torch tensor to an array.
604!> The returned pointer is only valide during the tensor's lifetime!
605!> \author Ole Schuett
606! **************************************************************************************************
607 SUBROUTINE torch_tensor_data_ptr_float_2d(tensor, data_ptr)
608 TYPE(torch_tensor_type), INTENT(IN) :: tensor
609 REAL(sp), DIMENSION(:, :), POINTER :: data_ptr
610
611#if defined(__LIBTORCH)
612 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
613 TYPE(c_ptr) :: data_ptr_c
614
615 INTERFACE
616 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
617 BIND(C, name="torch_c_tensor_data_ptr_float")
618 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
619 TYPE(c_ptr), VALUE :: tensor
620 INTEGER(kind=C_INT), VALUE :: ndims
621 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
622 TYPE(c_ptr) :: data_ptr
623 END SUBROUTINE torch_c_tensor_data_ptr_float
624 END INTERFACE
625
626 sizes_c(:) = -1
627 data_ptr_c = c_null_ptr
628 cpassert(c_associated(tensor%c_ptr))
629 cpassert(.NOT. ASSOCIATED(data_ptr))
630 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
631 ndims=2, &
632 sizes=sizes_c, &
633 data_ptr=data_ptr_c)
634
635 cpassert(all(sizes_c >= 0))
636 cpassert(c_associated(data_ptr_c))
637
638 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
639 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
640 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
641#else
642 cpabort("CP2K compiled without the Torch library.")
643 mark_used(tensor)
644 mark_used(data_ptr)
645#endif
646 END SUBROUTINE torch_tensor_data_ptr_float_2d
647
648
649! **************************************************************************************************
650!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
651!> The source must be an ALLOCATABLE to prevent passing a temporary array.
652!> \author Ole Schuett
653! **************************************************************************************************
654 SUBROUTINE torch_tensor_from_array_int64_2d(tensor, source, requires_grad)
655 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
656 INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
657 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
658
659#if defined(__LIBTORCH)
660 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
661 LOGICAL :: my_req_grad
662
663 INTERFACE
664 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
665 BIND(C, name="torch_c_tensor_from_array_int64")
666 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
667 TYPE(c_ptr) :: tensor
668 LOGICAL(kind=C_BOOL), VALUE :: req_grad
669 INTEGER(kind=C_INT), VALUE :: ndims
670 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
671 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
672 END SUBROUTINE torch_c_tensor_from_array_int64
673 END INTERFACE
674
675 my_req_grad = .false.
676 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
677
678 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
679 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
680
681 cpassert(.NOT. c_associated(tensor%c_ptr))
682 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
683 req_grad=LOGICAL(my_req_grad, C_BOOL), &
684 ndims=2, &
685 sizes=sizes_c, &
686 source=source)
687 cpassert(c_associated(tensor%c_ptr))
688#else
689 cpabort("CP2K compiled without the Torch library.")
690 mark_used(tensor)
691 mark_used(source)
692 mark_used(requires_grad)
693#endif
694 END SUBROUTINE torch_tensor_from_array_int64_2d
695
696! **************************************************************************************************
697!> \brief Copies data from a Torch tensor to an array.
698!> The returned pointer is only valide during the tensor's lifetime!
699!> \author Ole Schuett
700! **************************************************************************************************
701 SUBROUTINE torch_tensor_data_ptr_int64_2d(tensor, data_ptr)
702 TYPE(torch_tensor_type), INTENT(IN) :: tensor
703 INTEGER(kind=int_8), DIMENSION(:, :), POINTER :: data_ptr
704
705#if defined(__LIBTORCH)
706 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
707 TYPE(c_ptr) :: data_ptr_c
708
709 INTERFACE
710 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
711 BIND(C, name="torch_c_tensor_data_ptr_int64")
712 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
713 TYPE(c_ptr), VALUE :: tensor
714 INTEGER(kind=C_INT), VALUE :: ndims
715 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
716 TYPE(c_ptr) :: data_ptr
717 END SUBROUTINE torch_c_tensor_data_ptr_int64
718 END INTERFACE
719
720 sizes_c(:) = -1
721 data_ptr_c = c_null_ptr
722 cpassert(c_associated(tensor%c_ptr))
723 cpassert(.NOT. ASSOCIATED(data_ptr))
724 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
725 ndims=2, &
726 sizes=sizes_c, &
727 data_ptr=data_ptr_c)
728
729 cpassert(all(sizes_c >= 0))
730 cpassert(c_associated(data_ptr_c))
731
732 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
733 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
734 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
735#else
736 cpabort("CP2K compiled without the Torch library.")
737 mark_used(tensor)
738 mark_used(data_ptr)
739#endif
740 END SUBROUTINE torch_tensor_data_ptr_int64_2d
741
742
743! **************************************************************************************************
744!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
745!> The source must be an ALLOCATABLE to prevent passing a temporary array.
746!> \author Ole Schuett
747! **************************************************************************************************
748 SUBROUTINE torch_tensor_from_array_double_2d(tensor, source, requires_grad)
749 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
750 REAL(dp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
751 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
752
753#if defined(__LIBTORCH)
754 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
755 LOGICAL :: my_req_grad
756
757 INTERFACE
758 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
759 BIND(C, name="torch_c_tensor_from_array_double")
760 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
761 TYPE(c_ptr) :: tensor
762 LOGICAL(kind=C_BOOL), VALUE :: req_grad
763 INTEGER(kind=C_INT), VALUE :: ndims
764 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
765 REAL(kind=c_double), DIMENSION(*) :: source
766 END SUBROUTINE torch_c_tensor_from_array_double
767 END INTERFACE
768
769 my_req_grad = .false.
770 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
771
772 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
773 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
774
775 cpassert(.NOT. c_associated(tensor%c_ptr))
776 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
777 req_grad=LOGICAL(my_req_grad, C_BOOL), &
778 ndims=2, &
779 sizes=sizes_c, &
780 source=source)
781 cpassert(c_associated(tensor%c_ptr))
782#else
783 cpabort("CP2K compiled without the Torch library.")
784 mark_used(tensor)
785 mark_used(source)
786 mark_used(requires_grad)
787#endif
788 END SUBROUTINE torch_tensor_from_array_double_2d
789
790! **************************************************************************************************
791!> \brief Copies data from a Torch tensor to an array.
792!> The returned pointer is only valide during the tensor's lifetime!
793!> \author Ole Schuett
794! **************************************************************************************************
795 SUBROUTINE torch_tensor_data_ptr_double_2d(tensor, data_ptr)
796 TYPE(torch_tensor_type), INTENT(IN) :: tensor
797 REAL(dp), DIMENSION(:, :), POINTER :: data_ptr
798
799#if defined(__LIBTORCH)
800 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
801 TYPE(c_ptr) :: data_ptr_c
802
803 INTERFACE
804 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
805 BIND(C, name="torch_c_tensor_data_ptr_double")
806 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
807 TYPE(c_ptr), VALUE :: tensor
808 INTEGER(kind=C_INT), VALUE :: ndims
809 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
810 TYPE(c_ptr) :: data_ptr
811 END SUBROUTINE torch_c_tensor_data_ptr_double
812 END INTERFACE
813
814 sizes_c(:) = -1
815 data_ptr_c = c_null_ptr
816 cpassert(c_associated(tensor%c_ptr))
817 cpassert(.NOT. ASSOCIATED(data_ptr))
818 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
819 ndims=2, &
820 sizes=sizes_c, &
821 data_ptr=data_ptr_c)
822
823 cpassert(all(sizes_c >= 0))
824 cpassert(c_associated(data_ptr_c))
825
826 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
827 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
828 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
829#else
830 cpabort("CP2K compiled without the Torch library.")
831 mark_used(tensor)
832 mark_used(data_ptr)
833#endif
834 END SUBROUTINE torch_tensor_data_ptr_double_2d
835
836
837! **************************************************************************************************
838!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
839!> The source must be an ALLOCATABLE to prevent passing a temporary array.
840!> \author Ole Schuett
841! **************************************************************************************************
842 SUBROUTINE torch_tensor_from_array_int32_3d(tensor, source, requires_grad)
843 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
844 INTEGER(kind=int_4), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
845 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
846
847#if defined(__LIBTORCH)
848 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
849 LOGICAL :: my_req_grad
850
851 INTERFACE
852 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
853 BIND(C, name="torch_c_tensor_from_array_int32")
854 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
855 TYPE(c_ptr) :: tensor
856 LOGICAL(kind=C_BOOL), VALUE :: req_grad
857 INTEGER(kind=C_INT), VALUE :: ndims
858 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
859 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
860 END SUBROUTINE torch_c_tensor_from_array_int32
861 END INTERFACE
862
863 my_req_grad = .false.
864 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
865
866 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
867 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
868 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
869
870 cpassert(.NOT. c_associated(tensor%c_ptr))
871 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
872 req_grad=LOGICAL(my_req_grad, C_BOOL), &
873 ndims=3, &
874 sizes=sizes_c, &
875 source=source)
876 cpassert(c_associated(tensor%c_ptr))
877#else
878 cpabort("CP2K compiled without the Torch library.")
879 mark_used(tensor)
880 mark_used(source)
881 mark_used(requires_grad)
882#endif
883 END SUBROUTINE torch_tensor_from_array_int32_3d
884
885! **************************************************************************************************
886!> \brief Copies data from a Torch tensor to an array.
887!> The returned pointer is only valide during the tensor's lifetime!
888!> \author Ole Schuett
889! **************************************************************************************************
890 SUBROUTINE torch_tensor_data_ptr_int32_3d(tensor, data_ptr)
891 TYPE(torch_tensor_type), INTENT(IN) :: tensor
892 INTEGER(kind=int_4), DIMENSION(:, :, :), POINTER :: data_ptr
893
894#if defined(__LIBTORCH)
895 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
896 TYPE(c_ptr) :: data_ptr_c
897
898 INTERFACE
899 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
900 BIND(C, name="torch_c_tensor_data_ptr_int32")
901 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
902 TYPE(c_ptr), VALUE :: tensor
903 INTEGER(kind=C_INT), VALUE :: ndims
904 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
905 TYPE(c_ptr) :: data_ptr
906 END SUBROUTINE torch_c_tensor_data_ptr_int32
907 END INTERFACE
908
909 sizes_c(:) = -1
910 data_ptr_c = c_null_ptr
911 cpassert(c_associated(tensor%c_ptr))
912 cpassert(.NOT. ASSOCIATED(data_ptr))
913 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
914 ndims=3, &
915 sizes=sizes_c, &
916 data_ptr=data_ptr_c)
917
918 cpassert(all(sizes_c >= 0))
919 cpassert(c_associated(data_ptr_c))
920
921 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
922 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
923 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
924 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
925#else
926 cpabort("CP2K compiled without the Torch library.")
927 mark_used(tensor)
928 mark_used(data_ptr)
929#endif
930 END SUBROUTINE torch_tensor_data_ptr_int32_3d
931
932
933! **************************************************************************************************
934!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
935!> The source must be an ALLOCATABLE to prevent passing a temporary array.
936!> \author Ole Schuett
937! **************************************************************************************************
938 SUBROUTINE torch_tensor_from_array_float_3d(tensor, source, requires_grad)
939 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
940 REAL(sp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
941 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
942
943#if defined(__LIBTORCH)
944 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
945 LOGICAL :: my_req_grad
946
947 INTERFACE
948 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
949 BIND(C, name="torch_c_tensor_from_array_float")
950 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
951 TYPE(c_ptr) :: tensor
952 LOGICAL(kind=C_BOOL), VALUE :: req_grad
953 INTEGER(kind=C_INT), VALUE :: ndims
954 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
955 REAL(kind=c_float), DIMENSION(*) :: source
956 END SUBROUTINE torch_c_tensor_from_array_float
957 END INTERFACE
958
959 my_req_grad = .false.
960 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
961
962 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
963 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
964 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
965
966 cpassert(.NOT. c_associated(tensor%c_ptr))
967 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
968 req_grad=LOGICAL(my_req_grad, C_BOOL), &
969 ndims=3, &
970 sizes=sizes_c, &
971 source=source)
972 cpassert(c_associated(tensor%c_ptr))
973#else
974 cpabort("CP2K compiled without the Torch library.")
975 mark_used(tensor)
976 mark_used(source)
977 mark_used(requires_grad)
978#endif
979 END SUBROUTINE torch_tensor_from_array_float_3d
980
981! **************************************************************************************************
982!> \brief Copies data from a Torch tensor to an array.
983!> The returned pointer is only valide during the tensor's lifetime!
984!> \author Ole Schuett
985! **************************************************************************************************
986 SUBROUTINE torch_tensor_data_ptr_float_3d(tensor, data_ptr)
987 TYPE(torch_tensor_type), INTENT(IN) :: tensor
988 REAL(sp), DIMENSION(:, :, :), POINTER :: data_ptr
989
990#if defined(__LIBTORCH)
991 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
992 TYPE(c_ptr) :: data_ptr_c
993
994 INTERFACE
995 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
996 BIND(C, name="torch_c_tensor_data_ptr_float")
997 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
998 TYPE(c_ptr), VALUE :: tensor
999 INTEGER(kind=C_INT), VALUE :: ndims
1000 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1001 TYPE(c_ptr) :: data_ptr
1002 END SUBROUTINE torch_c_tensor_data_ptr_float
1003 END INTERFACE
1004
1005 sizes_c(:) = -1
1006 data_ptr_c = c_null_ptr
1007 cpassert(c_associated(tensor%c_ptr))
1008 cpassert(.NOT. ASSOCIATED(data_ptr))
1009 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
1010 ndims=3, &
1011 sizes=sizes_c, &
1012 data_ptr=data_ptr_c)
1013
1014 cpassert(all(sizes_c >= 0))
1015 cpassert(c_associated(data_ptr_c))
1016
1017 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1018 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1019 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1020 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1021#else
1022 cpabort("CP2K compiled without the Torch library.")
1023 mark_used(tensor)
1024 mark_used(data_ptr)
1025#endif
1026 END SUBROUTINE torch_tensor_data_ptr_float_3d
1027
1028
1029! **************************************************************************************************
1030!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
1031!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1032!> \author Ole Schuett
1033! **************************************************************************************************
1034 SUBROUTINE torch_tensor_from_array_int64_3d(tensor, source, requires_grad)
1035 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1036 INTEGER(kind=int_8), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
1037 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1038
1039#if defined(__LIBTORCH)
1040 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
1041 LOGICAL :: my_req_grad
1042
1043 INTERFACE
1044 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
1045 BIND(C, name="torch_c_tensor_from_array_int64")
1046 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
1047 TYPE(c_ptr) :: tensor
1048 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1049 INTEGER(kind=C_INT), VALUE :: ndims
1050 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1051 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
1052 END SUBROUTINE torch_c_tensor_from_array_int64
1053 END INTERFACE
1054
1055 my_req_grad = .false.
1056 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1057
1058 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
1059 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
1060 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
1061
1062 cpassert(.NOT. c_associated(tensor%c_ptr))
1063 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
1064 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1065 ndims=3, &
1066 sizes=sizes_c, &
1067 source=source)
1068 cpassert(c_associated(tensor%c_ptr))
1069#else
1070 cpabort("CP2K compiled without the Torch library.")
1071 mark_used(tensor)
1072 mark_used(source)
1073 mark_used(requires_grad)
1074#endif
1075 END SUBROUTINE torch_tensor_from_array_int64_3d
1076
1077! **************************************************************************************************
1078!> \brief Copies data from a Torch tensor to an array.
1079!> The returned pointer is only valide during the tensor's lifetime!
1080!> \author Ole Schuett
1081! **************************************************************************************************
1082 SUBROUTINE torch_tensor_data_ptr_int64_3d(tensor, data_ptr)
1083 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1084 INTEGER(kind=int_8), DIMENSION(:, :, :), POINTER :: data_ptr
1085
1086#if defined(__LIBTORCH)
1087 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
1088 TYPE(c_ptr) :: data_ptr_c
1089
1090 INTERFACE
1091 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
1092 BIND(C, name="torch_c_tensor_data_ptr_int64")
1093 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
1094 TYPE(c_ptr), VALUE :: tensor
1095 INTEGER(kind=C_INT), VALUE :: ndims
1096 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1097 TYPE(c_ptr) :: data_ptr
1098 END SUBROUTINE torch_c_tensor_data_ptr_int64
1099 END INTERFACE
1100
1101 sizes_c(:) = -1
1102 data_ptr_c = c_null_ptr
1103 cpassert(c_associated(tensor%c_ptr))
1104 cpassert(.NOT. ASSOCIATED(data_ptr))
1105 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
1106 ndims=3, &
1107 sizes=sizes_c, &
1108 data_ptr=data_ptr_c)
1109
1110 cpassert(all(sizes_c >= 0))
1111 cpassert(c_associated(data_ptr_c))
1112
1113 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1114 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1115 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1116 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1117#else
1118 cpabort("CP2K compiled without the Torch library.")
1119 mark_used(tensor)
1120 mark_used(data_ptr)
1121#endif
1122 END SUBROUTINE torch_tensor_data_ptr_int64_3d
1123
1124
1125! **************************************************************************************************
1126!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
1127!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1128!> \author Ole Schuett
1129! **************************************************************************************************
1130 SUBROUTINE torch_tensor_from_array_double_3d(tensor, source, requires_grad)
1131 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1132 REAL(dp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
1133 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1134
1135#if defined(__LIBTORCH)
1136 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
1137 LOGICAL :: my_req_grad
1138
1139 INTERFACE
1140 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
1141 BIND(C, name="torch_c_tensor_from_array_double")
1142 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
1143 TYPE(c_ptr) :: tensor
1144 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1145 INTEGER(kind=C_INT), VALUE :: ndims
1146 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1147 REAL(kind=c_double), DIMENSION(*) :: source
1148 END SUBROUTINE torch_c_tensor_from_array_double
1149 END INTERFACE
1150
1151 my_req_grad = .false.
1152 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1153
1154 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
1155 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
1156 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
1157
1158 cpassert(.NOT. c_associated(tensor%c_ptr))
1159 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
1160 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1161 ndims=3, &
1162 sizes=sizes_c, &
1163 source=source)
1164 cpassert(c_associated(tensor%c_ptr))
1165#else
1166 cpabort("CP2K compiled without the Torch library.")
1167 mark_used(tensor)
1168 mark_used(source)
1169 mark_used(requires_grad)
1170#endif
1171 END SUBROUTINE torch_tensor_from_array_double_3d
1172
1173! **************************************************************************************************
1174!> \brief Copies data from a Torch tensor to an array.
1175!> The returned pointer is only valide during the tensor's lifetime!
1176!> \author Ole Schuett
1177! **************************************************************************************************
1178 SUBROUTINE torch_tensor_data_ptr_double_3d(tensor, data_ptr)
1179 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1180 REAL(dp), DIMENSION(:, :, :), POINTER :: data_ptr
1181
1182#if defined(__LIBTORCH)
1183 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
1184 TYPE(c_ptr) :: data_ptr_c
1185
1186 INTERFACE
1187 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
1188 BIND(C, name="torch_c_tensor_data_ptr_double")
1189 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
1190 TYPE(c_ptr), VALUE :: tensor
1191 INTEGER(kind=C_INT), VALUE :: ndims
1192 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1193 TYPE(c_ptr) :: data_ptr
1194 END SUBROUTINE torch_c_tensor_data_ptr_double
1195 END INTERFACE
1196
1197 sizes_c(:) = -1
1198 data_ptr_c = c_null_ptr
1199 cpassert(c_associated(tensor%c_ptr))
1200 cpassert(.NOT. ASSOCIATED(data_ptr))
1201 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
1202 ndims=3, &
1203 sizes=sizes_c, &
1204 data_ptr=data_ptr_c)
1205
1206 cpassert(all(sizes_c >= 0))
1207 cpassert(c_associated(data_ptr_c))
1208
1209 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1210 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1211 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1212 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1213#else
1214 cpabort("CP2K compiled without the Torch library.")
1215 mark_used(tensor)
1216 mark_used(data_ptr)
1217#endif
1218 END SUBROUTINE torch_tensor_data_ptr_double_3d
1219
1220
1221! **************************************************************************************************
1222!> \brief Runs autograd on a Torch tensor.
1223!> \author Ole Schuett
1224! **************************************************************************************************
1225 SUBROUTINE torch_tensor_backward(tensor, outer_grad)
1226 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1227 TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
1228
1229#if defined(__LIBTORCH)
1230 CHARACTER(len=*), PARAMETER :: routinen = 'torch_tensor_backward'
1231 INTEGER :: handle
1232
1233 INTERFACE
1234 SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
1235 BIND(C, name="torch_c_tensor_backward")
1236 IMPORT :: c_char, c_ptr
1237 TYPE(c_ptr), VALUE :: tensor
1238 TYPE(c_ptr), VALUE :: outer_grad
1239 END SUBROUTINE torch_c_tensor_backward
1240 END INTERFACE
1241
1242 CALL timeset(routinen, handle)
1243 cpassert(c_associated(tensor%c_ptr))
1244 cpassert(c_associated(outer_grad%c_ptr))
1245 CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
1246 CALL timestop(handle)
1247#else
1248 cpabort("CP2K compiled without the Torch library.")
1249 mark_used(tensor)
1250 mark_used(outer_grad)
1251#endif
1252 END SUBROUTINE torch_tensor_backward
1253
1254! **************************************************************************************************
1255!> \brief Returns the gradient of a Torch tensor which was computed by autograd.
1256!> \author Ole Schuett
1257! **************************************************************************************************
1258 SUBROUTINE torch_tensor_grad(tensor, grad)
1259 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1260 TYPE(torch_tensor_type), INTENT(INOUT) :: grad
1261
1262#if defined(__LIBTORCH)
1263 INTERFACE
1264 SUBROUTINE torch_c_tensor_grad(tensor, grad) &
1265 BIND(C, name="torch_c_tensor_grad")
1266 IMPORT :: c_ptr
1267 TYPE(c_ptr), VALUE :: tensor
1268 TYPE(c_ptr) :: grad
1269 END SUBROUTINE torch_c_tensor_grad
1270 END INTERFACE
1271
1272 cpassert(c_associated(tensor%c_ptr))
1273 cpassert(.NOT. c_associated(grad%c_ptr))
1274 CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
1275 cpassert(c_associated(grad%c_ptr))
1276#else
1277 cpabort("CP2K compiled without the Torch library.")
1278 mark_used(tensor)
1279 mark_used(grad)
1280#endif
1281 END SUBROUTINE torch_tensor_grad
1282
1283! **************************************************************************************************
1284!> \brief Releases a Torch tensor and all its ressources.
1285!> \author Ole Schuett
1286! **************************************************************************************************
1287 SUBROUTINE torch_tensor_release(tensor)
1288 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1289
1290#if defined(__LIBTORCH)
1291 INTERFACE
1292 SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
1293 IMPORT :: c_ptr
1294 TYPE(c_ptr), VALUE :: tensor
1295 END SUBROUTINE torch_c_tensor_release
1296 END INTERFACE
1297
1298 cpassert(c_associated(tensor%c_ptr))
1299 CALL torch_c_tensor_release(tensor=tensor%c_ptr)
1300 tensor%c_ptr = c_null_ptr
1301#else
1302 cpabort("CP2K was compiled without Torch library.")
1303 mark_used(tensor)
1304#endif
1305 END SUBROUTINE torch_tensor_release
1306
1307! **************************************************************************************************
1308!> \brief Creates an empty Torch dictionary.
1309!> \author Ole Schuett
1310! **************************************************************************************************
1311 SUBROUTINE torch_dict_create(dict)
1312 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1313
1314#if defined(__LIBTORCH)
1315 INTERFACE
1316 SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
1317 IMPORT :: c_ptr
1318 TYPE(c_ptr) :: dict
1319 END SUBROUTINE torch_c_dict_create
1320 END INTERFACE
1321
1322 cpassert(.NOT. c_associated(dict%c_ptr))
1323 CALL torch_c_dict_create(dict=dict%c_ptr)
1324 cpassert(c_associated(dict%c_ptr))
1325#else
1326 cpabort("CP2K was compiled without Torch library.")
1327 mark_used(dict)
1328#endif
1329 END SUBROUTINE torch_dict_create
1330
1331! **************************************************************************************************
1332!> \brief Inserts a Torch tensor into a Torch dictionary.
1333!> \author Ole Schuett
1334! **************************************************************************************************
1335 SUBROUTINE torch_dict_insert(dict, key, tensor)
1336 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1337 CHARACTER(len=*), INTENT(IN) :: key
1338 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1339
1340#if defined(__LIBTORCH)
1341
1342 INTERFACE
1343 SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
1344 BIND(C, name="torch_c_dict_insert")
1345 IMPORT :: c_char, c_ptr
1346 TYPE(c_ptr), VALUE :: dict
1347 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1348 TYPE(c_ptr), VALUE :: tensor
1349 END SUBROUTINE torch_c_dict_insert
1350 END INTERFACE
1351
1352 cpassert(c_associated(dict%c_ptr))
1353 cpassert(c_associated(tensor%c_ptr))
1354 CALL torch_c_dict_insert(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1355#else
1356 cpabort("CP2K compiled without the Torch library.")
1357 mark_used(dict)
1358 mark_used(key)
1359 mark_used(tensor)
1360#endif
1361 END SUBROUTINE torch_dict_insert
1362
1363! **************************************************************************************************
1364!> \brief Retrieves a Torch tensor from a Torch dictionary.
1365!> \author Ole Schuett
1366! **************************************************************************************************
1367 SUBROUTINE torch_dict_get(dict, key, tensor)
1368 TYPE(torch_dict_type), INTENT(IN) :: dict
1369 CHARACTER(len=*), INTENT(IN) :: key
1370 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1371
1372#if defined(__LIBTORCH)
1373
1374 INTERFACE
1375 SUBROUTINE torch_c_dict_get(dict, key, tensor) &
1376 BIND(C, name="torch_c_dict_get")
1377 IMPORT :: c_char, c_ptr
1378 TYPE(c_ptr), VALUE :: dict
1379 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1380 TYPE(c_ptr) :: tensor
1381 END SUBROUTINE torch_c_dict_get
1382 END INTERFACE
1383
1384 cpassert(c_associated(dict%c_ptr))
1385 cpassert(.NOT. c_associated(tensor%c_ptr))
1386 CALL torch_c_dict_get(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1387 cpassert(c_associated(tensor%c_ptr))
1388
1389#else
1390 cpabort("CP2K compiled without the Torch library.")
1391 mark_used(dict)
1392 mark_used(key)
1393 mark_used(tensor)
1394#endif
1395 END SUBROUTINE torch_dict_get
1396
1397! **************************************************************************************************
1398!> \brief Releases a Torch dictionary and all its ressources.
1399!> \author Ole Schuett
1400! **************************************************************************************************
1401 SUBROUTINE torch_dict_release(dict)
1402 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1403
1404#if defined(__LIBTORCH)
1405 INTERFACE
1406 SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
1407 IMPORT :: c_ptr
1408 TYPE(c_ptr), VALUE :: dict
1409 END SUBROUTINE torch_c_dict_release
1410 END INTERFACE
1411
1412 cpassert(c_associated(dict%c_ptr))
1413 CALL torch_c_dict_release(dict=dict%c_ptr)
1414 dict%c_ptr = c_null_ptr
1415#else
1416 cpabort("CP2K was compiled without Torch library.")
1417 mark_used(dict)
1418#endif
1419 END SUBROUTINE torch_dict_release
1420
1421! **************************************************************************************************
1422!> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
1423!> \author Ole Schuett
1424! **************************************************************************************************
1425 SUBROUTINE torch_model_load(model, filename)
1426 TYPE(torch_model_type), INTENT(INOUT) :: model
1427 CHARACTER(len=*), INTENT(IN) :: filename
1428
1429#if defined(__LIBTORCH)
1430 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_load'
1431 INTEGER :: handle
1432
1433 INTERFACE
1434 SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
1435 IMPORT :: c_ptr, c_char
1436 TYPE(c_ptr) :: model
1437 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
1438 END SUBROUTINE torch_c_model_load
1439 END INTERFACE
1440
1441 CALL timeset(routinen, handle)
1442 cpassert(.NOT. c_associated(model%c_ptr))
1443 CALL torch_c_model_load(model=model%c_ptr, filename=trim(filename)//c_null_char)
1444 cpassert(c_associated(model%c_ptr))
1445 CALL timestop(handle)
1446#else
1447 cpabort("CP2K was compiled without Torch library.")
1448 mark_used(model)
1449 mark_used(filename)
1450#endif
1451 END SUBROUTINE torch_model_load
1452
1453! **************************************************************************************************
1454!> \brief Evaluates the given Torch model.
1455!> \author Ole Schuett
1456! **************************************************************************************************
1457 SUBROUTINE torch_model_forward(model, inputs, outputs)
1458 TYPE(torch_model_type), INTENT(INOUT) :: model
1459 TYPE(torch_dict_type), INTENT(IN) :: inputs
1460 TYPE(torch_dict_type), INTENT(INOUT) :: outputs
1461
1462#if defined(__LIBTORCH)
1463 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_forward'
1464 INTEGER :: handle
1465
1466 INTERFACE
1467 SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
1468 IMPORT :: c_ptr
1469 TYPE(c_ptr), VALUE :: model
1470 TYPE(c_ptr), VALUE :: inputs
1471 TYPE(c_ptr), VALUE :: outputs
1472 END SUBROUTINE torch_c_model_forward
1473 END INTERFACE
1474
1475 CALL timeset(routinen, handle)
1476 cpassert(c_associated(model%c_ptr))
1477 cpassert(c_associated(inputs%c_ptr))
1478 cpassert(c_associated(outputs%c_ptr))
1479 CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
1480 CALL timestop(handle)
1481#else
1482 cpabort("CP2K was compiled without Torch library.")
1483 mark_used(model)
1484 mark_used(inputs)
1485 mark_used(outputs)
1486#endif
1487 END SUBROUTINE torch_model_forward
1488
1489! **************************************************************************************************
1490!> \brief Releases a Torch model and all its ressources.
1491!> \author Ole Schuett
1492! **************************************************************************************************
1493 SUBROUTINE torch_model_release(model)
1494 TYPE(torch_model_type), INTENT(INOUT) :: model
1495
1496#if defined(__LIBTORCH)
1497 INTERFACE
1498 SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
1499 IMPORT :: c_ptr
1500 TYPE(c_ptr), VALUE :: model
1501 END SUBROUTINE torch_c_model_release
1502 END INTERFACE
1503
1504 cpassert(c_associated(model%c_ptr))
1505 CALL torch_c_model_release(model=model%c_ptr)
1506 model%c_ptr = c_null_ptr
1507#else
1508 cpabort("CP2K was compiled without Torch library.")
1509 mark_used(model)
1510#endif
1511 END SUBROUTINE torch_model_release
1512
1513! **************************************************************************************************
1514!> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
1515!> \author Ole Schuett
1516! **************************************************************************************************
1517 FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
1518 CHARACTER(len=*), INTENT(IN) :: filename, key
1519 CHARACTER(:), ALLOCATABLE :: res
1520
1521#if defined(__LIBTORCH)
1522 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_read_metadata'
1523 INTEGER :: handle
1524
1525 CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
1526 POINTER :: content_f
1527 INTEGER :: i
1528 INTEGER :: length
1529 TYPE(c_ptr) :: content_c
1530
1531 INTERFACE
1532 SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
1533 BIND(C, name="torch_c_model_read_metadata")
1534 IMPORT :: c_char, c_ptr, c_int
1535 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
1536 TYPE(c_ptr) :: content
1537 INTEGER(kind=C_INT) :: length
1538 END SUBROUTINE torch_c_model_read_metadata
1539 END INTERFACE
1540
1541 CALL timeset(routinen, handle)
1542 content_c = c_null_ptr
1543 length = -1
1544 CALL torch_c_model_read_metadata(filename=trim(filename)//c_null_char, &
1545 key=trim(key)//c_null_char, &
1546 content=content_c, &
1547 length=length)
1548 cpassert(c_associated(content_c))
1549 cpassert(length >= 0)
1550
1551 CALL c_f_pointer(content_c, content_f, shape=(/length + 1/))
1552 cpassert(content_f(length + 1) == c_null_char)
1553
1554 ALLOCATE (CHARACTER(LEN=length) :: res)
1555 DO i = 1, length
1556 cpassert(content_f(i) /= c_null_char)
1557 res(i:i) = content_f(i)
1558 END DO
1559
1560 DEALLOCATE (content_f) ! Was allocated on the C side.
1561 CALL timestop(handle)
1562#else
1563 res = ""
1564 mark_used(filename)
1565 mark_used(key)
1566 cpabort("CP2K was compiled without Torch library.")
1567#endif
1568 END FUNCTION torch_model_read_metadata
1569
1570! **************************************************************************************************
1571!> \brief Returns true iff the Torch CUDA backend is available.
1572!> \author Ole Schuett
1573! **************************************************************************************************
1574 FUNCTION torch_cuda_is_available() RESULT(res)
1575 LOGICAL :: res
1576
1577#if defined(__LIBTORCH)
1578 INTERFACE
1579 FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
1580 IMPORT :: c_bool
1581 LOGICAL(C_BOOL) :: torch_c_cuda_is_available
1582 END FUNCTION torch_c_cuda_is_available
1583 END INTERFACE
1584
1585 res = torch_c_cuda_is_available()
1586#else
1587 cpabort("CP2K was compiled without Torch library.")
1588 res = .false.
1589#endif
1590 END FUNCTION torch_cuda_is_available
1591
1592! **************************************************************************************************
1593!> \brief Set whether to allow the use of TF32.
1594!> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
1595!> See https://pytorch.org/docs/stable/notes/cuda.html
1596!> \author Gabriele Tocci
1597! **************************************************************************************************
1598 SUBROUTINE torch_allow_tf32(allow_tf32)
1599 LOGICAL, INTENT(IN) :: allow_tf32
1600
1601#if defined(__LIBTORCH)
1602 INTERFACE
1603 SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
1604 IMPORT :: c_bool
1605 LOGICAL(C_BOOL), VALUE :: allow_tf32
1606 END SUBROUTINE torch_c_allow_tf32
1607 END INTERFACE
1608
1609 CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, c_bool))
1610#else
1611 cpabort("CP2K was compiled without Torch library.")
1612 mark_used(allow_tf32)
1613#endif
1614 END SUBROUTINE torch_allow_tf32
1615
1616! **************************************************************************************************
1617!> \brief Freeze the given Torch model: applies generic optimization that speed up model.
1618!> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
1619!> \author Gabriele Tocci
1620! **************************************************************************************************
1621 SUBROUTINE torch_model_freeze(model)
1622 TYPE(torch_model_type), INTENT(INOUT) :: model
1623
1624#if defined(__LIBTORCH)
1625 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_freeze'
1626 INTEGER :: handle
1627
1628 INTERFACE
1629 SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
1630 IMPORT :: c_ptr
1631 TYPE(c_ptr), VALUE :: model
1632 END SUBROUTINE torch_c_model_freeze
1633 END INTERFACE
1634
1635 CALL timeset(routinen, handle)
1636 cpassert(c_associated(model%c_ptr))
1637 CALL torch_c_model_freeze(model=model%c_ptr)
1638 CALL timestop(handle)
1639#else
1640 cpabort("CP2K was compiled without Torch library.")
1641 mark_used(model)
1642#endif
1643 END SUBROUTINE torch_model_freeze
1644
1645
1646! **************************************************************************************************
1647!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1648!> \author Ole Schuett
1649! **************************************************************************************************
1650 SUBROUTINE torch_model_get_attr_int64 (model, key, dest)
1651 TYPE(torch_model_type), INTENT(IN) :: model
1652 CHARACTER(len=*), INTENT(IN) :: key
1653 INTEGER(kind=int_8), INTENT(OUT) :: dest
1654
1655#if defined(__LIBTORCH)
1656
1657 INTERFACE
1658 SUBROUTINE torch_c_model_get_attr_int64 (model, key, dest) &
1659 BIND(C, name="torch_c_model_get_attr_int64")
1660 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1661 TYPE(c_ptr), VALUE :: model
1662 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1663 INTEGER(kind=C_INT64_T) :: dest
1664 END SUBROUTINE torch_c_model_get_attr_int64
1665 END INTERFACE
1666
1667 CALL torch_c_model_get_attr_int64 (model=model%c_ptr, &
1668 key=trim(key)//c_null_char, &
1669 dest=dest)
1670#else
1671 dest = 0
1672 mark_used(model)
1673 mark_used(key)
1674 cpabort("CP2K compiled without the Torch library.")
1675#endif
1676 END SUBROUTINE torch_model_get_attr_int64
1677! **************************************************************************************************
1678!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1679!> \author Ole Schuett
1680! **************************************************************************************************
1681 SUBROUTINE torch_model_get_attr_double (model, key, dest)
1682 TYPE(torch_model_type), INTENT(IN) :: model
1683 CHARACTER(len=*), INTENT(IN) :: key
1684 REAL(dp), INTENT(OUT) :: dest
1685
1686#if defined(__LIBTORCH)
1687
1688 INTERFACE
1689 SUBROUTINE torch_c_model_get_attr_double (model, key, dest) &
1690 BIND(C, name="torch_c_model_get_attr_double")
1691 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1692 TYPE(c_ptr), VALUE :: model
1693 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1694 REAL(kind=c_double) :: dest
1695 END SUBROUTINE torch_c_model_get_attr_double
1696 END INTERFACE
1697
1698 CALL torch_c_model_get_attr_double (model=model%c_ptr, &
1699 key=trim(key)//c_null_char, &
1700 dest=dest)
1701#else
1702 dest = 0.0_dp
1703 mark_used(model)
1704 mark_used(key)
1705 cpabort("CP2K compiled without the Torch library.")
1706#endif
1707 END SUBROUTINE torch_model_get_attr_double
1708! **************************************************************************************************
1709!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1710!> \author Ole Schuett
1711! **************************************************************************************************
1712 SUBROUTINE torch_model_get_attr_string (model, key, dest)
1713 TYPE(torch_model_type), INTENT(IN) :: model
1714 CHARACTER(len=*), INTENT(IN) :: key
1715 CHARACTER(LEN=default_string_length), INTENT(OUT) :: dest
1716
1717#if defined(__LIBTORCH)
1718
1719 INTERFACE
1720 SUBROUTINE torch_c_model_get_attr_string (model, key, dest) &
1721 BIND(C, name="torch_c_model_get_attr_string")
1722 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1723 TYPE(c_ptr), VALUE :: model
1724 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1725 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
1726 END SUBROUTINE torch_c_model_get_attr_string
1727 END INTERFACE
1728
1729 CALL torch_c_model_get_attr_string (model=model%c_ptr, &
1730 key=trim(key)//c_null_char, &
1731 dest=dest)
1732#else
1733 dest = ""
1734 mark_used(model)
1735 mark_used(key)
1736 cpabort("CP2K compiled without the Torch library.")
1737#endif
1738 END SUBROUTINE torch_model_get_attr_string
1739
1740! **************************************************************************************************
1741!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1742!> \author Ole Schuett
1743! **************************************************************************************************
1744 SUBROUTINE torch_model_get_attr_int32(model, key, dest)
1745 TYPE(torch_model_type), INTENT(IN) :: model
1746 CHARACTER(len=*), INTENT(IN) :: key
1747 INTEGER, INTENT(OUT) :: dest
1748
1749 INTEGER(kind=int_8) :: temp
1750 CALL torch_model_get_attr_int64(model, key, temp)
1751 cpassert(abs(temp) < huge(dest))
1752 dest = int(temp)
1753 END SUBROUTINE torch_model_get_attr_int32
1754
1755! **************************************************************************************************
1756!> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
1757!> \author Ole Schuett
1758! **************************************************************************************************
1759 SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
1760 TYPE(torch_model_type), INTENT(IN) :: model
1761 CHARACTER(len=*), INTENT(IN) :: key
1762 CHARACTER(LEN=default_string_length), &
1763 ALLOCATABLE, DIMENSION(:) :: dest
1764
1765#if defined(__LIBTORCH)
1766
1767 INTEGER :: num_items, i
1768
1769 INTERFACE
1770 SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
1771 BIND(C, name="torch_c_model_get_attr_list_size")
1772 IMPORT :: c_ptr, c_char, c_int
1773 TYPE(c_ptr), VALUE :: model
1774 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1775 INTEGER(kind=C_INT) :: size
1776 END SUBROUTINE torch_c_model_get_attr_list_size
1777 END INTERFACE
1778
1779 INTERFACE
1780 SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
1781 BIND(C, name="torch_c_model_get_attr_strlist")
1782 IMPORT :: c_ptr, c_char, c_int
1783 TYPE(c_ptr), VALUE :: model
1784 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1785 INTEGER(kind=C_INT), VALUE :: index
1786 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
1787 END SUBROUTINE torch_c_model_get_attr_strlist
1788 END INTERFACE
1789
1790 CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
1791 key=trim(key)//c_null_char, &
1792 size=num_items)
1793 ALLOCATE (dest(num_items))
1794 dest(:) = ""
1795
1796 DO i = 1, num_items
1797 CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
1798 key=trim(key)//c_null_char, &
1799 index=i - 1, &
1800 dest=dest(i))
1801
1802 END DO
1803#else
1804 cpabort("CP2K compiled without the Torch library.")
1805 mark_used(model)
1806 mark_used(key)
1807 mark_used(dest)
1808#endif
1809
1810 END SUBROUTINE torch_model_get_attr_strlist
1811
1812END MODULE torch_api
struct tensor_ tensor
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
integer, parameter, public sp
Definition kinds.F:33
integer, parameter, public int_4
Definition kinds.F:51
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1402
subroutine, public torch_tensor_backward(tensor, outer_grad)
Runs autograd on a Torch tensor.
Definition torch_api.F:1226
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
Definition torch_api.F:1368
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1426
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1312
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition torch_api.F:1494
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
Definition torch_api.F:1259
subroutine, public torch_allow_tf32(allow_tf32)
Set whether to allow the use of TF32. Needed due to changes in defaults from pytorch 1....
Definition torch_api.F:1599
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
Definition torch_api.F:1622
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition torch_api.F:1518
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1336
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition torch_api.F:1575
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:1288
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Definition torch_api.F:1458