(git:1f9fd2c)
Loading...
Searching...
No Matches
torch_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
8 USE iso_c_binding, ONLY: c_associated, &
9 c_bool, &
10 c_char, &
11 c_float, &
12 c_double, &
13 c_f_pointer, &
14 c_int, &
15 c_null_char, &
16 c_null_ptr, &
17 c_ptr, &
18 c_int32_t, &
19 c_int64_t
20
22
23#include "./base/base_uses.f90"
24
25 IMPLICIT NONE
26
27 PRIVATE
28
30 PRIVATE
31 TYPE(C_PTR) :: c_ptr = c_null_ptr
32 END TYPE torch_tensor_type
33
35 PRIVATE
36 TYPE(C_PTR) :: c_ptr = c_null_ptr
37 END TYPE torch_dict_type
38
40 PRIVATE
41 TYPE(C_PTR) :: c_ptr = c_null_ptr
42 END TYPE torch_model_type
43
45 MODULE PROCEDURE torch_tensor_from_array_int32_1d
46 MODULE PROCEDURE torch_tensor_from_array_float_1d
47 MODULE PROCEDURE torch_tensor_from_array_int64_1d
48 MODULE PROCEDURE torch_tensor_from_array_double_1d
49 MODULE PROCEDURE torch_tensor_from_array_int32_2d
50 MODULE PROCEDURE torch_tensor_from_array_float_2d
51 MODULE PROCEDURE torch_tensor_from_array_int64_2d
52 MODULE PROCEDURE torch_tensor_from_array_double_2d
53 MODULE PROCEDURE torch_tensor_from_array_int32_3d
54 MODULE PROCEDURE torch_tensor_from_array_float_3d
55 MODULE PROCEDURE torch_tensor_from_array_int64_3d
56 MODULE PROCEDURE torch_tensor_from_array_double_3d
57 END INTERFACE torch_tensor_from_array
58
60 MODULE PROCEDURE torch_tensor_reset_from_array_double_1d
61 MODULE PROCEDURE torch_tensor_reset_from_array_double_2d
62 MODULE PROCEDURE torch_tensor_reset_from_array_double_3d
64
66 MODULE PROCEDURE torch_tensor_data_ptr_int32_1d
67 MODULE PROCEDURE torch_tensor_data_ptr_float_1d
68 MODULE PROCEDURE torch_tensor_data_ptr_int64_1d
69 MODULE PROCEDURE torch_tensor_data_ptr_double_1d
70 MODULE PROCEDURE torch_tensor_data_ptr_int32_2d
71 MODULE PROCEDURE torch_tensor_data_ptr_float_2d
72 MODULE PROCEDURE torch_tensor_data_ptr_int64_2d
73 MODULE PROCEDURE torch_tensor_data_ptr_double_2d
74 MODULE PROCEDURE torch_tensor_data_ptr_int32_3d
75 MODULE PROCEDURE torch_tensor_data_ptr_float_3d
76 MODULE PROCEDURE torch_tensor_data_ptr_int64_3d
77 MODULE PROCEDURE torch_tensor_data_ptr_double_3d
78 END INTERFACE torch_tensor_data_ptr
79
81 MODULE PROCEDURE torch_model_get_attr_string
82 MODULE PROCEDURE torch_model_get_attr_double
83 MODULE PROCEDURE torch_model_get_attr_int64
84 MODULE PROCEDURE torch_model_get_attr_int32
85 MODULE PROCEDURE torch_model_get_attr_strlist
86 END INTERFACE torch_model_get_attr
87
92 PUBLIC :: torch_tensor_grad
102
103CONTAINS
104
105
106
107! **************************************************************************************************
108!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
109!> The source must be an ALLOCATABLE to prevent passing a temporary array.
110!> \author Ole Schuett
111! **************************************************************************************************
112 SUBROUTINE torch_tensor_from_array_int32_1d(tensor, source, requires_grad)
113 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
114 INTEGER(kind=int_4), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
115 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
116
117#if defined(__LIBTORCH)
118 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
119 LOGICAL :: my_req_grad
120
121 INTERFACE
122 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
123 BIND(C, name="torch_c_tensor_from_array_int32")
124 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
125 TYPE(c_ptr) :: tensor
126 LOGICAL(kind=C_BOOL), VALUE :: req_grad
127 INTEGER(kind=C_INT), VALUE :: ndims
128 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
129 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
130 END SUBROUTINE torch_c_tensor_from_array_int32
131 END INTERFACE
132
133 my_req_grad = .false.
134 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
135
136 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
137
138 cpassert(.NOT. c_associated(tensor%c_ptr))
139 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
140 req_grad=LOGICAL(my_req_grad, C_BOOL), &
141 ndims=1, &
142 sizes=sizes_c, &
143 source=source)
144 cpassert(c_associated(tensor%c_ptr))
145#else
146 cpabort("CP2K compiled without the Torch library.")
147 mark_used(tensor)
148 mark_used(source)
149 mark_used(requires_grad)
150#endif
151 END SUBROUTINE torch_tensor_from_array_int32_1d
152
153! **************************************************************************************************
154!> \brief Copies data from a Torch tensor to an array.
155!> The returned pointer is only valide during the tensor's lifetime!
156!> \author Ole Schuett
157! **************************************************************************************************
158 SUBROUTINE torch_tensor_data_ptr_int32_1d(tensor, data_ptr)
159 TYPE(torch_tensor_type), INTENT(IN) :: tensor
160 INTEGER(kind=int_4), DIMENSION(:), POINTER :: data_ptr
161
162#if defined(__LIBTORCH)
163 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
164 TYPE(c_ptr) :: data_ptr_c
165
166 INTERFACE
167 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
168 BIND(C, name="torch_c_tensor_data_ptr_int32")
169 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
170 TYPE(c_ptr), VALUE :: tensor
171 INTEGER(kind=C_INT), VALUE :: ndims
172 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
173 TYPE(c_ptr) :: data_ptr
174 END SUBROUTINE torch_c_tensor_data_ptr_int32
175 END INTERFACE
176
177 sizes_c(:) = -1
178 data_ptr_c = c_null_ptr
179 cpassert(c_associated(tensor%c_ptr))
180 cpassert(.NOT. ASSOCIATED(data_ptr))
181 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
182 ndims=1, &
183 sizes=sizes_c, &
184 data_ptr=data_ptr_c)
185
186 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
187
188 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
189 cpassert(c_associated(data_ptr_c))
190 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
191 END IF
192#else
193 cpabort("CP2K compiled without the Torch library.")
194 mark_used(tensor)
195 mark_used(data_ptr)
196#endif
197 END SUBROUTINE torch_tensor_data_ptr_int32_1d
198
199
200! **************************************************************************************************
201!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
202!> The source must be an ALLOCATABLE to prevent passing a temporary array.
203!> \author Ole Schuett
204! **************************************************************************************************
205 SUBROUTINE torch_tensor_from_array_float_1d(tensor, source, requires_grad)
206 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
207 REAL(sp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
208 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
209
210#if defined(__LIBTORCH)
211 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
212 LOGICAL :: my_req_grad
213
214 INTERFACE
215 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
216 BIND(C, name="torch_c_tensor_from_array_float")
217 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
218 TYPE(c_ptr) :: tensor
219 LOGICAL(kind=C_BOOL), VALUE :: req_grad
220 INTEGER(kind=C_INT), VALUE :: ndims
221 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
222 REAL(kind=c_float), DIMENSION(*) :: source
223 END SUBROUTINE torch_c_tensor_from_array_float
224 END INTERFACE
225
226 my_req_grad = .false.
227 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
228
229 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
230
231 cpassert(.NOT. c_associated(tensor%c_ptr))
232 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
233 req_grad=LOGICAL(my_req_grad, C_BOOL), &
234 ndims=1, &
235 sizes=sizes_c, &
236 source=source)
237 cpassert(c_associated(tensor%c_ptr))
238#else
239 cpabort("CP2K compiled without the Torch library.")
240 mark_used(tensor)
241 mark_used(source)
242 mark_used(requires_grad)
243#endif
244 END SUBROUTINE torch_tensor_from_array_float_1d
245
246! **************************************************************************************************
247!> \brief Copies data from a Torch tensor to an array.
248!> The returned pointer is only valide during the tensor's lifetime!
249!> \author Ole Schuett
250! **************************************************************************************************
251 SUBROUTINE torch_tensor_data_ptr_float_1d(tensor, data_ptr)
252 TYPE(torch_tensor_type), INTENT(IN) :: tensor
253 REAL(sp), DIMENSION(:), POINTER :: data_ptr
254
255#if defined(__LIBTORCH)
256 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
257 TYPE(c_ptr) :: data_ptr_c
258
259 INTERFACE
260 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
261 BIND(C, name="torch_c_tensor_data_ptr_float")
262 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
263 TYPE(c_ptr), VALUE :: tensor
264 INTEGER(kind=C_INT), VALUE :: ndims
265 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
266 TYPE(c_ptr) :: data_ptr
267 END SUBROUTINE torch_c_tensor_data_ptr_float
268 END INTERFACE
269
270 sizes_c(:) = -1
271 data_ptr_c = c_null_ptr
272 cpassert(c_associated(tensor%c_ptr))
273 cpassert(.NOT. ASSOCIATED(data_ptr))
274 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
275 ndims=1, &
276 sizes=sizes_c, &
277 data_ptr=data_ptr_c)
278
279 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
280
281 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
282 cpassert(c_associated(data_ptr_c))
283 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
284 END IF
285#else
286 cpabort("CP2K compiled without the Torch library.")
287 mark_used(tensor)
288 mark_used(data_ptr)
289#endif
290 END SUBROUTINE torch_tensor_data_ptr_float_1d
291
292
293! **************************************************************************************************
294!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
295!> The source must be an ALLOCATABLE to prevent passing a temporary array.
296!> \author Ole Schuett
297! **************************************************************************************************
298 SUBROUTINE torch_tensor_from_array_int64_1d(tensor, source, requires_grad)
299 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
300 INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
301 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
302
303#if defined(__LIBTORCH)
304 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
305 LOGICAL :: my_req_grad
306
307 INTERFACE
308 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
309 BIND(C, name="torch_c_tensor_from_array_int64")
310 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
311 TYPE(c_ptr) :: tensor
312 LOGICAL(kind=C_BOOL), VALUE :: req_grad
313 INTEGER(kind=C_INT), VALUE :: ndims
314 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
315 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
316 END SUBROUTINE torch_c_tensor_from_array_int64
317 END INTERFACE
318
319 my_req_grad = .false.
320 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
321
322 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
323
324 cpassert(.NOT. c_associated(tensor%c_ptr))
325 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
326 req_grad=LOGICAL(my_req_grad, C_BOOL), &
327 ndims=1, &
328 sizes=sizes_c, &
329 source=source)
330 cpassert(c_associated(tensor%c_ptr))
331#else
332 cpabort("CP2K compiled without the Torch library.")
333 mark_used(tensor)
334 mark_used(source)
335 mark_used(requires_grad)
336#endif
337 END SUBROUTINE torch_tensor_from_array_int64_1d
338
339! **************************************************************************************************
340!> \brief Copies data from a Torch tensor to an array.
341!> The returned pointer is only valide during the tensor's lifetime!
342!> \author Ole Schuett
343! **************************************************************************************************
344 SUBROUTINE torch_tensor_data_ptr_int64_1d(tensor, data_ptr)
345 TYPE(torch_tensor_type), INTENT(IN) :: tensor
346 INTEGER(kind=int_8), DIMENSION(:), POINTER :: data_ptr
347
348#if defined(__LIBTORCH)
349 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
350 TYPE(c_ptr) :: data_ptr_c
351
352 INTERFACE
353 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
354 BIND(C, name="torch_c_tensor_data_ptr_int64")
355 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
356 TYPE(c_ptr), VALUE :: tensor
357 INTEGER(kind=C_INT), VALUE :: ndims
358 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
359 TYPE(c_ptr) :: data_ptr
360 END SUBROUTINE torch_c_tensor_data_ptr_int64
361 END INTERFACE
362
363 sizes_c(:) = -1
364 data_ptr_c = c_null_ptr
365 cpassert(c_associated(tensor%c_ptr))
366 cpassert(.NOT. ASSOCIATED(data_ptr))
367 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
368 ndims=1, &
369 sizes=sizes_c, &
370 data_ptr=data_ptr_c)
371
372 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
373
374 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
375 cpassert(c_associated(data_ptr_c))
376 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
377 END IF
378#else
379 cpabort("CP2K compiled without the Torch library.")
380 mark_used(tensor)
381 mark_used(data_ptr)
382#endif
383 END SUBROUTINE torch_tensor_data_ptr_int64_1d
384
385
386! **************************************************************************************************
387!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
388!> The source must be an ALLOCATABLE to prevent passing a temporary array.
389!> \author Ole Schuett
390! **************************************************************************************************
391 SUBROUTINE torch_tensor_from_array_double_1d(tensor, source, requires_grad)
392 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
393 REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
394 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
395
396#if defined(__LIBTORCH)
397 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
398 LOGICAL :: my_req_grad
399
400 INTERFACE
401 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
402 BIND(C, name="torch_c_tensor_from_array_double")
403 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
404 TYPE(c_ptr) :: tensor
405 LOGICAL(kind=C_BOOL), VALUE :: req_grad
406 INTEGER(kind=C_INT), VALUE :: ndims
407 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
408 REAL(kind=c_double), DIMENSION(*) :: source
409 END SUBROUTINE torch_c_tensor_from_array_double
410 END INTERFACE
411
412 my_req_grad = .false.
413 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
414
415 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
416
417 cpassert(.NOT. c_associated(tensor%c_ptr))
418 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
419 req_grad=LOGICAL(my_req_grad, C_BOOL), &
420 ndims=1, &
421 sizes=sizes_c, &
422 source=source)
423 cpassert(c_associated(tensor%c_ptr))
424#else
425 cpabort("CP2K compiled without the Torch library.")
426 mark_used(tensor)
427 mark_used(source)
428 mark_used(requires_grad)
429#endif
430 END SUBROUTINE torch_tensor_from_array_double_1d
431
432! **************************************************************************************************
433!> \brief Copies data from a Torch tensor to an array.
434!> The returned pointer is only valide during the tensor's lifetime!
435!> \author Ole Schuett
436! **************************************************************************************************
437 SUBROUTINE torch_tensor_data_ptr_double_1d(tensor, data_ptr)
438 TYPE(torch_tensor_type), INTENT(IN) :: tensor
439 REAL(dp), DIMENSION(:), POINTER :: data_ptr
440
441#if defined(__LIBTORCH)
442 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
443 TYPE(c_ptr) :: data_ptr_c
444
445 INTERFACE
446 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
447 BIND(C, name="torch_c_tensor_data_ptr_double")
448 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
449 TYPE(c_ptr), VALUE :: tensor
450 INTEGER(kind=C_INT), VALUE :: ndims
451 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
452 TYPE(c_ptr) :: data_ptr
453 END SUBROUTINE torch_c_tensor_data_ptr_double
454 END INTERFACE
455
456 sizes_c(:) = -1
457 data_ptr_c = c_null_ptr
458 cpassert(c_associated(tensor%c_ptr))
459 cpassert(.NOT. ASSOCIATED(data_ptr))
460 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
461 ndims=1, &
462 sizes=sizes_c, &
463 data_ptr=data_ptr_c)
464
465 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
466
467 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
468 cpassert(c_associated(data_ptr_c))
469 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
470 END IF
471#else
472 cpabort("CP2K compiled without the Torch library.")
473 mark_used(tensor)
474 mark_used(data_ptr)
475#endif
476 END SUBROUTINE torch_tensor_data_ptr_double_1d
477
478
479! **************************************************************************************************
480!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
481!> The source must be an ALLOCATABLE to prevent passing a temporary array.
482!> \author Ole Schuett
483! **************************************************************************************************
484 SUBROUTINE torch_tensor_from_array_int32_2d(tensor, source, requires_grad)
485 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
486 INTEGER(kind=int_4), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
487 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
488
489#if defined(__LIBTORCH)
490 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
491 LOGICAL :: my_req_grad
492
493 INTERFACE
494 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
495 BIND(C, name="torch_c_tensor_from_array_int32")
496 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
497 TYPE(c_ptr) :: tensor
498 LOGICAL(kind=C_BOOL), VALUE :: req_grad
499 INTEGER(kind=C_INT), VALUE :: ndims
500 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
501 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
502 END SUBROUTINE torch_c_tensor_from_array_int32
503 END INTERFACE
504
505 my_req_grad = .false.
506 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
507
508 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
509 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
510
511 cpassert(.NOT. c_associated(tensor%c_ptr))
512 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
513 req_grad=LOGICAL(my_req_grad, C_BOOL), &
514 ndims=2, &
515 sizes=sizes_c, &
516 source=source)
517 cpassert(c_associated(tensor%c_ptr))
518#else
519 cpabort("CP2K compiled without the Torch library.")
520 mark_used(tensor)
521 mark_used(source)
522 mark_used(requires_grad)
523#endif
524 END SUBROUTINE torch_tensor_from_array_int32_2d
525
526! **************************************************************************************************
527!> \brief Copies data from a Torch tensor to an array.
528!> The returned pointer is only valide during the tensor's lifetime!
529!> \author Ole Schuett
530! **************************************************************************************************
531 SUBROUTINE torch_tensor_data_ptr_int32_2d(tensor, data_ptr)
532 TYPE(torch_tensor_type), INTENT(IN) :: tensor
533 INTEGER(kind=int_4), DIMENSION(:, :), POINTER :: data_ptr
534
535#if defined(__LIBTORCH)
536 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
537 TYPE(c_ptr) :: data_ptr_c
538
539 INTERFACE
540 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
541 BIND(C, name="torch_c_tensor_data_ptr_int32")
542 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
543 TYPE(c_ptr), VALUE :: tensor
544 INTEGER(kind=C_INT), VALUE :: ndims
545 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
546 TYPE(c_ptr) :: data_ptr
547 END SUBROUTINE torch_c_tensor_data_ptr_int32
548 END INTERFACE
549
550 sizes_c(:) = -1
551 data_ptr_c = c_null_ptr
552 cpassert(c_associated(tensor%c_ptr))
553 cpassert(.NOT. ASSOCIATED(data_ptr))
554 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
555 ndims=2, &
556 sizes=sizes_c, &
557 data_ptr=data_ptr_c)
558
559 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
560 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
561
562 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
563 cpassert(c_associated(data_ptr_c))
564 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
565 END IF
566#else
567 cpabort("CP2K compiled without the Torch library.")
568 mark_used(tensor)
569 mark_used(data_ptr)
570#endif
571 END SUBROUTINE torch_tensor_data_ptr_int32_2d
572
573
574! **************************************************************************************************
575!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
576!> The source must be an ALLOCATABLE to prevent passing a temporary array.
577!> \author Ole Schuett
578! **************************************************************************************************
579 SUBROUTINE torch_tensor_from_array_float_2d(tensor, source, requires_grad)
580 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
581 REAL(sp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
582 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
583
584#if defined(__LIBTORCH)
585 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
586 LOGICAL :: my_req_grad
587
588 INTERFACE
589 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
590 BIND(C, name="torch_c_tensor_from_array_float")
591 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
592 TYPE(c_ptr) :: tensor
593 LOGICAL(kind=C_BOOL), VALUE :: req_grad
594 INTEGER(kind=C_INT), VALUE :: ndims
595 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
596 REAL(kind=c_float), DIMENSION(*) :: source
597 END SUBROUTINE torch_c_tensor_from_array_float
598 END INTERFACE
599
600 my_req_grad = .false.
601 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
602
603 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
604 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
605
606 cpassert(.NOT. c_associated(tensor%c_ptr))
607 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
608 req_grad=LOGICAL(my_req_grad, C_BOOL), &
609 ndims=2, &
610 sizes=sizes_c, &
611 source=source)
612 cpassert(c_associated(tensor%c_ptr))
613#else
614 cpabort("CP2K compiled without the Torch library.")
615 mark_used(tensor)
616 mark_used(source)
617 mark_used(requires_grad)
618#endif
619 END SUBROUTINE torch_tensor_from_array_float_2d
620
621! **************************************************************************************************
622!> \brief Copies data from a Torch tensor to an array.
623!> The returned pointer is only valide during the tensor's lifetime!
624!> \author Ole Schuett
625! **************************************************************************************************
626 SUBROUTINE torch_tensor_data_ptr_float_2d(tensor, data_ptr)
627 TYPE(torch_tensor_type), INTENT(IN) :: tensor
628 REAL(sp), DIMENSION(:, :), POINTER :: data_ptr
629
630#if defined(__LIBTORCH)
631 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
632 TYPE(c_ptr) :: data_ptr_c
633
634 INTERFACE
635 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
636 BIND(C, name="torch_c_tensor_data_ptr_float")
637 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
638 TYPE(c_ptr), VALUE :: tensor
639 INTEGER(kind=C_INT), VALUE :: ndims
640 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
641 TYPE(c_ptr) :: data_ptr
642 END SUBROUTINE torch_c_tensor_data_ptr_float
643 END INTERFACE
644
645 sizes_c(:) = -1
646 data_ptr_c = c_null_ptr
647 cpassert(c_associated(tensor%c_ptr))
648 cpassert(.NOT. ASSOCIATED(data_ptr))
649 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
650 ndims=2, &
651 sizes=sizes_c, &
652 data_ptr=data_ptr_c)
653
654 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
655 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
656
657 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
658 cpassert(c_associated(data_ptr_c))
659 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
660 END IF
661#else
662 cpabort("CP2K compiled without the Torch library.")
663 mark_used(tensor)
664 mark_used(data_ptr)
665#endif
666 END SUBROUTINE torch_tensor_data_ptr_float_2d
667
668
669! **************************************************************************************************
670!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
671!> The source must be an ALLOCATABLE to prevent passing a temporary array.
672!> \author Ole Schuett
673! **************************************************************************************************
674 SUBROUTINE torch_tensor_from_array_int64_2d(tensor, source, requires_grad)
675 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
676 INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
677 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
678
679#if defined(__LIBTORCH)
680 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
681 LOGICAL :: my_req_grad
682
683 INTERFACE
684 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
685 BIND(C, name="torch_c_tensor_from_array_int64")
686 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
687 TYPE(c_ptr) :: tensor
688 LOGICAL(kind=C_BOOL), VALUE :: req_grad
689 INTEGER(kind=C_INT), VALUE :: ndims
690 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
691 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
692 END SUBROUTINE torch_c_tensor_from_array_int64
693 END INTERFACE
694
695 my_req_grad = .false.
696 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
697
698 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
699 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
700
701 cpassert(.NOT. c_associated(tensor%c_ptr))
702 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
703 req_grad=LOGICAL(my_req_grad, C_BOOL), &
704 ndims=2, &
705 sizes=sizes_c, &
706 source=source)
707 cpassert(c_associated(tensor%c_ptr))
708#else
709 cpabort("CP2K compiled without the Torch library.")
710 mark_used(tensor)
711 mark_used(source)
712 mark_used(requires_grad)
713#endif
714 END SUBROUTINE torch_tensor_from_array_int64_2d
715
716! **************************************************************************************************
717!> \brief Copies data from a Torch tensor to an array.
718!> The returned pointer is only valide during the tensor's lifetime!
719!> \author Ole Schuett
720! **************************************************************************************************
721 SUBROUTINE torch_tensor_data_ptr_int64_2d(tensor, data_ptr)
722 TYPE(torch_tensor_type), INTENT(IN) :: tensor
723 INTEGER(kind=int_8), DIMENSION(:, :), POINTER :: data_ptr
724
725#if defined(__LIBTORCH)
726 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
727 TYPE(c_ptr) :: data_ptr_c
728
729 INTERFACE
730 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
731 BIND(C, name="torch_c_tensor_data_ptr_int64")
732 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
733 TYPE(c_ptr), VALUE :: tensor
734 INTEGER(kind=C_INT), VALUE :: ndims
735 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
736 TYPE(c_ptr) :: data_ptr
737 END SUBROUTINE torch_c_tensor_data_ptr_int64
738 END INTERFACE
739
740 sizes_c(:) = -1
741 data_ptr_c = c_null_ptr
742 cpassert(c_associated(tensor%c_ptr))
743 cpassert(.NOT. ASSOCIATED(data_ptr))
744 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
745 ndims=2, &
746 sizes=sizes_c, &
747 data_ptr=data_ptr_c)
748
749 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
750 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
751
752 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
753 cpassert(c_associated(data_ptr_c))
754 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
755 END IF
756#else
757 cpabort("CP2K compiled without the Torch library.")
758 mark_used(tensor)
759 mark_used(data_ptr)
760#endif
761 END SUBROUTINE torch_tensor_data_ptr_int64_2d
762
763
764! **************************************************************************************************
765!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
766!> The source must be an ALLOCATABLE to prevent passing a temporary array.
767!> \author Ole Schuett
768! **************************************************************************************************
769 SUBROUTINE torch_tensor_from_array_double_2d(tensor, source, requires_grad)
770 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
771 REAL(dp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
772 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
773
774#if defined(__LIBTORCH)
775 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
776 LOGICAL :: my_req_grad
777
778 INTERFACE
779 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
780 BIND(C, name="torch_c_tensor_from_array_double")
781 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
782 TYPE(c_ptr) :: tensor
783 LOGICAL(kind=C_BOOL), VALUE :: req_grad
784 INTEGER(kind=C_INT), VALUE :: ndims
785 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
786 REAL(kind=c_double), DIMENSION(*) :: source
787 END SUBROUTINE torch_c_tensor_from_array_double
788 END INTERFACE
789
790 my_req_grad = .false.
791 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
792
793 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
794 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
795
796 cpassert(.NOT. c_associated(tensor%c_ptr))
797 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
798 req_grad=LOGICAL(my_req_grad, C_BOOL), &
799 ndims=2, &
800 sizes=sizes_c, &
801 source=source)
802 cpassert(c_associated(tensor%c_ptr))
803#else
804 cpabort("CP2K compiled without the Torch library.")
805 mark_used(tensor)
806 mark_used(source)
807 mark_used(requires_grad)
808#endif
809 END SUBROUTINE torch_tensor_from_array_double_2d
810
811! **************************************************************************************************
812!> \brief Copies data from a Torch tensor to an array.
813!> The returned pointer is only valide during the tensor's lifetime!
814!> \author Ole Schuett
815! **************************************************************************************************
816 SUBROUTINE torch_tensor_data_ptr_double_2d(tensor, data_ptr)
817 TYPE(torch_tensor_type), INTENT(IN) :: tensor
818 REAL(dp), DIMENSION(:, :), POINTER :: data_ptr
819
820#if defined(__LIBTORCH)
821 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
822 TYPE(c_ptr) :: data_ptr_c
823
824 INTERFACE
825 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
826 BIND(C, name="torch_c_tensor_data_ptr_double")
827 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
828 TYPE(c_ptr), VALUE :: tensor
829 INTEGER(kind=C_INT), VALUE :: ndims
830 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
831 TYPE(c_ptr) :: data_ptr
832 END SUBROUTINE torch_c_tensor_data_ptr_double
833 END INTERFACE
834
835 sizes_c(:) = -1
836 data_ptr_c = c_null_ptr
837 cpassert(c_associated(tensor%c_ptr))
838 cpassert(.NOT. ASSOCIATED(data_ptr))
839 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
840 ndims=2, &
841 sizes=sizes_c, &
842 data_ptr=data_ptr_c)
843
844 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
845 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
846
847 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
848 cpassert(c_associated(data_ptr_c))
849 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
850 END IF
851#else
852 cpabort("CP2K compiled without the Torch library.")
853 mark_used(tensor)
854 mark_used(data_ptr)
855#endif
856 END SUBROUTINE torch_tensor_data_ptr_double_2d
857
858
859! **************************************************************************************************
860!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
861!> The source must be an ALLOCATABLE to prevent passing a temporary array.
862!> \author Ole Schuett
863! **************************************************************************************************
864 SUBROUTINE torch_tensor_from_array_int32_3d(tensor, source, requires_grad)
865 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
866 INTEGER(kind=int_4), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
867 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
868
869#if defined(__LIBTORCH)
870 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
871 LOGICAL :: my_req_grad
872
873 INTERFACE
874 SUBROUTINE torch_c_tensor_from_array_int32 (tensor, req_grad, ndims, sizes, source) &
875 BIND(C, name="torch_c_tensor_from_array_int32")
876 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
877 TYPE(c_ptr) :: tensor
878 LOGICAL(kind=C_BOOL), VALUE :: req_grad
879 INTEGER(kind=C_INT), VALUE :: ndims
880 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
881 INTEGER(kind=C_INT32_T), DIMENSION(*) :: source
882 END SUBROUTINE torch_c_tensor_from_array_int32
883 END INTERFACE
884
885 my_req_grad = .false.
886 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
887
888 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
889 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
890 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
891
892 cpassert(.NOT. c_associated(tensor%c_ptr))
893 CALL torch_c_tensor_from_array_int32 (tensor=tensor%c_ptr, &
894 req_grad=LOGICAL(my_req_grad, C_BOOL), &
895 ndims=3, &
896 sizes=sizes_c, &
897 source=source)
898 cpassert(c_associated(tensor%c_ptr))
899#else
900 cpabort("CP2K compiled without the Torch library.")
901 mark_used(tensor)
902 mark_used(source)
903 mark_used(requires_grad)
904#endif
905 END SUBROUTINE torch_tensor_from_array_int32_3d
906
907! **************************************************************************************************
908!> \brief Copies data from a Torch tensor to an array.
909!> The returned pointer is only valide during the tensor's lifetime!
910!> \author Ole Schuett
911! **************************************************************************************************
912 SUBROUTINE torch_tensor_data_ptr_int32_3d(tensor, data_ptr)
913 TYPE(torch_tensor_type), INTENT(IN) :: tensor
914 INTEGER(kind=int_4), DIMENSION(:, :, :), POINTER :: data_ptr
915
916#if defined(__LIBTORCH)
917 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
918 TYPE(c_ptr) :: data_ptr_c
919
920 INTERFACE
921 SUBROUTINE torch_c_tensor_data_ptr_int32 (tensor, ndims, sizes, data_ptr) &
922 BIND(C, name="torch_c_tensor_data_ptr_int32")
923 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
924 TYPE(c_ptr), VALUE :: tensor
925 INTEGER(kind=C_INT), VALUE :: ndims
926 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
927 TYPE(c_ptr) :: data_ptr
928 END SUBROUTINE torch_c_tensor_data_ptr_int32
929 END INTERFACE
930
931 sizes_c(:) = -1
932 data_ptr_c = c_null_ptr
933 cpassert(c_associated(tensor%c_ptr))
934 cpassert(.NOT. ASSOCIATED(data_ptr))
935 CALL torch_c_tensor_data_ptr_int32 (tensor=tensor%c_ptr, &
936 ndims=3, &
937 sizes=sizes_c, &
938 data_ptr=data_ptr_c)
939
940 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
941 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
942 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
943
944 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
945 cpassert(c_associated(data_ptr_c))
946 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
947 END IF
948#else
949 cpabort("CP2K compiled without the Torch library.")
950 mark_used(tensor)
951 mark_used(data_ptr)
952#endif
953 END SUBROUTINE torch_tensor_data_ptr_int32_3d
954
955
956! **************************************************************************************************
957!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
958!> The source must be an ALLOCATABLE to prevent passing a temporary array.
959!> \author Ole Schuett
960! **************************************************************************************************
961 SUBROUTINE torch_tensor_from_array_float_3d(tensor, source, requires_grad)
962 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
963 REAL(sp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
964 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
965
966#if defined(__LIBTORCH)
967 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
968 LOGICAL :: my_req_grad
969
970 INTERFACE
971 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
972 BIND(C, name="torch_c_tensor_from_array_float")
973 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
974 TYPE(c_ptr) :: tensor
975 LOGICAL(kind=C_BOOL), VALUE :: req_grad
976 INTEGER(kind=C_INT), VALUE :: ndims
977 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
978 REAL(kind=c_float), DIMENSION(*) :: source
979 END SUBROUTINE torch_c_tensor_from_array_float
980 END INTERFACE
981
982 my_req_grad = .false.
983 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
984
985 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
986 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
987 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
988
989 cpassert(.NOT. c_associated(tensor%c_ptr))
990 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
991 req_grad=LOGICAL(my_req_grad, C_BOOL), &
992 ndims=3, &
993 sizes=sizes_c, &
994 source=source)
995 cpassert(c_associated(tensor%c_ptr))
996#else
997 cpabort("CP2K compiled without the Torch library.")
998 mark_used(tensor)
999 mark_used(source)
1000 mark_used(requires_grad)
1001#endif
1002 END SUBROUTINE torch_tensor_from_array_float_3d
1003
1004! **************************************************************************************************
1005!> \brief Copies data from a Torch tensor to an array.
1006!> The returned pointer is only valide during the tensor's lifetime!
1007!> \author Ole Schuett
1008! **************************************************************************************************
1009 SUBROUTINE torch_tensor_data_ptr_float_3d(tensor, data_ptr)
1010 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1011 REAL(sp), DIMENSION(:, :, :), POINTER :: data_ptr
1012
1013#if defined(__LIBTORCH)
1014 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
1015 TYPE(c_ptr) :: data_ptr_c
1016
1017 INTERFACE
1018 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
1019 BIND(C, name="torch_c_tensor_data_ptr_float")
1020 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
1021 TYPE(c_ptr), VALUE :: tensor
1022 INTEGER(kind=C_INT), VALUE :: ndims
1023 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1024 TYPE(c_ptr) :: data_ptr
1025 END SUBROUTINE torch_c_tensor_data_ptr_float
1026 END INTERFACE
1027
1028 sizes_c(:) = -1
1029 data_ptr_c = c_null_ptr
1030 cpassert(c_associated(tensor%c_ptr))
1031 cpassert(.NOT. ASSOCIATED(data_ptr))
1032 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
1033 ndims=3, &
1034 sizes=sizes_c, &
1035 data_ptr=data_ptr_c)
1036
1037 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1038 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1039 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1040
1041 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
1042 cpassert(c_associated(data_ptr_c))
1043 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1044 END IF
1045#else
1046 cpabort("CP2K compiled without the Torch library.")
1047 mark_used(tensor)
1048 mark_used(data_ptr)
1049#endif
1050 END SUBROUTINE torch_tensor_data_ptr_float_3d
1051
1052
1053! **************************************************************************************************
1054!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
1055!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1056!> \author Ole Schuett
1057! **************************************************************************************************
1058 SUBROUTINE torch_tensor_from_array_int64_3d(tensor, source, requires_grad)
1059 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1060 INTEGER(kind=int_8), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
1061 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1062
1063#if defined(__LIBTORCH)
1064 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
1065 LOGICAL :: my_req_grad
1066
1067 INTERFACE
1068 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
1069 BIND(C, name="torch_c_tensor_from_array_int64")
1070 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
1071 TYPE(c_ptr) :: tensor
1072 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1073 INTEGER(kind=C_INT), VALUE :: ndims
1074 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1075 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
1076 END SUBROUTINE torch_c_tensor_from_array_int64
1077 END INTERFACE
1078
1079 my_req_grad = .false.
1080 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1081
1082 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
1083 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
1084 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
1085
1086 cpassert(.NOT. c_associated(tensor%c_ptr))
1087 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
1088 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1089 ndims=3, &
1090 sizes=sizes_c, &
1091 source=source)
1092 cpassert(c_associated(tensor%c_ptr))
1093#else
1094 cpabort("CP2K compiled without the Torch library.")
1095 mark_used(tensor)
1096 mark_used(source)
1097 mark_used(requires_grad)
1098#endif
1099 END SUBROUTINE torch_tensor_from_array_int64_3d
1100
1101! **************************************************************************************************
1102!> \brief Copies data from a Torch tensor to an array.
1103!> The returned pointer is only valide during the tensor's lifetime!
1104!> \author Ole Schuett
1105! **************************************************************************************************
1106 SUBROUTINE torch_tensor_data_ptr_int64_3d(tensor, data_ptr)
1107 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1108 INTEGER(kind=int_8), DIMENSION(:, :, :), POINTER :: data_ptr
1109
1110#if defined(__LIBTORCH)
1111 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
1112 TYPE(c_ptr) :: data_ptr_c
1113
1114 INTERFACE
1115 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
1116 BIND(C, name="torch_c_tensor_data_ptr_int64")
1117 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
1118 TYPE(c_ptr), VALUE :: tensor
1119 INTEGER(kind=C_INT), VALUE :: ndims
1120 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1121 TYPE(c_ptr) :: data_ptr
1122 END SUBROUTINE torch_c_tensor_data_ptr_int64
1123 END INTERFACE
1124
1125 sizes_c(:) = -1
1126 data_ptr_c = c_null_ptr
1127 cpassert(c_associated(tensor%c_ptr))
1128 cpassert(.NOT. ASSOCIATED(data_ptr))
1129 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
1130 ndims=3, &
1131 sizes=sizes_c, &
1132 data_ptr=data_ptr_c)
1133
1134 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1135 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1136 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1137
1138 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
1139 cpassert(c_associated(data_ptr_c))
1140 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1141 END IF
1142#else
1143 cpabort("CP2K compiled without the Torch library.")
1144 mark_used(tensor)
1145 mark_used(data_ptr)
1146#endif
1147 END SUBROUTINE torch_tensor_data_ptr_int64_3d
1148
1149
1150! **************************************************************************************************
1151!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
1152!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1153!> \author Ole Schuett
1154! **************************************************************************************************
1155 SUBROUTINE torch_tensor_from_array_double_3d(tensor, source, requires_grad)
1156 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1157 REAL(dp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
1158 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1159
1160#if defined(__LIBTORCH)
1161 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
1162 LOGICAL :: my_req_grad
1163
1164 INTERFACE
1165 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
1166 BIND(C, name="torch_c_tensor_from_array_double")
1167 IMPORT :: c_ptr, c_int, c_int32_t, c_int64_t, c_float, c_double, c_bool
1168 TYPE(c_ptr) :: tensor
1169 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1170 INTEGER(kind=C_INT), VALUE :: ndims
1171 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1172 REAL(kind=c_double), DIMENSION(*) :: source
1173 END SUBROUTINE torch_c_tensor_from_array_double
1174 END INTERFACE
1175
1176 my_req_grad = .false.
1177 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1178
1179 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
1180 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
1181 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
1182
1183 cpassert(.NOT. c_associated(tensor%c_ptr))
1184 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
1185 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1186 ndims=3, &
1187 sizes=sizes_c, &
1188 source=source)
1189 cpassert(c_associated(tensor%c_ptr))
1190#else
1191 cpabort("CP2K compiled without the Torch library.")
1192 mark_used(tensor)
1193 mark_used(source)
1194 mark_used(requires_grad)
1195#endif
1196 END SUBROUTINE torch_tensor_from_array_double_3d
1197
1198! **************************************************************************************************
1199!> \brief Copies data from a Torch tensor to an array.
1200!> The returned pointer is only valide during the tensor's lifetime!
1201!> \author Ole Schuett
1202! **************************************************************************************************
1203 SUBROUTINE torch_tensor_data_ptr_double_3d(tensor, data_ptr)
1204 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1205 REAL(dp), DIMENSION(:, :, :), POINTER :: data_ptr
1206
1207#if defined(__LIBTORCH)
1208 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
1209 TYPE(c_ptr) :: data_ptr_c
1210
1211 INTERFACE
1212 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
1213 BIND(C, name="torch_c_tensor_data_ptr_double")
1214 IMPORT :: c_char, c_ptr, c_int, c_int32_t, c_int64_t
1215 TYPE(c_ptr), VALUE :: tensor
1216 INTEGER(kind=C_INT), VALUE :: ndims
1217 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1218 TYPE(c_ptr) :: data_ptr
1219 END SUBROUTINE torch_c_tensor_data_ptr_double
1220 END INTERFACE
1221
1222 sizes_c(:) = -1
1223 data_ptr_c = c_null_ptr
1224 cpassert(c_associated(tensor%c_ptr))
1225 cpassert(.NOT. ASSOCIATED(data_ptr))
1226 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
1227 ndims=3, &
1228 sizes=sizes_c, &
1229 data_ptr=data_ptr_c)
1230
1231 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
1232 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
1233 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
1234
1235 IF (all(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
1236 cpassert(c_associated(data_ptr_c))
1237 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
1238 END IF
1239#else
1240 cpabort("CP2K compiled without the Torch library.")
1241 mark_used(tensor)
1242 mark_used(data_ptr)
1243#endif
1244 END SUBROUTINE torch_tensor_data_ptr_double_3d
1245
1246
1247
1248! **************************************************************************************************
1249!> \brief Reuses or creates a device leaf tensor and copies data into it.
1250!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1251! **************************************************************************************************
1252 SUBROUTINE torch_tensor_reset_from_array_double_1d(tensor, source, requires_grad)
1253 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1254 REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
1255 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1256
1257#if defined(__LIBTORCH)
1258 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
1259 LOGICAL :: my_req_grad
1260
1261 INTERFACE
1262 SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
1263 BIND(C, name="torch_c_tensor_reset_from_array_double")
1264 IMPORT :: c_ptr, c_int, c_int64_t, c_double, c_bool
1265 TYPE(c_ptr) :: tensor
1266 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1267 INTEGER(kind=C_INT), VALUE :: ndims
1268 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1269 REAL(kind=c_double), DIMENSION(*) :: source
1270 END SUBROUTINE torch_c_tensor_reset_from_array_double
1271 END INTERFACE
1272
1273 my_req_grad = .false.
1274 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1275
1276 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
1277
1278 CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
1279 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1280 ndims=1, &
1281 sizes=sizes_c, &
1282 source=source)
1283 cpassert(c_associated(tensor%c_ptr))
1284#else
1285 cpabort("CP2K compiled without the Torch library.")
1286 mark_used(tensor)
1287 mark_used(source)
1288 mark_used(requires_grad)
1289#endif
1290 END SUBROUTINE torch_tensor_reset_from_array_double_1d
1291
1292
1293! **************************************************************************************************
1294!> \brief Reuses or creates a device leaf tensor and copies data into it.
1295!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1296! **************************************************************************************************
1297 SUBROUTINE torch_tensor_reset_from_array_double_2d(tensor, source, requires_grad)
1298 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1299 REAL(dp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
1300 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1301
1302#if defined(__LIBTORCH)
1303 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
1304 LOGICAL :: my_req_grad
1305
1306 INTERFACE
1307 SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
1308 BIND(C, name="torch_c_tensor_reset_from_array_double")
1309 IMPORT :: c_ptr, c_int, c_int64_t, c_double, c_bool
1310 TYPE(c_ptr) :: tensor
1311 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1312 INTEGER(kind=C_INT), VALUE :: ndims
1313 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1314 REAL(kind=c_double), DIMENSION(*) :: source
1315 END SUBROUTINE torch_c_tensor_reset_from_array_double
1316 END INTERFACE
1317
1318 my_req_grad = .false.
1319 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1320
1321 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
1322 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
1323
1324 CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
1325 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1326 ndims=2, &
1327 sizes=sizes_c, &
1328 source=source)
1329 cpassert(c_associated(tensor%c_ptr))
1330#else
1331 cpabort("CP2K compiled without the Torch library.")
1332 mark_used(tensor)
1333 mark_used(source)
1334 mark_used(requires_grad)
1335#endif
1336 END SUBROUTINE torch_tensor_reset_from_array_double_2d
1337
1338
1339! **************************************************************************************************
1340!> \brief Reuses or creates a device leaf tensor and copies data into it.
1341!> The source must be an ALLOCATABLE to prevent passing a temporary array.
1342! **************************************************************************************************
1343 SUBROUTINE torch_tensor_reset_from_array_double_3d(tensor, source, requires_grad)
1344 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1345 REAL(dp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
1346 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
1347
1348#if defined(__LIBTORCH)
1349 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
1350 LOGICAL :: my_req_grad
1351
1352 INTERFACE
1353 SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
1354 BIND(C, name="torch_c_tensor_reset_from_array_double")
1355 IMPORT :: c_ptr, c_int, c_int64_t, c_double, c_bool
1356 TYPE(c_ptr) :: tensor
1357 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1358 INTEGER(kind=C_INT), VALUE :: ndims
1359 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
1360 REAL(kind=c_double), DIMENSION(*) :: source
1361 END SUBROUTINE torch_c_tensor_reset_from_array_double
1362 END INTERFACE
1363
1364 my_req_grad = .false.
1365 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
1366
1367 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
1368 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
1369 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
1370
1371 CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
1372 req_grad=LOGICAL(my_req_grad, C_BOOL), &
1373 ndims=3, &
1374 sizes=sizes_c, &
1375 source=source)
1376 cpassert(c_associated(tensor%c_ptr))
1377#else
1378 cpabort("CP2K compiled without the Torch library.")
1379 mark_used(tensor)
1380 mark_used(source)
1381 mark_used(requires_grad)
1382#endif
1383 END SUBROUTINE torch_tensor_reset_from_array_double_3d
1384
1385
1386! **************************************************************************************************
1387!> \brief Creates an expanded tensor view along one singleton dimension.
1388! **************************************************************************************************
1389 SUBROUTINE torch_tensor_expand_dim(tensor, dim, extent, result)
1390 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1391 INTEGER, INTENT(IN) :: dim, extent
1392 TYPE(torch_tensor_type), INTENT(INOUT) :: result
1393
1394#if defined(__LIBTORCH)
1395 INTERFACE
1396 SUBROUTINE torch_c_tensor_expand_dim(tensor, dim, extent, result) &
1397 BIND(C, name="torch_c_tensor_expand_dim")
1398 IMPORT :: c_int64_t, c_ptr
1399 TYPE(c_ptr), VALUE :: tensor
1400 INTEGER(kind=C_INT64_T), VALUE :: dim, extent
1401 TYPE(c_ptr) :: result
1402 END SUBROUTINE torch_c_tensor_expand_dim
1403 END INTERFACE
1404
1405 cpassert(c_associated(tensor%c_ptr))
1406 cpassert(.NOT. c_associated(result%c_ptr))
1407 cpassert(dim >= 0)
1408 cpassert(extent >= 0)
1409 CALL torch_c_tensor_expand_dim(tensor=tensor%c_ptr, &
1410 dim=int(dim, c_int64_t), &
1411 extent=int(extent, c_int64_t), &
1412 result=result%c_ptr)
1413 cpassert(c_associated(result%c_ptr))
1414#else
1415 cpabort("CP2K compiled without the Torch library.")
1416 mark_used(tensor)
1417 mark_used(dim)
1418 mark_used(extent)
1419 mark_used(result)
1420#endif
1421 END SUBROUTINE torch_tensor_expand_dim
1422
1423! **************************************************************************************************
1424!> \brief Creates a view of a contiguous tensor slice.
1425! **************************************************************************************************
1426 SUBROUTINE torch_tensor_narrow(tensor, dim, start_index, length, result)
1427 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1428 INTEGER, INTENT(IN) :: dim, start_index, length
1429 TYPE(torch_tensor_type), INTENT(INOUT) :: result
1430
1431#if defined(__LIBTORCH)
1432 INTERFACE
1433 SUBROUTINE torch_c_tensor_narrow(tensor, dim, start_index, length, result) &
1434 BIND(C, name="torch_c_tensor_narrow")
1435 IMPORT :: c_int64_t, c_ptr
1436 TYPE(c_ptr), VALUE :: tensor
1437 INTEGER(kind=C_INT64_T), VALUE :: dim, start_index, length
1438 TYPE(c_ptr) :: result
1439 END SUBROUTINE torch_c_tensor_narrow
1440 END INTERFACE
1441
1442 cpassert(c_associated(tensor%c_ptr))
1443 cpassert(.NOT. c_associated(result%c_ptr))
1444 cpassert(dim >= 0)
1445 cpassert(start_index >= 0)
1446 cpassert(length >= 0)
1447 CALL torch_c_tensor_narrow(tensor=tensor%c_ptr, &
1448 dim=int(dim, c_int64_t), &
1449 start_index=int(start_index, c_int64_t), &
1450 length=int(length, c_int64_t), &
1451 result=result%c_ptr)
1452 cpassert(c_associated(result%c_ptr))
1453#else
1454 cpabort("CP2K compiled without the Torch library.")
1455 mark_used(tensor)
1456 mark_used(dim)
1457 mark_used(start_index)
1458 mark_used(length)
1459 mark_used(result)
1460#endif
1461 END SUBROUTINE torch_tensor_narrow
1462
1463! **************************************************************************************************
1464!> \brief Runs autograd on a Torch tensor.
1465!> \author Ole Schuett
1466! **************************************************************************************************
1467 SUBROUTINE torch_tensor_backward(tensor, outer_grad)
1468 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1469 TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
1470
1471#if defined(__LIBTORCH)
1472 CHARACTER(len=*), PARAMETER :: routinen = 'torch_tensor_backward'
1473 INTEGER :: handle
1474
1475 INTERFACE
1476 SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
1477 BIND(C, name="torch_c_tensor_backward")
1478 IMPORT :: c_char, c_ptr
1479 TYPE(c_ptr), VALUE :: tensor
1480 TYPE(c_ptr), VALUE :: outer_grad
1481 END SUBROUTINE torch_c_tensor_backward
1482 END INTERFACE
1483
1484 CALL timeset(routinen, handle)
1485 cpassert(c_associated(tensor%c_ptr))
1486 cpassert(c_associated(outer_grad%c_ptr))
1487 CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
1488 CALL timestop(handle)
1489#else
1490 cpabort("CP2K compiled without the Torch library.")
1491 mark_used(tensor)
1492 mark_used(outer_grad)
1493#endif
1494 END SUBROUTINE torch_tensor_backward
1495
1496! **************************************************************************************************
1497!> \brief Runs autograd on a scalar Torch tensor.
1498! **************************************************************************************************
1500 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1501
1502#if defined(__LIBTORCH)
1503 INTERFACE
1504 SUBROUTINE torch_c_tensor_backward_scalar(tensor) &
1505 BIND(C, name="torch_c_tensor_backward_scalar")
1506 IMPORT :: c_ptr
1507 TYPE(c_ptr), VALUE :: tensor
1508 END SUBROUTINE torch_c_tensor_backward_scalar
1509 END INTERFACE
1510
1511 cpassert(c_associated(tensor%c_ptr))
1512 CALL torch_c_tensor_backward_scalar(tensor=tensor%c_ptr)
1513#else
1514 cpabort("CP2K compiled without the Torch library.")
1515 mark_used(tensor)
1516#endif
1517 END SUBROUTINE torch_tensor_backward_scalar
1518
1519! **************************************************************************************************
1520!> \brief Moves a tensor to the active Torch device and makes it an autograd leaf.
1521! **************************************************************************************************
1522 SUBROUTINE torch_tensor_to_device_leaf(tensor, requires_grad)
1523 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1524 LOGICAL, INTENT(IN) :: requires_grad
1525
1526#if defined(__LIBTORCH)
1527 INTERFACE
1528 SUBROUTINE torch_c_tensor_to_device_leaf(tensor, req_grad) &
1529 BIND(C, name="torch_c_tensor_to_device_leaf")
1530 IMPORT :: c_bool, c_ptr
1531 TYPE(c_ptr) :: tensor
1532 LOGICAL(kind=C_BOOL), VALUE :: req_grad
1533 END SUBROUTINE torch_c_tensor_to_device_leaf
1534 END INTERFACE
1535
1536 cpassert(c_associated(tensor%c_ptr))
1537 CALL torch_c_tensor_to_device_leaf(tensor=tensor%c_ptr, &
1538 req_grad=LOGICAL(requires_grad, c_bool))
1539 cpassert(c_associated(tensor%c_ptr))
1540#else
1541 cpabort("CP2K compiled without the Torch library.")
1542 mark_used(tensor)
1543 mark_used(requires_grad)
1544#endif
1545 END SUBROUTINE torch_tensor_to_device_leaf
1546
1547! **************************************************************************************************
1548!> \brief Select whether Torch wrappers should use CUDA when available.
1549! **************************************************************************************************
1550 SUBROUTINE torch_use_cuda(use_cuda)
1551 LOGICAL, INTENT(IN) :: use_cuda
1552
1553#if defined(__LIBTORCH)
1554 INTERFACE
1555 SUBROUTINE torch_c_use_cuda(use_cuda) BIND(C, name="torch_c_use_cuda")
1556 IMPORT :: c_bool
1557 LOGICAL(kind=C_BOOL), VALUE :: use_cuda
1558 END SUBROUTINE torch_c_use_cuda
1559 END INTERFACE
1560
1561 CALL torch_c_use_cuda(use_cuda=LOGICAL(use_cuda, c_bool))
1562#else
1563 mark_used(use_cuda)
1564#endif
1565 END SUBROUTINE torch_use_cuda
1566
1567! **************************************************************************************************
1568!> \brief Returns the gradient of a Torch tensor which was computed by autograd.
1569!> \author Ole Schuett
1570! **************************************************************************************************
1571 SUBROUTINE torch_tensor_grad(tensor, grad)
1572 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1573 TYPE(torch_tensor_type), INTENT(INOUT) :: grad
1574
1575#if defined(__LIBTORCH)
1576 INTERFACE
1577 SUBROUTINE torch_c_tensor_grad(tensor, grad) &
1578 BIND(C, name="torch_c_tensor_grad")
1579 IMPORT :: c_ptr
1580 TYPE(c_ptr), VALUE :: tensor
1581 TYPE(c_ptr) :: grad
1582 END SUBROUTINE torch_c_tensor_grad
1583 END INTERFACE
1584
1585 cpassert(c_associated(tensor%c_ptr))
1586 cpassert(.NOT. c_associated(grad%c_ptr))
1587 CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
1588 cpassert(c_associated(grad%c_ptr))
1589#else
1590 cpabort("CP2K compiled without the Torch library.")
1591 mark_used(tensor)
1592 mark_used(grad)
1593#endif
1594 END SUBROUTINE torch_tensor_grad
1595
1596! **************************************************************************************************
1597!> \brief Returns the weighted sum of two Torch tensors.
1598! **************************************************************************************************
1599 SUBROUTINE torch_tensor_weighted_sum(values, weights, result)
1600 TYPE(torch_tensor_type), INTENT(IN) :: values, weights
1601 TYPE(torch_tensor_type), INTENT(INOUT) :: result
1602
1603#if defined(__LIBTORCH)
1604 INTERFACE
1605 SUBROUTINE torch_c_tensor_weighted_sum(values, weights, result) &
1606 BIND(C, name="torch_c_tensor_weighted_sum")
1607 IMPORT :: c_ptr
1608 TYPE(c_ptr), VALUE :: values
1609 TYPE(c_ptr), VALUE :: weights
1610 TYPE(c_ptr) :: result
1611 END SUBROUTINE torch_c_tensor_weighted_sum
1612 END INTERFACE
1613
1614 cpassert(c_associated(values%c_ptr))
1615 cpassert(c_associated(weights%c_ptr))
1616 cpassert(.NOT. c_associated(result%c_ptr))
1617 CALL torch_c_tensor_weighted_sum(values=values%c_ptr, weights=weights%c_ptr, result=result%c_ptr)
1618 cpassert(c_associated(result%c_ptr))
1619#else
1620 cpabort("CP2K compiled without the Torch library.")
1621 mark_used(values)
1622 mark_used(weights)
1623 mark_used(result)
1624#endif
1625 END SUBROUTINE torch_tensor_weighted_sum
1626
1627! **************************************************************************************************
1628!> \brief Returns a scalar double value from a Torch tensor.
1629! **************************************************************************************************
1630 FUNCTION torch_tensor_item_double(tensor) RESULT(value)
1631 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1632 REAL(kind=dp) :: value
1633
1634#if defined(__LIBTORCH)
1635 INTERFACE
1636 FUNCTION torch_c_tensor_item_double(tensor) RESULT(value) &
1637 BIND(C, name="torch_c_tensor_item_double")
1638 IMPORT :: c_double, c_ptr
1639 TYPE(c_ptr), VALUE :: tensor
1640 REAL(kind=c_double) :: value
1641 END FUNCTION torch_c_tensor_item_double
1642 END INTERFACE
1643
1644 cpassert(c_associated(tensor%c_ptr))
1645 value = torch_c_tensor_item_double(tensor=tensor%c_ptr)
1646#else
1647 value = 0.0_dp
1648 cpabort("CP2K compiled without the Torch library.")
1649 mark_used(tensor)
1650#endif
1651 END FUNCTION torch_tensor_item_double
1652
1653! **************************************************************************************************
1654!> \brief Releases a Torch tensor and all its ressources.
1655!> \author Ole Schuett
1656! **************************************************************************************************
1657 SUBROUTINE torch_tensor_release(tensor)
1658 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1659
1660#if defined(__LIBTORCH)
1661 INTERFACE
1662 SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
1663 IMPORT :: c_ptr
1664 TYPE(c_ptr), VALUE :: tensor
1665 END SUBROUTINE torch_c_tensor_release
1666 END INTERFACE
1667
1668 cpassert(c_associated(tensor%c_ptr))
1669 CALL torch_c_tensor_release(tensor=tensor%c_ptr)
1670 tensor%c_ptr = c_null_ptr
1671#else
1672 cpabort("CP2K was compiled without Torch library.")
1673 mark_used(tensor)
1674#endif
1675 END SUBROUTINE torch_tensor_release
1676
1677! **************************************************************************************************
1678!> \brief Creates an empty Torch dictionary.
1679!> \author Ole Schuett
1680! **************************************************************************************************
1681 SUBROUTINE torch_dict_create(dict)
1682 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1683
1684#if defined(__LIBTORCH)
1685 INTERFACE
1686 SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
1687 IMPORT :: c_ptr
1688 TYPE(c_ptr) :: dict
1689 END SUBROUTINE torch_c_dict_create
1690 END INTERFACE
1691
1692 cpassert(.NOT. c_associated(dict%c_ptr))
1693 CALL torch_c_dict_create(dict=dict%c_ptr)
1694 cpassert(c_associated(dict%c_ptr))
1695#else
1696 cpabort("CP2K was compiled without Torch library.")
1697 mark_used(dict)
1698#endif
1699 END SUBROUTINE torch_dict_create
1700
1701! **************************************************************************************************
1702!> \brief Clones a Torch dictionary.
1703! **************************************************************************************************
1704 SUBROUTINE torch_dict_clone(source, target)
1705 TYPE(torch_dict_type), INTENT(IN) :: source
1706 TYPE(torch_dict_type), INTENT(INOUT) :: target
1707
1708#if defined(__LIBTORCH)
1709 INTERFACE
1710 SUBROUTINE torch_c_dict_clone(source, target) BIND(C, name="torch_c_dict_clone")
1711 IMPORT :: c_ptr
1712 TYPE(c_ptr), VALUE :: source
1713 TYPE(c_ptr) :: target
1714 END SUBROUTINE torch_c_dict_clone
1715 END INTERFACE
1716
1717 cpassert(c_associated(source%c_ptr))
1718 cpassert(.NOT. c_associated(target%c_ptr))
1719 CALL torch_c_dict_clone(source=source%c_ptr, target=target%c_ptr)
1720 cpassert(c_associated(target%c_ptr))
1721#else
1722 cpabort("CP2K was compiled without Torch library.")
1723 mark_used(source)
1724 mark_used(target)
1725#endif
1726 END SUBROUTINE torch_dict_clone
1727
1728! **************************************************************************************************
1729!> \brief Inserts a Torch tensor into a Torch dictionary.
1730!> \author Ole Schuett
1731! **************************************************************************************************
1732 SUBROUTINE torch_dict_insert(dict, key, tensor)
1733 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1734 CHARACTER(len=*), INTENT(IN) :: key
1735 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1736
1737#if defined(__LIBTORCH)
1738
1739 INTERFACE
1740 SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
1741 BIND(C, name="torch_c_dict_insert")
1742 IMPORT :: c_char, c_ptr
1743 TYPE(c_ptr), VALUE :: dict
1744 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1745 TYPE(c_ptr), VALUE :: tensor
1746 END SUBROUTINE torch_c_dict_insert
1747 END INTERFACE
1748
1749 cpassert(c_associated(dict%c_ptr))
1750 cpassert(c_associated(tensor%c_ptr))
1751 CALL torch_c_dict_insert(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1752#else
1753 cpabort("CP2K compiled without the Torch library.")
1754 mark_used(dict)
1755 mark_used(key)
1756 mark_used(tensor)
1757#endif
1758 END SUBROUTINE torch_dict_insert
1759
1760! **************************************************************************************************
1761!> \brief Retrieves a Torch tensor from a Torch dictionary.
1762!> \author Ole Schuett
1763! **************************************************************************************************
1764 SUBROUTINE torch_dict_get(dict, key, tensor)
1765 TYPE(torch_dict_type), INTENT(IN) :: dict
1766 CHARACTER(len=*), INTENT(IN) :: key
1767 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1768
1769#if defined(__LIBTORCH)
1770
1771 INTERFACE
1772 SUBROUTINE torch_c_dict_get(dict, key, tensor) &
1773 BIND(C, name="torch_c_dict_get")
1774 IMPORT :: c_char, c_ptr
1775 TYPE(c_ptr), VALUE :: dict
1776 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1777 TYPE(c_ptr) :: tensor
1778 END SUBROUTINE torch_c_dict_get
1779 END INTERFACE
1780
1781 cpassert(c_associated(dict%c_ptr))
1782 cpassert(.NOT. c_associated(tensor%c_ptr))
1783 CALL torch_c_dict_get(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1784 cpassert(c_associated(tensor%c_ptr))
1785
1786#else
1787 cpabort("CP2K compiled without the Torch library.")
1788 mark_used(dict)
1789 mark_used(key)
1790 mark_used(tensor)
1791#endif
1792 END SUBROUTINE torch_dict_get
1793
1794! **************************************************************************************************
1795!> \brief Releases a Torch dictionary and all its ressources.
1796!> \author Ole Schuett
1797! **************************************************************************************************
1798 SUBROUTINE torch_dict_release(dict)
1799 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1800
1801#if defined(__LIBTORCH)
1802 INTERFACE
1803 SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
1804 IMPORT :: c_ptr
1805 TYPE(c_ptr), VALUE :: dict
1806 END SUBROUTINE torch_c_dict_release
1807 END INTERFACE
1808
1809 cpassert(c_associated(dict%c_ptr))
1810 CALL torch_c_dict_release(dict=dict%c_ptr)
1811 dict%c_ptr = c_null_ptr
1812#else
1813 cpabort("CP2K was compiled without Torch library.")
1814 mark_used(dict)
1815#endif
1816 END SUBROUTINE torch_dict_release
1817
1818! **************************************************************************************************
1819!> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
1820!> \author Ole Schuett
1821! **************************************************************************************************
1822 SUBROUTINE torch_model_load(model, filename)
1823 TYPE(torch_model_type), INTENT(INOUT) :: model
1824 CHARACTER(len=*), INTENT(IN) :: filename
1825
1826#if defined(__LIBTORCH)
1827 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_load'
1828 INTEGER :: handle
1829
1830 INTERFACE
1831 SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
1832 IMPORT :: c_ptr, c_char
1833 TYPE(c_ptr) :: model
1834 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
1835 END SUBROUTINE torch_c_model_load
1836 END INTERFACE
1837
1838 CALL timeset(routinen, handle)
1839 cpassert(.NOT. c_associated(model%c_ptr))
1840 CALL torch_c_model_load(model=model%c_ptr, filename=trim(filename)//c_null_char)
1841 cpassert(c_associated(model%c_ptr))
1842 CALL timestop(handle)
1843#else
1844 cpabort("CP2K was compiled without Torch library.")
1845 mark_used(model)
1846 mark_used(filename)
1847#endif
1848 END SUBROUTINE torch_model_load
1849
1850! **************************************************************************************************
1851!> \brief Evaluates the given Torch model.
1852!> \author Ole Schuett
1853! **************************************************************************************************
1854 SUBROUTINE torch_model_forward(model, inputs, outputs)
1855 TYPE(torch_model_type), INTENT(INOUT) :: model
1856 TYPE(torch_dict_type), INTENT(IN) :: inputs
1857 TYPE(torch_dict_type), INTENT(INOUT) :: outputs
1858
1859#if defined(__LIBTORCH)
1860 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_forward'
1861 INTEGER :: handle
1862
1863 INTERFACE
1864 SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
1865 IMPORT :: c_ptr
1866 TYPE(c_ptr), VALUE :: model
1867 TYPE(c_ptr), VALUE :: inputs
1868 TYPE(c_ptr), VALUE :: outputs
1869 END SUBROUTINE torch_c_model_forward
1870 END INTERFACE
1871
1872 CALL timeset(routinen, handle)
1873 cpassert(c_associated(model%c_ptr))
1874 cpassert(c_associated(inputs%c_ptr))
1875 cpassert(c_associated(outputs%c_ptr))
1876 CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
1877 CALL timestop(handle)
1878#else
1879 cpabort("CP2K was compiled without Torch library.")
1880 mark_used(model)
1881 mark_used(inputs)
1882 mark_used(outputs)
1883#endif
1884 END SUBROUTINE torch_model_forward
1885
1886! **************************************************************************************************
1887!> \brief Evaluates a TorchScript model method expecting keyword argument "mol".
1888! **************************************************************************************************
1889 SUBROUTINE torch_model_forward_mol_tensor(model, method_name, inputs, output)
1890 TYPE(torch_model_type), INTENT(INOUT) :: model
1891 CHARACTER(len=*), INTENT(IN) :: method_name
1892 TYPE(torch_dict_type), INTENT(IN) :: inputs
1893 TYPE(torch_tensor_type), INTENT(INOUT) :: output
1894
1895#if defined(__LIBTORCH)
1896 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_forward_mol_tensor'
1897 INTEGER :: handle
1898
1899 INTERFACE
1900 SUBROUTINE torch_c_model_forward_mol_tensor(model, method_name, inputs, output) &
1901 BIND(C, name="torch_c_model_forward_mol_tensor")
1902 IMPORT :: c_char, c_ptr
1903 TYPE(c_ptr), VALUE :: model
1904 CHARACTER(kind=C_CHAR), DIMENSION(*) :: method_name
1905 TYPE(c_ptr), VALUE :: inputs
1906 TYPE(c_ptr) :: output
1907 END SUBROUTINE torch_c_model_forward_mol_tensor
1908 END INTERFACE
1909
1910 CALL timeset(routinen, handle)
1911 cpassert(c_associated(model%c_ptr))
1912 cpassert(c_associated(inputs%c_ptr))
1913 cpassert(.NOT. c_associated(output%c_ptr))
1914 CALL torch_c_model_forward_mol_tensor(model=model%c_ptr, &
1915 method_name=trim(method_name)//c_null_char, &
1916 inputs=inputs%c_ptr, &
1917 output=output%c_ptr)
1918 cpassert(c_associated(output%c_ptr))
1919 CALL timestop(handle)
1920#else
1921 cpabort("CP2K was compiled without Torch library.")
1922 mark_used(model)
1923 mark_used(method_name)
1924 mark_used(inputs)
1925 mark_used(output)
1926#endif
1927 END SUBROUTINE torch_model_forward_mol_tensor
1928
1929! **************************************************************************************************
1930!> \brief Releases a Torch model and all its ressources.
1931!> \author Ole Schuett
1932! **************************************************************************************************
1933 SUBROUTINE torch_model_release(model)
1934 TYPE(torch_model_type), INTENT(INOUT) :: model
1935
1936#if defined(__LIBTORCH)
1937 INTERFACE
1938 SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
1939 IMPORT :: c_ptr
1940 TYPE(c_ptr), VALUE :: model
1941 END SUBROUTINE torch_c_model_release
1942 END INTERFACE
1943
1944 cpassert(c_associated(model%c_ptr))
1945 CALL torch_c_model_release(model=model%c_ptr)
1946 model%c_ptr = c_null_ptr
1947#else
1948 cpabort("CP2K was compiled without Torch library.")
1949 mark_used(model)
1950#endif
1951 END SUBROUTINE torch_model_release
1952
1953! **************************************************************************************************
1954!> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
1955!> \author Ole Schuett
1956! **************************************************************************************************
1957 FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
1958 CHARACTER(len=*), INTENT(IN) :: filename, key
1959 CHARACTER(:), ALLOCATABLE :: res
1960
1961#if defined(__LIBTORCH)
1962 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_read_metadata'
1963 INTEGER :: handle
1964
1965 INTEGER :: length
1966 TYPE(c_ptr) :: content_c
1967
1968 INTERFACE
1969 SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
1970 BIND(C, name="torch_c_model_read_metadata")
1971 IMPORT :: c_char, c_ptr, c_int
1972 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
1973 TYPE(c_ptr) :: content
1974 INTEGER(kind=C_INT) :: length
1975 END SUBROUTINE torch_c_model_read_metadata
1976 END INTERFACE
1977
1978 CALL timeset(routinen, handle)
1979 content_c = c_null_ptr
1980 length = -1
1981 CALL torch_c_model_read_metadata(filename=trim(filename)//c_null_char, &
1982 key=trim(key)//c_null_char, &
1983 content=content_c, &
1984 length=length)
1985 CALL c_string_to_allocatable(content_c, length, res)
1986 CALL timestop(handle)
1987#else
1988 res = ""
1989 mark_used(filename)
1990 mark_used(key)
1991 cpabort("CP2K was compiled without Torch library.")
1992#endif
1993 END FUNCTION torch_model_read_metadata
1994
1995! **************************************************************************************************
1996!> \brief Move a C-allocated null-terminated string into an allocatable Fortran string.
1997! **************************************************************************************************
1998 SUBROUTINE c_string_to_allocatable(content_c, length, res)
1999 TYPE(c_ptr), INTENT(INOUT) :: content_c
2000 INTEGER, INTENT(IN) :: length
2001 CHARACTER(:), ALLOCATABLE, INTENT(OUT) :: res
2002
2003#if defined(__LIBTORCH)
2004 CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
2005 POINTER :: content_f
2006 INTEGER :: i
2007
2008 INTERFACE
2009 SUBROUTINE torch_c_free_string(content) BIND(C, name="torch_c_free_string")
2010 IMPORT :: c_ptr
2011 TYPE(c_ptr), VALUE :: content
2012 END SUBROUTINE torch_c_free_string
2013 END INTERFACE
2014
2015 cpassert(c_associated(content_c))
2016 cpassert(length >= 0)
2017
2018 CALL c_f_pointer(content_c, content_f, shape=[length + 1])
2019 cpassert(content_f(length + 1) == c_null_char)
2020
2021 ALLOCATE (CHARACTER(LEN=length) :: res)
2022 DO i = 1, length
2023 cpassert(content_f(i) /= c_null_char)
2024 res(i:i) = content_f(i)
2025 END DO
2026
2027 NULLIFY (content_f)
2028 CALL torch_c_free_string(content_c)
2029 content_c = c_null_ptr
2030
2031#else
2032 res = ""
2033 mark_used(content_c)
2034 mark_used(length)
2035 cpabort("CP2K was compiled without Torch library.")
2036#endif
2037 END SUBROUTINE c_string_to_allocatable
2038
2039! **************************************************************************************************
2040!> \brief Returns true iff the Torch CUDA backend is available.
2041!> \author Ole Schuett
2042! **************************************************************************************************
2043 FUNCTION torch_cuda_is_available() RESULT(res)
2044 LOGICAL :: res
2045
2046#if defined(__LIBTORCH)
2047 INTERFACE
2048 FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
2049 IMPORT :: c_bool
2050 LOGICAL(C_BOOL) :: torch_c_cuda_is_available
2051 END FUNCTION torch_c_cuda_is_available
2052 END INTERFACE
2053
2054 res = torch_c_cuda_is_available()
2055#else
2056 cpabort("CP2K was compiled without Torch library.")
2057 res = .false.
2058#endif
2059 END FUNCTION torch_cuda_is_available
2060
2061! **************************************************************************************************
2062!> \brief Return the number of CUDA devices visible to Torch.
2063! **************************************************************************************************
2064 FUNCTION torch_cuda_device_count() RESULT(count)
2065 INTEGER :: count
2066
2067#if defined(__LIBTORCH)
2068 INTERFACE
2069 FUNCTION torch_c_cuda_device_count() BIND(C, name="torch_c_cuda_device_count")
2070 IMPORT :: c_int
2071 INTEGER(C_INT) :: torch_c_cuda_device_count
2072 END FUNCTION torch_c_cuda_device_count
2073 END INTERFACE
2074
2075 count = torch_c_cuda_device_count()
2076#else
2077 cpabort("CP2K was compiled without Torch library.")
2078 count = 0
2079#endif
2080 END FUNCTION torch_cuda_device_count
2081
2082! **************************************************************************************************
2083!> \brief Set whether to allow the use of TF32.
2084!> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
2085!> See https://pytorch.org/docs/stable/notes/cuda.html
2086!> \author Gabriele Tocci
2087! **************************************************************************************************
2088 SUBROUTINE torch_allow_tf32(allow_tf32)
2089 LOGICAL, INTENT(IN) :: allow_tf32
2090
2091#if defined(__LIBTORCH)
2092 INTERFACE
2093 SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
2094 IMPORT :: c_bool
2095 LOGICAL(C_BOOL), VALUE :: allow_tf32
2096 END SUBROUTINE torch_c_allow_tf32
2097 END INTERFACE
2098
2099 CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, c_bool))
2100#else
2101 cpabort("CP2K was compiled without Torch library.")
2102 mark_used(allow_tf32)
2103#endif
2104 END SUBROUTINE torch_allow_tf32
2105
2106! **************************************************************************************************
2107!> \brief Freeze the given Torch model: applies generic optimization that speed up model.
2108!> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
2109!> \author Gabriele Tocci
2110! **************************************************************************************************
2111 SUBROUTINE torch_model_freeze(model)
2112 TYPE(torch_model_type), INTENT(INOUT) :: model
2113
2114#if defined(__LIBTORCH)
2115 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_freeze'
2116 INTEGER :: handle
2117
2118 INTERFACE
2119 SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
2120 IMPORT :: c_ptr
2121 TYPE(c_ptr), VALUE :: model
2122 END SUBROUTINE torch_c_model_freeze
2123 END INTERFACE
2124
2125 CALL timeset(routinen, handle)
2126 cpassert(c_associated(model%c_ptr))
2127 CALL torch_c_model_freeze(model=model%c_ptr)
2128 CALL timestop(handle)
2129#else
2130 cpabort("CP2K was compiled without Torch library.")
2131 mark_used(model)
2132#endif
2133 END SUBROUTINE torch_model_freeze
2134
2135
2136! **************************************************************************************************
2137!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
2138!> \author Ole Schuett
2139! **************************************************************************************************
2140 SUBROUTINE torch_model_get_attr_int64 (model, key, dest)
2141 TYPE(torch_model_type), INTENT(IN) :: model
2142 CHARACTER(len=*), INTENT(IN) :: key
2143 INTEGER(kind=int_8), INTENT(OUT) :: dest
2144
2145#if defined(__LIBTORCH)
2146
2147 INTERFACE
2148 SUBROUTINE torch_c_model_get_attr_int64 (model, key, dest) &
2149 BIND(C, name="torch_c_model_get_attr_int64")
2150 IMPORT :: c_ptr, c_char, c_int64_t, c_double
2151 TYPE(c_ptr), VALUE :: model
2152 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
2153 INTEGER(kind=C_INT64_T) :: dest
2154 END SUBROUTINE torch_c_model_get_attr_int64
2155 END INTERFACE
2156
2157 CALL torch_c_model_get_attr_int64 (model=model%c_ptr, &
2158 key=trim(key)//c_null_char, &
2159 dest=dest)
2160#else
2161 dest = 0
2162 mark_used(model)
2163 mark_used(key)
2164 cpabort("CP2K compiled without the Torch library.")
2165#endif
2166 END SUBROUTINE torch_model_get_attr_int64
2167! **************************************************************************************************
2168!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
2169!> \author Ole Schuett
2170! **************************************************************************************************
2171 SUBROUTINE torch_model_get_attr_double (model, key, dest)
2172 TYPE(torch_model_type), INTENT(IN) :: model
2173 CHARACTER(len=*), INTENT(IN) :: key
2174 REAL(dp), INTENT(OUT) :: dest
2175
2176#if defined(__LIBTORCH)
2177
2178 INTERFACE
2179 SUBROUTINE torch_c_model_get_attr_double (model, key, dest) &
2180 BIND(C, name="torch_c_model_get_attr_double")
2181 IMPORT :: c_ptr, c_char, c_int64_t, c_double
2182 TYPE(c_ptr), VALUE :: model
2183 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
2184 REAL(kind=c_double) :: dest
2185 END SUBROUTINE torch_c_model_get_attr_double
2186 END INTERFACE
2187
2188 CALL torch_c_model_get_attr_double (model=model%c_ptr, &
2189 key=trim(key)//c_null_char, &
2190 dest=dest)
2191#else
2192 dest = 0.0_dp
2193 mark_used(model)
2194 mark_used(key)
2195 cpabort("CP2K compiled without the Torch library.")
2196#endif
2197 END SUBROUTINE torch_model_get_attr_double
2198! **************************************************************************************************
2199!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
2200!> \author Ole Schuett
2201! **************************************************************************************************
2202 SUBROUTINE torch_model_get_attr_string (model, key, dest)
2203 TYPE(torch_model_type), INTENT(IN) :: model
2204 CHARACTER(len=*), INTENT(IN) :: key
2205 CHARACTER(LEN=default_string_length), INTENT(OUT) :: dest
2206
2207#if defined(__LIBTORCH)
2208
2209 INTERFACE
2210 SUBROUTINE torch_c_model_get_attr_string (model, key, dest) &
2211 BIND(C, name="torch_c_model_get_attr_string")
2212 IMPORT :: c_ptr, c_char, c_int64_t, c_double
2213 TYPE(c_ptr), VALUE :: model
2214 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
2215 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
2216 END SUBROUTINE torch_c_model_get_attr_string
2217 END INTERFACE
2218
2219 CALL torch_c_model_get_attr_string (model=model%c_ptr, &
2220 key=trim(key)//c_null_char, &
2221 dest=dest)
2222#else
2223 dest = ""
2224 mark_used(model)
2225 mark_used(key)
2226 cpabort("CP2K compiled without the Torch library.")
2227#endif
2228 END SUBROUTINE torch_model_get_attr_string
2229
2230! **************************************************************************************************
2231!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
2232!> \author Ole Schuett
2233! **************************************************************************************************
2234 SUBROUTINE torch_model_get_attr_int32(model, key, dest)
2235 TYPE(torch_model_type), INTENT(IN) :: model
2236 CHARACTER(len=*), INTENT(IN) :: key
2237 INTEGER, INTENT(OUT) :: dest
2238
2239 INTEGER(kind=int_8) :: temp
2240 CALL torch_model_get_attr_int64(model, key, temp)
2241 cpassert(abs(temp) < huge(dest))
2242 dest = int(temp)
2243 END SUBROUTINE torch_model_get_attr_int32
2244
2245! **************************************************************************************************
2246!> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
2247!> \author Ole Schuett
2248! **************************************************************************************************
2249 SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
2250 TYPE(torch_model_type), INTENT(IN) :: model
2251 CHARACTER(len=*), INTENT(IN) :: key
2252 CHARACTER(LEN=default_string_length), &
2253 ALLOCATABLE, DIMENSION(:) :: dest
2254
2255#if defined(__LIBTORCH)
2256
2257 INTEGER :: num_items, i
2258
2259 INTERFACE
2260 SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
2261 BIND(C, name="torch_c_model_get_attr_list_size")
2262 IMPORT :: c_ptr, c_char, c_int
2263 TYPE(c_ptr), VALUE :: model
2264 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
2265 INTEGER(kind=C_INT) :: size
2266 END SUBROUTINE torch_c_model_get_attr_list_size
2267 END INTERFACE
2268
2269 INTERFACE
2270 SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
2271 BIND(C, name="torch_c_model_get_attr_strlist")
2272 IMPORT :: c_ptr, c_char, c_int
2273 TYPE(c_ptr), VALUE :: model
2274 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
2275 INTEGER(kind=C_INT), VALUE :: index
2276 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
2277 END SUBROUTINE torch_c_model_get_attr_strlist
2278 END INTERFACE
2279
2280 CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
2281 key=trim(key)//c_null_char, &
2282 size=num_items)
2283 ALLOCATE (dest(num_items))
2284 dest(:) = ""
2285
2286 DO i = 1, num_items
2287 CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
2288 key=trim(key)//c_null_char, &
2289 index=i - 1, &
2290 dest=dest(i))
2291
2292 END DO
2293#else
2294 cpabort("CP2K compiled without the Torch library.")
2295 mark_used(model)
2296 mark_used(key)
2297 mark_used(dest)
2298#endif
2299
2300 END SUBROUTINE torch_model_get_attr_strlist
2301
2302END MODULE torch_api
struct tensor_ tensor
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
integer, parameter, public sp
Definition kinds.F:33
integer, parameter, public int_4
Definition kinds.F:51
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1799
subroutine, public torch_tensor_backward(tensor, outer_grad)
Runs autograd on a Torch tensor.
Definition torch_api.F:1468
subroutine, public torch_use_cuda(use_cuda)
Select whether Torch wrappers should use CUDA when available.
Definition torch_api.F:1551
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
Definition torch_api.F:1765
real(kind=dp) function, public torch_tensor_item_double(tensor)
Returns a scalar double value from a Torch tensor.
Definition torch_api.F:1631
subroutine, public torch_tensor_backward_scalar(tensor)
Runs autograd on a scalar Torch tensor.
Definition torch_api.F:1500
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1823
subroutine, public torch_tensor_narrow(tensor, dim, start_index, length, result)
Creates a view of a contiguous tensor slice.
Definition torch_api.F:1427
subroutine, public torch_tensor_to_device_leaf(tensor, requires_grad)
Moves a tensor to the active Torch device and makes it an autograd leaf.
Definition torch_api.F:1523
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1682
subroutine, public torch_model_forward_mol_tensor(model, method_name, inputs, output)
Evaluates a TorchScript model method expecting keyword argument "mol".
Definition torch_api.F:1890
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition torch_api.F:1934
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
Definition torch_api.F:1572
subroutine, public torch_allow_tf32(allow_tf32)
Set whether to allow the use of TF32. Needed due to changes in defaults from pytorch 1....
Definition torch_api.F:2089
subroutine, public torch_tensor_weighted_sum(values, weights, result)
Returns the weighted sum of two Torch tensors.
Definition torch_api.F:1600
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
Definition torch_api.F:2112
integer function, public torch_cuda_device_count()
Return the number of CUDA devices visible to Torch.
Definition torch_api.F:2065
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition torch_api.F:1958
subroutine, public torch_tensor_expand_dim(tensor, dim, extent, result)
Creates an expanded tensor view along one singleton dimension.
Definition torch_api.F:1390
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1733
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition torch_api.F:2044
subroutine, public torch_dict_clone(source, target)
Clones a Torch dictionary.
Definition torch_api.F:1705
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:1658
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Definition torch_api.F:1855