(git:d18deda)
Loading...
Searching...
No Matches
torch_api.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
8 USE iso_c_binding, ONLY: c_associated, &
9 c_bool, &
10 c_char, &
11 c_float, &
12 c_double, &
13 c_f_pointer, &
14 c_int, &
15 c_null_char, &
16 c_null_ptr, &
17 c_ptr, &
18 c_int64_t
19
21
22#include "./base/base_uses.f90"
23
24 IMPLICIT NONE
25
26 PRIVATE
27
29 PRIVATE
30 TYPE(C_PTR) :: c_ptr = c_null_ptr
31 END TYPE torch_tensor_type
32
34 PRIVATE
35 TYPE(C_PTR) :: c_ptr = c_null_ptr
36 END TYPE torch_dict_type
37
39 PRIVATE
40 TYPE(C_PTR) :: c_ptr = c_null_ptr
41 END TYPE torch_model_type
42
44 MODULE PROCEDURE torch_tensor_from_array_float_1d
45 MODULE PROCEDURE torch_tensor_from_array_int64_1d
46 MODULE PROCEDURE torch_tensor_from_array_double_1d
47 MODULE PROCEDURE torch_tensor_from_array_float_2d
48 MODULE PROCEDURE torch_tensor_from_array_int64_2d
49 MODULE PROCEDURE torch_tensor_from_array_double_2d
50 MODULE PROCEDURE torch_tensor_from_array_float_3d
51 MODULE PROCEDURE torch_tensor_from_array_int64_3d
52 MODULE PROCEDURE torch_tensor_from_array_double_3d
53 END INTERFACE torch_tensor_from_array
54
56 MODULE PROCEDURE torch_tensor_data_ptr_float_1d
57 MODULE PROCEDURE torch_tensor_data_ptr_int64_1d
58 MODULE PROCEDURE torch_tensor_data_ptr_double_1d
59 MODULE PROCEDURE torch_tensor_data_ptr_float_2d
60 MODULE PROCEDURE torch_tensor_data_ptr_int64_2d
61 MODULE PROCEDURE torch_tensor_data_ptr_double_2d
62 MODULE PROCEDURE torch_tensor_data_ptr_float_3d
63 MODULE PROCEDURE torch_tensor_data_ptr_int64_3d
64 MODULE PROCEDURE torch_tensor_data_ptr_double_3d
65 END INTERFACE torch_tensor_data_ptr
66
68 MODULE PROCEDURE torch_model_get_attr_string
69 MODULE PROCEDURE torch_model_get_attr_double
70 MODULE PROCEDURE torch_model_get_attr_int64
71 MODULE PROCEDURE torch_model_get_attr_int32
72 MODULE PROCEDURE torch_model_get_attr_strlist
73 END INTERFACE torch_model_get_attr
74
81
82CONTAINS
83
84
85
86! **************************************************************************************************
87!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
88!> The source must be an ALLOCATABLE to prevent passing a temporary array.
89!> \author Ole Schuett
90! **************************************************************************************************
91 SUBROUTINE torch_tensor_from_array_float_1d(tensor, source, requires_grad)
92 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
93 REAL(sp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
94 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
95
96#if defined(__LIBTORCH)
97 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
98 LOGICAL :: my_req_grad
99
100 INTERFACE
101 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
102 BIND(C, name="torch_c_tensor_from_array_float")
103 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
104 TYPE(c_ptr) :: tensor
105 LOGICAL(kind=C_BOOL), VALUE :: req_grad
106 INTEGER(kind=C_INT), VALUE :: ndims
107 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
108 REAL(kind=c_float), DIMENSION(*) :: source
109 END SUBROUTINE torch_c_tensor_from_array_float
110 END INTERFACE
111
112 my_req_grad = .false.
113 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
114
115 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
116
117 cpassert(.NOT. c_associated(tensor%c_ptr))
118 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
119 req_grad=LOGICAL(my_req_grad, C_BOOL), &
120 ndims=1, &
121 sizes=sizes_c, &
122 source=source)
123 cpassert(c_associated(tensor%c_ptr))
124#else
125 cpabort("CP2K compiled without the Torch library.")
126 mark_used(tensor)
127 mark_used(source)
128 mark_used(requires_grad)
129#endif
130 END SUBROUTINE torch_tensor_from_array_float_1d
131
132! **************************************************************************************************
133!> \brief Copies data from a Torch tensor to an array.
134!> The returned pointer is only valide during the tensor's lifetime!
135!> \author Ole Schuett
136! **************************************************************************************************
137 SUBROUTINE torch_tensor_data_ptr_float_1d(tensor, data_ptr)
138 TYPE(torch_tensor_type), INTENT(IN) :: tensor
139 REAL(sp), DIMENSION(:), POINTER :: data_ptr
140
141#if defined(__LIBTORCH)
142 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
143 TYPE(c_ptr) :: data_ptr_c
144
145 INTERFACE
146 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
147 BIND(C, name="torch_c_tensor_data_ptr_float")
148 IMPORT :: c_char, c_ptr, c_int, c_int64_t
149 TYPE(c_ptr), VALUE :: tensor
150 INTEGER(kind=C_INT), VALUE :: ndims
151 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
152 TYPE(c_ptr) :: data_ptr
153 END SUBROUTINE torch_c_tensor_data_ptr_float
154 END INTERFACE
155
156 sizes_c(:) = -1
157 data_ptr_c = c_null_ptr
158 cpassert(c_associated(tensor%c_ptr))
159 cpassert(.NOT. ASSOCIATED(data_ptr))
160 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
161 ndims=1, &
162 sizes=sizes_c, &
163 data_ptr=data_ptr_c)
164
165 cpassert(all(sizes_c >= 0))
166 cpassert(c_associated(data_ptr_c))
167
168 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
169 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
170#else
171 cpabort("CP2K compiled without the Torch library.")
172 mark_used(tensor)
173 mark_used(data_ptr)
174#endif
175 END SUBROUTINE torch_tensor_data_ptr_float_1d
176
177
178! **************************************************************************************************
179!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
180!> The source must be an ALLOCATABLE to prevent passing a temporary array.
181!> \author Ole Schuett
182! **************************************************************************************************
183 SUBROUTINE torch_tensor_from_array_int64_1d(tensor, source, requires_grad)
184 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
185 INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
186 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
187
188#if defined(__LIBTORCH)
189 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
190 LOGICAL :: my_req_grad
191
192 INTERFACE
193 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
194 BIND(C, name="torch_c_tensor_from_array_int64")
195 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
196 TYPE(c_ptr) :: tensor
197 LOGICAL(kind=C_BOOL), VALUE :: req_grad
198 INTEGER(kind=C_INT), VALUE :: ndims
199 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
200 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
201 END SUBROUTINE torch_c_tensor_from_array_int64
202 END INTERFACE
203
204 my_req_grad = .false.
205 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
206
207 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
208
209 cpassert(.NOT. c_associated(tensor%c_ptr))
210 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
211 req_grad=LOGICAL(my_req_grad, C_BOOL), &
212 ndims=1, &
213 sizes=sizes_c, &
214 source=source)
215 cpassert(c_associated(tensor%c_ptr))
216#else
217 cpabort("CP2K compiled without the Torch library.")
218 mark_used(tensor)
219 mark_used(source)
220 mark_used(requires_grad)
221#endif
222 END SUBROUTINE torch_tensor_from_array_int64_1d
223
224! **************************************************************************************************
225!> \brief Copies data from a Torch tensor to an array.
226!> The returned pointer is only valide during the tensor's lifetime!
227!> \author Ole Schuett
228! **************************************************************************************************
229 SUBROUTINE torch_tensor_data_ptr_int64_1d(tensor, data_ptr)
230 TYPE(torch_tensor_type), INTENT(IN) :: tensor
231 INTEGER(kind=int_8), DIMENSION(:), POINTER :: data_ptr
232
233#if defined(__LIBTORCH)
234 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
235 TYPE(c_ptr) :: data_ptr_c
236
237 INTERFACE
238 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
239 BIND(C, name="torch_c_tensor_data_ptr_int64")
240 IMPORT :: c_char, c_ptr, c_int, c_int64_t
241 TYPE(c_ptr), VALUE :: tensor
242 INTEGER(kind=C_INT), VALUE :: ndims
243 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
244 TYPE(c_ptr) :: data_ptr
245 END SUBROUTINE torch_c_tensor_data_ptr_int64
246 END INTERFACE
247
248 sizes_c(:) = -1
249 data_ptr_c = c_null_ptr
250 cpassert(c_associated(tensor%c_ptr))
251 cpassert(.NOT. ASSOCIATED(data_ptr))
252 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
253 ndims=1, &
254 sizes=sizes_c, &
255 data_ptr=data_ptr_c)
256
257 cpassert(all(sizes_c >= 0))
258 cpassert(c_associated(data_ptr_c))
259
260 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
261 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
262#else
263 cpabort("CP2K compiled without the Torch library.")
264 mark_used(tensor)
265 mark_used(data_ptr)
266#endif
267 END SUBROUTINE torch_tensor_data_ptr_int64_1d
268
269
270! **************************************************************************************************
271!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
272!> The source must be an ALLOCATABLE to prevent passing a temporary array.
273!> \author Ole Schuett
274! **************************************************************************************************
275 SUBROUTINE torch_tensor_from_array_double_1d(tensor, source, requires_grad)
276 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
277 REAL(dp), DIMENSION(:), ALLOCATABLE, INTENT(IN) :: source
278 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
279
280#if defined(__LIBTORCH)
281 INTEGER(kind=int_8), DIMENSION(1) :: sizes_c
282 LOGICAL :: my_req_grad
283
284 INTERFACE
285 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
286 BIND(C, name="torch_c_tensor_from_array_double")
287 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
288 TYPE(c_ptr) :: tensor
289 LOGICAL(kind=C_BOOL), VALUE :: req_grad
290 INTEGER(kind=C_INT), VALUE :: ndims
291 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
292 REAL(kind=c_double), DIMENSION(*) :: source
293 END SUBROUTINE torch_c_tensor_from_array_double
294 END INTERFACE
295
296 my_req_grad = .false.
297 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
298
299 sizes_c(1) = SIZE(source, 1) ! C arrays are stored row-major.
300
301 cpassert(.NOT. c_associated(tensor%c_ptr))
302 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
303 req_grad=LOGICAL(my_req_grad, C_BOOL), &
304 ndims=1, &
305 sizes=sizes_c, &
306 source=source)
307 cpassert(c_associated(tensor%c_ptr))
308#else
309 cpabort("CP2K compiled without the Torch library.")
310 mark_used(tensor)
311 mark_used(source)
312 mark_used(requires_grad)
313#endif
314 END SUBROUTINE torch_tensor_from_array_double_1d
315
316! **************************************************************************************************
317!> \brief Copies data from a Torch tensor to an array.
318!> The returned pointer is only valide during the tensor's lifetime!
319!> \author Ole Schuett
320! **************************************************************************************************
321 SUBROUTINE torch_tensor_data_ptr_double_1d(tensor, data_ptr)
322 TYPE(torch_tensor_type), INTENT(IN) :: tensor
323 REAL(dp), DIMENSION(:), POINTER :: data_ptr
324
325#if defined(__LIBTORCH)
326 INTEGER(kind=int_8), DIMENSION(1) :: sizes_f, sizes_c
327 TYPE(c_ptr) :: data_ptr_c
328
329 INTERFACE
330 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
331 BIND(C, name="torch_c_tensor_data_ptr_double")
332 IMPORT :: c_char, c_ptr, c_int, c_int64_t
333 TYPE(c_ptr), VALUE :: tensor
334 INTEGER(kind=C_INT), VALUE :: ndims
335 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
336 TYPE(c_ptr) :: data_ptr
337 END SUBROUTINE torch_c_tensor_data_ptr_double
338 END INTERFACE
339
340 sizes_c(:) = -1
341 data_ptr_c = c_null_ptr
342 cpassert(c_associated(tensor%c_ptr))
343 cpassert(.NOT. ASSOCIATED(data_ptr))
344 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
345 ndims=1, &
346 sizes=sizes_c, &
347 data_ptr=data_ptr_c)
348
349 cpassert(all(sizes_c >= 0))
350 cpassert(c_associated(data_ptr_c))
351
352 sizes_f(1) = sizes_c(1) ! C arrays are stored row-major.
353 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
354#else
355 cpabort("CP2K compiled without the Torch library.")
356 mark_used(tensor)
357 mark_used(data_ptr)
358#endif
359 END SUBROUTINE torch_tensor_data_ptr_double_1d
360
361
362! **************************************************************************************************
363!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
364!> The source must be an ALLOCATABLE to prevent passing a temporary array.
365!> \author Ole Schuett
366! **************************************************************************************************
367 SUBROUTINE torch_tensor_from_array_float_2d(tensor, source, requires_grad)
368 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
369 REAL(sp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
370 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
371
372#if defined(__LIBTORCH)
373 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
374 LOGICAL :: my_req_grad
375
376 INTERFACE
377 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
378 BIND(C, name="torch_c_tensor_from_array_float")
379 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
380 TYPE(c_ptr) :: tensor
381 LOGICAL(kind=C_BOOL), VALUE :: req_grad
382 INTEGER(kind=C_INT), VALUE :: ndims
383 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
384 REAL(kind=c_float), DIMENSION(*) :: source
385 END SUBROUTINE torch_c_tensor_from_array_float
386 END INTERFACE
387
388 my_req_grad = .false.
389 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
390
391 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
392 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
393
394 cpassert(.NOT. c_associated(tensor%c_ptr))
395 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
396 req_grad=LOGICAL(my_req_grad, C_BOOL), &
397 ndims=2, &
398 sizes=sizes_c, &
399 source=source)
400 cpassert(c_associated(tensor%c_ptr))
401#else
402 cpabort("CP2K compiled without the Torch library.")
403 mark_used(tensor)
404 mark_used(source)
405 mark_used(requires_grad)
406#endif
407 END SUBROUTINE torch_tensor_from_array_float_2d
408
409! **************************************************************************************************
410!> \brief Copies data from a Torch tensor to an array.
411!> The returned pointer is only valide during the tensor's lifetime!
412!> \author Ole Schuett
413! **************************************************************************************************
414 SUBROUTINE torch_tensor_data_ptr_float_2d(tensor, data_ptr)
415 TYPE(torch_tensor_type), INTENT(IN) :: tensor
416 REAL(sp), DIMENSION(:, :), POINTER :: data_ptr
417
418#if defined(__LIBTORCH)
419 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
420 TYPE(c_ptr) :: data_ptr_c
421
422 INTERFACE
423 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
424 BIND(C, name="torch_c_tensor_data_ptr_float")
425 IMPORT :: c_char, c_ptr, c_int, c_int64_t
426 TYPE(c_ptr), VALUE :: tensor
427 INTEGER(kind=C_INT), VALUE :: ndims
428 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
429 TYPE(c_ptr) :: data_ptr
430 END SUBROUTINE torch_c_tensor_data_ptr_float
431 END INTERFACE
432
433 sizes_c(:) = -1
434 data_ptr_c = c_null_ptr
435 cpassert(c_associated(tensor%c_ptr))
436 cpassert(.NOT. ASSOCIATED(data_ptr))
437 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
438 ndims=2, &
439 sizes=sizes_c, &
440 data_ptr=data_ptr_c)
441
442 cpassert(all(sizes_c >= 0))
443 cpassert(c_associated(data_ptr_c))
444
445 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
446 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
447 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
448#else
449 cpabort("CP2K compiled without the Torch library.")
450 mark_used(tensor)
451 mark_used(data_ptr)
452#endif
453 END SUBROUTINE torch_tensor_data_ptr_float_2d
454
455
456! **************************************************************************************************
457!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
458!> The source must be an ALLOCATABLE to prevent passing a temporary array.
459!> \author Ole Schuett
460! **************************************************************************************************
461 SUBROUTINE torch_tensor_from_array_int64_2d(tensor, source, requires_grad)
462 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
463 INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
464 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
465
466#if defined(__LIBTORCH)
467 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
468 LOGICAL :: my_req_grad
469
470 INTERFACE
471 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
472 BIND(C, name="torch_c_tensor_from_array_int64")
473 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
474 TYPE(c_ptr) :: tensor
475 LOGICAL(kind=C_BOOL), VALUE :: req_grad
476 INTEGER(kind=C_INT), VALUE :: ndims
477 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
478 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
479 END SUBROUTINE torch_c_tensor_from_array_int64
480 END INTERFACE
481
482 my_req_grad = .false.
483 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
484
485 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
486 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
487
488 cpassert(.NOT. c_associated(tensor%c_ptr))
489 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
490 req_grad=LOGICAL(my_req_grad, C_BOOL), &
491 ndims=2, &
492 sizes=sizes_c, &
493 source=source)
494 cpassert(c_associated(tensor%c_ptr))
495#else
496 cpabort("CP2K compiled without the Torch library.")
497 mark_used(tensor)
498 mark_used(source)
499 mark_used(requires_grad)
500#endif
501 END SUBROUTINE torch_tensor_from_array_int64_2d
502
503! **************************************************************************************************
504!> \brief Copies data from a Torch tensor to an array.
505!> The returned pointer is only valide during the tensor's lifetime!
506!> \author Ole Schuett
507! **************************************************************************************************
508 SUBROUTINE torch_tensor_data_ptr_int64_2d(tensor, data_ptr)
509 TYPE(torch_tensor_type), INTENT(IN) :: tensor
510 INTEGER(kind=int_8), DIMENSION(:, :), POINTER :: data_ptr
511
512#if defined(__LIBTORCH)
513 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
514 TYPE(c_ptr) :: data_ptr_c
515
516 INTERFACE
517 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
518 BIND(C, name="torch_c_tensor_data_ptr_int64")
519 IMPORT :: c_char, c_ptr, c_int, c_int64_t
520 TYPE(c_ptr), VALUE :: tensor
521 INTEGER(kind=C_INT), VALUE :: ndims
522 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
523 TYPE(c_ptr) :: data_ptr
524 END SUBROUTINE torch_c_tensor_data_ptr_int64
525 END INTERFACE
526
527 sizes_c(:) = -1
528 data_ptr_c = c_null_ptr
529 cpassert(c_associated(tensor%c_ptr))
530 cpassert(.NOT. ASSOCIATED(data_ptr))
531 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
532 ndims=2, &
533 sizes=sizes_c, &
534 data_ptr=data_ptr_c)
535
536 cpassert(all(sizes_c >= 0))
537 cpassert(c_associated(data_ptr_c))
538
539 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
540 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
541 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
542#else
543 cpabort("CP2K compiled without the Torch library.")
544 mark_used(tensor)
545 mark_used(data_ptr)
546#endif
547 END SUBROUTINE torch_tensor_data_ptr_int64_2d
548
549
550! **************************************************************************************************
551!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
552!> The source must be an ALLOCATABLE to prevent passing a temporary array.
553!> \author Ole Schuett
554! **************************************************************************************************
555 SUBROUTINE torch_tensor_from_array_double_2d(tensor, source, requires_grad)
556 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
557 REAL(dp), DIMENSION(:, :), ALLOCATABLE, INTENT(IN) :: source
558 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
559
560#if defined(__LIBTORCH)
561 INTEGER(kind=int_8), DIMENSION(2) :: sizes_c
562 LOGICAL :: my_req_grad
563
564 INTERFACE
565 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
566 BIND(C, name="torch_c_tensor_from_array_double")
567 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
568 TYPE(c_ptr) :: tensor
569 LOGICAL(kind=C_BOOL), VALUE :: req_grad
570 INTEGER(kind=C_INT), VALUE :: ndims
571 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
572 REAL(kind=c_double), DIMENSION(*) :: source
573 END SUBROUTINE torch_c_tensor_from_array_double
574 END INTERFACE
575
576 my_req_grad = .false.
577 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
578
579 sizes_c(1) = SIZE(source, 2) ! C arrays are stored row-major.
580 sizes_c(2) = SIZE(source, 1) ! C arrays are stored row-major.
581
582 cpassert(.NOT. c_associated(tensor%c_ptr))
583 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
584 req_grad=LOGICAL(my_req_grad, C_BOOL), &
585 ndims=2, &
586 sizes=sizes_c, &
587 source=source)
588 cpassert(c_associated(tensor%c_ptr))
589#else
590 cpabort("CP2K compiled without the Torch library.")
591 mark_used(tensor)
592 mark_used(source)
593 mark_used(requires_grad)
594#endif
595 END SUBROUTINE torch_tensor_from_array_double_2d
596
597! **************************************************************************************************
598!> \brief Copies data from a Torch tensor to an array.
599!> The returned pointer is only valide during the tensor's lifetime!
600!> \author Ole Schuett
601! **************************************************************************************************
602 SUBROUTINE torch_tensor_data_ptr_double_2d(tensor, data_ptr)
603 TYPE(torch_tensor_type), INTENT(IN) :: tensor
604 REAL(dp), DIMENSION(:, :), POINTER :: data_ptr
605
606#if defined(__LIBTORCH)
607 INTEGER(kind=int_8), DIMENSION(2) :: sizes_f, sizes_c
608 TYPE(c_ptr) :: data_ptr_c
609
610 INTERFACE
611 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
612 BIND(C, name="torch_c_tensor_data_ptr_double")
613 IMPORT :: c_char, c_ptr, c_int, c_int64_t
614 TYPE(c_ptr), VALUE :: tensor
615 INTEGER(kind=C_INT), VALUE :: ndims
616 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
617 TYPE(c_ptr) :: data_ptr
618 END SUBROUTINE torch_c_tensor_data_ptr_double
619 END INTERFACE
620
621 sizes_c(:) = -1
622 data_ptr_c = c_null_ptr
623 cpassert(c_associated(tensor%c_ptr))
624 cpassert(.NOT. ASSOCIATED(data_ptr))
625 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
626 ndims=2, &
627 sizes=sizes_c, &
628 data_ptr=data_ptr_c)
629
630 cpassert(all(sizes_c >= 0))
631 cpassert(c_associated(data_ptr_c))
632
633 sizes_f(1) = sizes_c(2) ! C arrays are stored row-major.
634 sizes_f(2) = sizes_c(1) ! C arrays are stored row-major.
635 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
636#else
637 cpabort("CP2K compiled without the Torch library.")
638 mark_used(tensor)
639 mark_used(data_ptr)
640#endif
641 END SUBROUTINE torch_tensor_data_ptr_double_2d
642
643
644! **************************************************************************************************
645!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
646!> The source must be an ALLOCATABLE to prevent passing a temporary array.
647!> \author Ole Schuett
648! **************************************************************************************************
649 SUBROUTINE torch_tensor_from_array_float_3d(tensor, source, requires_grad)
650 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
651 REAL(sp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
652 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
653
654#if defined(__LIBTORCH)
655 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
656 LOGICAL :: my_req_grad
657
658 INTERFACE
659 SUBROUTINE torch_c_tensor_from_array_float (tensor, req_grad, ndims, sizes, source) &
660 BIND(C, name="torch_c_tensor_from_array_float")
661 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
662 TYPE(c_ptr) :: tensor
663 LOGICAL(kind=C_BOOL), VALUE :: req_grad
664 INTEGER(kind=C_INT), VALUE :: ndims
665 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
666 REAL(kind=c_float), DIMENSION(*) :: source
667 END SUBROUTINE torch_c_tensor_from_array_float
668 END INTERFACE
669
670 my_req_grad = .false.
671 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
672
673 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
674 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
675 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
676
677 cpassert(.NOT. c_associated(tensor%c_ptr))
678 CALL torch_c_tensor_from_array_float (tensor=tensor%c_ptr, &
679 req_grad=LOGICAL(my_req_grad, C_BOOL), &
680 ndims=3, &
681 sizes=sizes_c, &
682 source=source)
683 cpassert(c_associated(tensor%c_ptr))
684#else
685 cpabort("CP2K compiled without the Torch library.")
686 mark_used(tensor)
687 mark_used(source)
688 mark_used(requires_grad)
689#endif
690 END SUBROUTINE torch_tensor_from_array_float_3d
691
692! **************************************************************************************************
693!> \brief Copies data from a Torch tensor to an array.
694!> The returned pointer is only valide during the tensor's lifetime!
695!> \author Ole Schuett
696! **************************************************************************************************
697 SUBROUTINE torch_tensor_data_ptr_float_3d(tensor, data_ptr)
698 TYPE(torch_tensor_type), INTENT(IN) :: tensor
699 REAL(sp), DIMENSION(:, :, :), POINTER :: data_ptr
700
701#if defined(__LIBTORCH)
702 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
703 TYPE(c_ptr) :: data_ptr_c
704
705 INTERFACE
706 SUBROUTINE torch_c_tensor_data_ptr_float (tensor, ndims, sizes, data_ptr) &
707 BIND(C, name="torch_c_tensor_data_ptr_float")
708 IMPORT :: c_char, c_ptr, c_int, c_int64_t
709 TYPE(c_ptr), VALUE :: tensor
710 INTEGER(kind=C_INT), VALUE :: ndims
711 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
712 TYPE(c_ptr) :: data_ptr
713 END SUBROUTINE torch_c_tensor_data_ptr_float
714 END INTERFACE
715
716 sizes_c(:) = -1
717 data_ptr_c = c_null_ptr
718 cpassert(c_associated(tensor%c_ptr))
719 cpassert(.NOT. ASSOCIATED(data_ptr))
720 CALL torch_c_tensor_data_ptr_float (tensor=tensor%c_ptr, &
721 ndims=3, &
722 sizes=sizes_c, &
723 data_ptr=data_ptr_c)
724
725 cpassert(all(sizes_c >= 0))
726 cpassert(c_associated(data_ptr_c))
727
728 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
729 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
730 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
731 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
732#else
733 cpabort("CP2K compiled without the Torch library.")
734 mark_used(tensor)
735 mark_used(data_ptr)
736#endif
737 END SUBROUTINE torch_tensor_data_ptr_float_3d
738
739
740! **************************************************************************************************
741!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
742!> The source must be an ALLOCATABLE to prevent passing a temporary array.
743!> \author Ole Schuett
744! **************************************************************************************************
745 SUBROUTINE torch_tensor_from_array_int64_3d(tensor, source, requires_grad)
746 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
747 INTEGER(kind=int_8), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
748 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
749
750#if defined(__LIBTORCH)
751 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
752 LOGICAL :: my_req_grad
753
754 INTERFACE
755 SUBROUTINE torch_c_tensor_from_array_int64 (tensor, req_grad, ndims, sizes, source) &
756 BIND(C, name="torch_c_tensor_from_array_int64")
757 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
758 TYPE(c_ptr) :: tensor
759 LOGICAL(kind=C_BOOL), VALUE :: req_grad
760 INTEGER(kind=C_INT), VALUE :: ndims
761 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
762 INTEGER(kind=C_INT64_T), DIMENSION(*) :: source
763 END SUBROUTINE torch_c_tensor_from_array_int64
764 END INTERFACE
765
766 my_req_grad = .false.
767 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
768
769 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
770 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
771 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
772
773 cpassert(.NOT. c_associated(tensor%c_ptr))
774 CALL torch_c_tensor_from_array_int64 (tensor=tensor%c_ptr, &
775 req_grad=LOGICAL(my_req_grad, C_BOOL), &
776 ndims=3, &
777 sizes=sizes_c, &
778 source=source)
779 cpassert(c_associated(tensor%c_ptr))
780#else
781 cpabort("CP2K compiled without the Torch library.")
782 mark_used(tensor)
783 mark_used(source)
784 mark_used(requires_grad)
785#endif
786 END SUBROUTINE torch_tensor_from_array_int64_3d
787
788! **************************************************************************************************
789!> \brief Copies data from a Torch tensor to an array.
790!> The returned pointer is only valide during the tensor's lifetime!
791!> \author Ole Schuett
792! **************************************************************************************************
793 SUBROUTINE torch_tensor_data_ptr_int64_3d(tensor, data_ptr)
794 TYPE(torch_tensor_type), INTENT(IN) :: tensor
795 INTEGER(kind=int_8), DIMENSION(:, :, :), POINTER :: data_ptr
796
797#if defined(__LIBTORCH)
798 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
799 TYPE(c_ptr) :: data_ptr_c
800
801 INTERFACE
802 SUBROUTINE torch_c_tensor_data_ptr_int64 (tensor, ndims, sizes, data_ptr) &
803 BIND(C, name="torch_c_tensor_data_ptr_int64")
804 IMPORT :: c_char, c_ptr, c_int, c_int64_t
805 TYPE(c_ptr), VALUE :: tensor
806 INTEGER(kind=C_INT), VALUE :: ndims
807 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
808 TYPE(c_ptr) :: data_ptr
809 END SUBROUTINE torch_c_tensor_data_ptr_int64
810 END INTERFACE
811
812 sizes_c(:) = -1
813 data_ptr_c = c_null_ptr
814 cpassert(c_associated(tensor%c_ptr))
815 cpassert(.NOT. ASSOCIATED(data_ptr))
816 CALL torch_c_tensor_data_ptr_int64 (tensor=tensor%c_ptr, &
817 ndims=3, &
818 sizes=sizes_c, &
819 data_ptr=data_ptr_c)
820
821 cpassert(all(sizes_c >= 0))
822 cpassert(c_associated(data_ptr_c))
823
824 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
825 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
826 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
827 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
828#else
829 cpabort("CP2K compiled without the Torch library.")
830 mark_used(tensor)
831 mark_used(data_ptr)
832#endif
833 END SUBROUTINE torch_tensor_data_ptr_int64_3d
834
835
836! **************************************************************************************************
837!> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
838!> The source must be an ALLOCATABLE to prevent passing a temporary array.
839!> \author Ole Schuett
840! **************************************************************************************************
841 SUBROUTINE torch_tensor_from_array_double_3d(tensor, source, requires_grad)
842 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
843 REAL(dp), DIMENSION(:, :, :), ALLOCATABLE, INTENT(IN) :: source
844 LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
845
846#if defined(__LIBTORCH)
847 INTEGER(kind=int_8), DIMENSION(3) :: sizes_c
848 LOGICAL :: my_req_grad
849
850 INTERFACE
851 SUBROUTINE torch_c_tensor_from_array_double (tensor, req_grad, ndims, sizes, source) &
852 BIND(C, name="torch_c_tensor_from_array_double")
853 IMPORT :: c_ptr, c_int, c_int64_t, c_float, c_double, c_bool
854 TYPE(c_ptr) :: tensor
855 LOGICAL(kind=C_BOOL), VALUE :: req_grad
856 INTEGER(kind=C_INT), VALUE :: ndims
857 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
858 REAL(kind=c_double), DIMENSION(*) :: source
859 END SUBROUTINE torch_c_tensor_from_array_double
860 END INTERFACE
861
862 my_req_grad = .false.
863 IF (PRESENT(requires_grad)) my_req_grad = requires_grad
864
865 sizes_c(1) = SIZE(source, 3) ! C arrays are stored row-major.
866 sizes_c(2) = SIZE(source, 2) ! C arrays are stored row-major.
867 sizes_c(3) = SIZE(source, 1) ! C arrays are stored row-major.
868
869 cpassert(.NOT. c_associated(tensor%c_ptr))
870 CALL torch_c_tensor_from_array_double (tensor=tensor%c_ptr, &
871 req_grad=LOGICAL(my_req_grad, C_BOOL), &
872 ndims=3, &
873 sizes=sizes_c, &
874 source=source)
875 cpassert(c_associated(tensor%c_ptr))
876#else
877 cpabort("CP2K compiled without the Torch library.")
878 mark_used(tensor)
879 mark_used(source)
880 mark_used(requires_grad)
881#endif
882 END SUBROUTINE torch_tensor_from_array_double_3d
883
884! **************************************************************************************************
885!> \brief Copies data from a Torch tensor to an array.
886!> The returned pointer is only valide during the tensor's lifetime!
887!> \author Ole Schuett
888! **************************************************************************************************
889 SUBROUTINE torch_tensor_data_ptr_double_3d(tensor, data_ptr)
890 TYPE(torch_tensor_type), INTENT(IN) :: tensor
891 REAL(dp), DIMENSION(:, :, :), POINTER :: data_ptr
892
893#if defined(__LIBTORCH)
894 INTEGER(kind=int_8), DIMENSION(3) :: sizes_f, sizes_c
895 TYPE(c_ptr) :: data_ptr_c
896
897 INTERFACE
898 SUBROUTINE torch_c_tensor_data_ptr_double (tensor, ndims, sizes, data_ptr) &
899 BIND(C, name="torch_c_tensor_data_ptr_double")
900 IMPORT :: c_char, c_ptr, c_int, c_int64_t
901 TYPE(c_ptr), VALUE :: tensor
902 INTEGER(kind=C_INT), VALUE :: ndims
903 INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
904 TYPE(c_ptr) :: data_ptr
905 END SUBROUTINE torch_c_tensor_data_ptr_double
906 END INTERFACE
907
908 sizes_c(:) = -1
909 data_ptr_c = c_null_ptr
910 cpassert(c_associated(tensor%c_ptr))
911 cpassert(.NOT. ASSOCIATED(data_ptr))
912 CALL torch_c_tensor_data_ptr_double (tensor=tensor%c_ptr, &
913 ndims=3, &
914 sizes=sizes_c, &
915 data_ptr=data_ptr_c)
916
917 cpassert(all(sizes_c >= 0))
918 cpassert(c_associated(data_ptr_c))
919
920 sizes_f(1) = sizes_c(3) ! C arrays are stored row-major.
921 sizes_f(2) = sizes_c(2) ! C arrays are stored row-major.
922 sizes_f(3) = sizes_c(1) ! C arrays are stored row-major.
923 CALL c_f_pointer(data_ptr_c, data_ptr, shape=sizes_f)
924#else
925 cpabort("CP2K compiled without the Torch library.")
926 mark_used(tensor)
927 mark_used(data_ptr)
928#endif
929 END SUBROUTINE torch_tensor_data_ptr_double_3d
930
931
932! **************************************************************************************************
933!> \brief Runs autograd on a Torch tensor.
934!> \author Ole Schuett
935! **************************************************************************************************
936 SUBROUTINE torch_tensor_backward(tensor, outer_grad)
937 TYPE(torch_tensor_type), INTENT(IN) :: tensor
938 TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
939
940#if defined(__LIBTORCH)
941 CHARACTER(len=*), PARAMETER :: routinen = 'torch_tensor_backward'
942 INTEGER :: handle
943
944 INTERFACE
945 SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
946 BIND(C, name="torch_c_tensor_backward")
947 IMPORT :: c_char, c_ptr
948 TYPE(c_ptr), VALUE :: tensor
949 TYPE(c_ptr), VALUE :: outer_grad
950 END SUBROUTINE torch_c_tensor_backward
951 END INTERFACE
952
953 CALL timeset(routinen, handle)
954 cpassert(c_associated(tensor%c_ptr))
955 cpassert(c_associated(outer_grad%c_ptr))
956 CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
957 CALL timestop(handle)
958#else
959 cpabort("CP2K compiled without the Torch library.")
960 mark_used(tensor)
961 mark_used(outer_grad)
962#endif
963 END SUBROUTINE torch_tensor_backward
964
965! **************************************************************************************************
966!> \brief Returns the gradient of a Torch tensor which was computed by autograd.
967!> \author Ole Schuett
968! **************************************************************************************************
969 SUBROUTINE torch_tensor_grad(tensor, grad)
970 TYPE(torch_tensor_type), INTENT(IN) :: tensor
971 TYPE(torch_tensor_type), INTENT(INOUT) :: grad
972
973#if defined(__LIBTORCH)
974 INTERFACE
975 SUBROUTINE torch_c_tensor_grad(tensor, grad) &
976 BIND(C, name="torch_c_tensor_grad")
977 IMPORT :: c_ptr
978 TYPE(c_ptr), VALUE :: tensor
979 TYPE(c_ptr) :: grad
980 END SUBROUTINE torch_c_tensor_grad
981 END INTERFACE
982
983 cpassert(c_associated(tensor%c_ptr))
984 cpassert(.NOT. c_associated(grad%c_ptr))
985 CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
986 cpassert(c_associated(grad%c_ptr))
987#else
988 cpabort("CP2K compiled without the Torch library.")
989 mark_used(tensor)
990 mark_used(grad)
991#endif
992 END SUBROUTINE torch_tensor_grad
993
994! **************************************************************************************************
995!> \brief Releases a Torch tensor and all its ressources.
996!> \author Ole Schuett
997! **************************************************************************************************
998 SUBROUTINE torch_tensor_release(tensor)
999 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1000
1001#if defined(__LIBTORCH)
1002 INTERFACE
1003 SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
1004 IMPORT :: c_ptr
1005 TYPE(c_ptr), VALUE :: tensor
1006 END SUBROUTINE torch_c_tensor_release
1007 END INTERFACE
1008
1009 cpassert(c_associated(tensor%c_ptr))
1010 CALL torch_c_tensor_release(tensor=tensor%c_ptr)
1011 tensor%c_ptr = c_null_ptr
1012#else
1013 cpabort("CP2K was compiled without Torch library.")
1014 mark_used(tensor)
1015#endif
1016 END SUBROUTINE torch_tensor_release
1017
1018! **************************************************************************************************
1019!> \brief Creates an empty Torch dictionary.
1020!> \author Ole Schuett
1021! **************************************************************************************************
1022 SUBROUTINE torch_dict_create(dict)
1023 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1024
1025#if defined(__LIBTORCH)
1026 INTERFACE
1027 SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
1028 IMPORT :: c_ptr
1029 TYPE(c_ptr) :: dict
1030 END SUBROUTINE torch_c_dict_create
1031 END INTERFACE
1032
1033 cpassert(.NOT. c_associated(dict%c_ptr))
1034 CALL torch_c_dict_create(dict=dict%c_ptr)
1035 cpassert(c_associated(dict%c_ptr))
1036#else
1037 cpabort("CP2K was compiled without Torch library.")
1038 mark_used(dict)
1039#endif
1040 END SUBROUTINE torch_dict_create
1041
1042! **************************************************************************************************
1043!> \brief Inserts a Torch tensor into a Torch dictionary.
1044!> \author Ole Schuett
1045! **************************************************************************************************
1046 SUBROUTINE torch_dict_insert(dict, key, tensor)
1047 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1048 CHARACTER(len=*), INTENT(IN) :: key
1049 TYPE(torch_tensor_type), INTENT(IN) :: tensor
1050
1051#if defined(__LIBTORCH)
1052
1053 INTERFACE
1054 SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
1055 BIND(C, name="torch_c_dict_insert")
1056 IMPORT :: c_char, c_ptr
1057 TYPE(c_ptr), VALUE :: dict
1058 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1059 TYPE(c_ptr), VALUE :: tensor
1060 END SUBROUTINE torch_c_dict_insert
1061 END INTERFACE
1062
1063 cpassert(c_associated(dict%c_ptr))
1064 cpassert(c_associated(tensor%c_ptr))
1065 CALL torch_c_dict_insert(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1066#else
1067 cpabort("CP2K compiled without the Torch library.")
1068 mark_used(dict)
1069 mark_used(key)
1070 mark_used(tensor)
1071#endif
1072 END SUBROUTINE torch_dict_insert
1073
1074! **************************************************************************************************
1075!> \brief Retrieves a Torch tensor from a Torch dictionary.
1076!> \author Ole Schuett
1077! **************************************************************************************************
1078 SUBROUTINE torch_dict_get(dict, key, tensor)
1079 TYPE(torch_dict_type), INTENT(IN) :: dict
1080 CHARACTER(len=*), INTENT(IN) :: key
1081 TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
1082
1083#if defined(__LIBTORCH)
1084
1085 INTERFACE
1086 SUBROUTINE torch_c_dict_get(dict, key, tensor) &
1087 BIND(C, name="torch_c_dict_get")
1088 IMPORT :: c_char, c_ptr
1089 TYPE(c_ptr), VALUE :: dict
1090 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1091 TYPE(c_ptr) :: tensor
1092 END SUBROUTINE torch_c_dict_get
1093 END INTERFACE
1094
1095 cpassert(c_associated(dict%c_ptr))
1096 cpassert(.NOT. c_associated(tensor%c_ptr))
1097 CALL torch_c_dict_get(dict=dict%c_ptr, key=trim(key)//c_null_char, tensor=tensor%c_ptr)
1098 cpassert(c_associated(tensor%c_ptr))
1099
1100#else
1101 cpabort("CP2K compiled without the Torch library.")
1102 mark_used(dict)
1103 mark_used(key)
1104 mark_used(tensor)
1105#endif
1106 END SUBROUTINE torch_dict_get
1107
1108! **************************************************************************************************
1109!> \brief Releases a Torch dictionary and all its ressources.
1110!> \author Ole Schuett
1111! **************************************************************************************************
1112 SUBROUTINE torch_dict_release(dict)
1113 TYPE(torch_dict_type), INTENT(INOUT) :: dict
1114
1115#if defined(__LIBTORCH)
1116 INTERFACE
1117 SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
1118 IMPORT :: c_ptr
1119 TYPE(c_ptr), VALUE :: dict
1120 END SUBROUTINE torch_c_dict_release
1121 END INTERFACE
1122
1123 cpassert(c_associated(dict%c_ptr))
1124 CALL torch_c_dict_release(dict=dict%c_ptr)
1125 dict%c_ptr = c_null_ptr
1126#else
1127 cpabort("CP2K was compiled without Torch library.")
1128 mark_used(dict)
1129#endif
1130 END SUBROUTINE torch_dict_release
1131
1132! **************************************************************************************************
1133!> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
1134!> \author Ole Schuett
1135! **************************************************************************************************
1136 SUBROUTINE torch_model_load(model, filename)
1137 TYPE(torch_model_type), INTENT(INOUT) :: model
1138 CHARACTER(len=*), INTENT(IN) :: filename
1139
1140#if defined(__LIBTORCH)
1141 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_load'
1142 INTEGER :: handle
1143
1144 INTERFACE
1145 SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
1146 IMPORT :: c_ptr, c_char
1147 TYPE(c_ptr) :: model
1148 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
1149 END SUBROUTINE torch_c_model_load
1150 END INTERFACE
1151
1152 CALL timeset(routinen, handle)
1153 cpassert(.NOT. c_associated(model%c_ptr))
1154 CALL torch_c_model_load(model=model%c_ptr, filename=trim(filename)//c_null_char)
1155 cpassert(c_associated(model%c_ptr))
1156 CALL timestop(handle)
1157#else
1158 cpabort("CP2K was compiled without Torch library.")
1159 mark_used(model)
1160 mark_used(filename)
1161#endif
1162 END SUBROUTINE torch_model_load
1163
1164! **************************************************************************************************
1165!> \brief Evaluates the given Torch model.
1166!> \author Ole Schuett
1167! **************************************************************************************************
1168 SUBROUTINE torch_model_forward(model, inputs, outputs)
1169 TYPE(torch_model_type), INTENT(INOUT) :: model
1170 TYPE(torch_dict_type), INTENT(IN) :: inputs
1171 TYPE(torch_dict_type), INTENT(INOUT) :: outputs
1172
1173#if defined(__LIBTORCH)
1174 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_forward'
1175 INTEGER :: handle
1176
1177 INTERFACE
1178 SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
1179 IMPORT :: c_ptr
1180 TYPE(c_ptr), VALUE :: model
1181 TYPE(c_ptr), VALUE :: inputs
1182 TYPE(c_ptr), VALUE :: outputs
1183 END SUBROUTINE torch_c_model_forward
1184 END INTERFACE
1185
1186 CALL timeset(routinen, handle)
1187 cpassert(c_associated(model%c_ptr))
1188 cpassert(c_associated(inputs%c_ptr))
1189 cpassert(c_associated(outputs%c_ptr))
1190 CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
1191 CALL timestop(handle)
1192#else
1193 cpabort("CP2K was compiled without Torch library.")
1194 mark_used(model)
1195 mark_used(inputs)
1196 mark_used(outputs)
1197#endif
1198 END SUBROUTINE torch_model_forward
1199
1200! **************************************************************************************************
1201!> \brief Releases a Torch model and all its ressources.
1202!> \author Ole Schuett
1203! **************************************************************************************************
1204 SUBROUTINE torch_model_release(model)
1205 TYPE(torch_model_type), INTENT(INOUT) :: model
1206
1207#if defined(__LIBTORCH)
1208 INTERFACE
1209 SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
1210 IMPORT :: c_ptr
1211 TYPE(c_ptr), VALUE :: model
1212 END SUBROUTINE torch_c_model_release
1213 END INTERFACE
1214
1215 cpassert(c_associated(model%c_ptr))
1216 CALL torch_c_model_release(model=model%c_ptr)
1217 model%c_ptr = c_null_ptr
1218#else
1219 cpabort("CP2K was compiled without Torch library.")
1220 mark_used(model)
1221#endif
1222 END SUBROUTINE torch_model_release
1223
1224! **************************************************************************************************
1225!> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
1226!> \author Ole Schuett
1227! **************************************************************************************************
1228 FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
1229 CHARACTER(len=*), INTENT(IN) :: filename, key
1230 CHARACTER(:), ALLOCATABLE :: res
1231
1232#if defined(__LIBTORCH)
1233 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_read_metadata'
1234 INTEGER :: handle
1235
1236 CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
1237 POINTER :: content_f
1238 INTEGER :: i
1239 INTEGER :: length
1240 TYPE(c_ptr) :: content_c
1241
1242 INTERFACE
1243 SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
1244 BIND(C, name="torch_c_model_read_metadata")
1245 IMPORT :: c_char, c_ptr, c_int
1246 CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
1247 TYPE(c_ptr) :: content
1248 INTEGER(kind=C_INT) :: length
1249 END SUBROUTINE torch_c_model_read_metadata
1250 END INTERFACE
1251
1252 CALL timeset(routinen, handle)
1253 content_c = c_null_ptr
1254 length = -1
1255 CALL torch_c_model_read_metadata(filename=trim(filename)//c_null_char, &
1256 key=trim(key)//c_null_char, &
1257 content=content_c, &
1258 length=length)
1259 cpassert(c_associated(content_c))
1260 cpassert(length >= 0)
1261
1262 CALL c_f_pointer(content_c, content_f, shape=(/length + 1/))
1263 cpassert(content_f(length + 1) == c_null_char)
1264
1265 ALLOCATE (CHARACTER(LEN=length) :: res)
1266 DO i = 1, length
1267 cpassert(content_f(i) /= c_null_char)
1268 res(i:i) = content_f(i)
1269 END DO
1270
1271 DEALLOCATE (content_f) ! Was allocated on the C side.
1272 CALL timestop(handle)
1273#else
1274 cpabort("CP2K was compiled without Torch library.")
1275 mark_used(filename)
1276 mark_used(key)
1277 mark_used(res)
1278#endif
1279 END FUNCTION torch_model_read_metadata
1280
1281! **************************************************************************************************
1282!> \brief Returns true iff the Torch CUDA backend is available.
1283!> \author Ole Schuett
1284! **************************************************************************************************
1285 FUNCTION torch_cuda_is_available() RESULT(res)
1286 LOGICAL :: res
1287
1288#if defined(__LIBTORCH)
1289 INTERFACE
1290 FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
1291 IMPORT :: c_bool
1292 LOGICAL(C_BOOL) :: torch_c_cuda_is_available
1293 END FUNCTION torch_c_cuda_is_available
1294 END INTERFACE
1295
1296 res = torch_c_cuda_is_available()
1297#else
1298 cpabort("CP2K was compiled without Torch library.")
1299 res = .false.
1300#endif
1301 END FUNCTION torch_cuda_is_available
1302
1303! **************************************************************************************************
1304!> \brief Set whether to allow the use of TF32.
1305!> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
1306!> See https://pytorch.org/docs/stable/notes/cuda.html
1307!> \author Gabriele Tocci
1308! **************************************************************************************************
1309 SUBROUTINE torch_allow_tf32(allow_tf32)
1310 LOGICAL, INTENT(IN) :: allow_tf32
1311
1312#if defined(__LIBTORCH)
1313 INTERFACE
1314 SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
1315 IMPORT :: c_bool
1316 LOGICAL(C_BOOL), VALUE :: allow_tf32
1317 END SUBROUTINE torch_c_allow_tf32
1318 END INTERFACE
1319
1320 CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, c_bool))
1321#else
1322 cpabort("CP2K was compiled without Torch library.")
1323 mark_used(allow_tf32)
1324#endif
1325 END SUBROUTINE torch_allow_tf32
1326
1327! **************************************************************************************************
1328!> \brief Freeze the given Torch model: applies generic optimization that speed up model.
1329!> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
1330!> \author Gabriele Tocci
1331! **************************************************************************************************
1332 SUBROUTINE torch_model_freeze(model)
1333 TYPE(torch_model_type), INTENT(INOUT) :: model
1334
1335#if defined(__LIBTORCH)
1336 CHARACTER(len=*), PARAMETER :: routinen = 'torch_model_freeze'
1337 INTEGER :: handle
1338
1339 INTERFACE
1340 SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
1341 IMPORT :: c_ptr
1342 TYPE(c_ptr), VALUE :: model
1343 END SUBROUTINE torch_c_model_freeze
1344 END INTERFACE
1345
1346 CALL timeset(routinen, handle)
1347 cpassert(c_associated(model%c_ptr))
1348 CALL torch_c_model_freeze(model=model%c_ptr)
1349 CALL timestop(handle)
1350#else
1351 cpabort("CP2K was compiled without Torch library.")
1352 mark_used(model)
1353#endif
1354 END SUBROUTINE torch_model_freeze
1355
1356
1357! **************************************************************************************************
1358!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1359!> \author Ole Schuett
1360! **************************************************************************************************
1361 SUBROUTINE torch_model_get_attr_int64 (model, key, dest)
1362 TYPE(torch_model_type), INTENT(IN) :: model
1363 CHARACTER(len=*), INTENT(IN) :: key
1364 INTEGER(kind=int_8), INTENT(OUT) :: dest
1365
1366#if defined(__LIBTORCH)
1367
1368 INTERFACE
1369 SUBROUTINE torch_c_model_get_attr_int64 (model, key, dest) &
1370 BIND(C, name="torch_c_model_get_attr_int64")
1371 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1372 TYPE(c_ptr), VALUE :: model
1373 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1374 INTEGER(kind=C_INT64_T) :: dest
1375 END SUBROUTINE torch_c_model_get_attr_int64
1376 END INTERFACE
1377
1378 CALL torch_c_model_get_attr_int64 (model=model%c_ptr, &
1379 key=trim(key)//c_null_char, &
1380 dest=dest)
1381#else
1382 cpabort("CP2K compiled without the Torch library.")
1383 mark_used(model)
1384 mark_used(key)
1385 mark_used(dest)
1386#endif
1387 END SUBROUTINE torch_model_get_attr_int64
1388! **************************************************************************************************
1389!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1390!> \author Ole Schuett
1391! **************************************************************************************************
1392 SUBROUTINE torch_model_get_attr_double (model, key, dest)
1393 TYPE(torch_model_type), INTENT(IN) :: model
1394 CHARACTER(len=*), INTENT(IN) :: key
1395 REAL(dp), INTENT(OUT) :: dest
1396
1397#if defined(__LIBTORCH)
1398
1399 INTERFACE
1400 SUBROUTINE torch_c_model_get_attr_double (model, key, dest) &
1401 BIND(C, name="torch_c_model_get_attr_double")
1402 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1403 TYPE(c_ptr), VALUE :: model
1404 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1405 REAL(kind=c_double) :: dest
1406 END SUBROUTINE torch_c_model_get_attr_double
1407 END INTERFACE
1408
1409 CALL torch_c_model_get_attr_double (model=model%c_ptr, &
1410 key=trim(key)//c_null_char, &
1411 dest=dest)
1412#else
1413 cpabort("CP2K compiled without the Torch library.")
1414 mark_used(model)
1415 mark_used(key)
1416 mark_used(dest)
1417#endif
1418 END SUBROUTINE torch_model_get_attr_double
1419! **************************************************************************************************
1420!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1421!> \author Ole Schuett
1422! **************************************************************************************************
1423 SUBROUTINE torch_model_get_attr_string (model, key, dest)
1424 TYPE(torch_model_type), INTENT(IN) :: model
1425 CHARACTER(len=*), INTENT(IN) :: key
1426 CHARACTER(LEN=default_string_length), INTENT(OUT) :: dest
1427
1428#if defined(__LIBTORCH)
1429
1430 INTERFACE
1431 SUBROUTINE torch_c_model_get_attr_string (model, key, dest) &
1432 BIND(C, name="torch_c_model_get_attr_string")
1433 IMPORT :: c_ptr, c_char, c_int64_t, c_double
1434 TYPE(c_ptr), VALUE :: model
1435 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1436 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
1437 END SUBROUTINE torch_c_model_get_attr_string
1438 END INTERFACE
1439
1440 CALL torch_c_model_get_attr_string (model=model%c_ptr, &
1441 key=trim(key)//c_null_char, &
1442 dest=dest)
1443#else
1444 cpabort("CP2K compiled without the Torch library.")
1445 mark_used(model)
1446 mark_used(key)
1447 mark_used(dest)
1448#endif
1449 END SUBROUTINE torch_model_get_attr_string
1450
1451! **************************************************************************************************
1452!> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1453!> \author Ole Schuett
1454! **************************************************************************************************
1455 SUBROUTINE torch_model_get_attr_int32(model, key, dest)
1456 TYPE(torch_model_type), INTENT(IN) :: model
1457 CHARACTER(len=*), INTENT(IN) :: key
1458 INTEGER, INTENT(OUT) :: dest
1459
1460 INTEGER(kind=int_8) :: temp
1461 CALL torch_model_get_attr_int64(model, key, temp)
1462 cpassert(abs(temp) < huge(dest))
1463 dest = int(temp)
1464 END SUBROUTINE torch_model_get_attr_int32
1465
1466! **************************************************************************************************
1467!> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
1468!> \author Ole Schuett
1469! **************************************************************************************************
1470 SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
1471 TYPE(torch_model_type), INTENT(IN) :: model
1472 CHARACTER(len=*), INTENT(IN) :: key
1473 CHARACTER(LEN=default_string_length), &
1474 ALLOCATABLE, DIMENSION(:) :: dest
1475
1476#if defined(__LIBTORCH)
1477
1478 INTEGER :: num_items, i
1479
1480 INTERFACE
1481 SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
1482 BIND(C, name="torch_c_model_get_attr_list_size")
1483 IMPORT :: c_ptr, c_char, c_int
1484 TYPE(c_ptr), VALUE :: model
1485 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1486 INTEGER(kind=C_INT) :: size
1487 END SUBROUTINE torch_c_model_get_attr_list_size
1488 END INTERFACE
1489
1490 INTERFACE
1491 SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
1492 BIND(C, name="torch_c_model_get_attr_strlist")
1493 IMPORT :: c_ptr, c_char, c_int
1494 TYPE(c_ptr), VALUE :: model
1495 CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1496 INTEGER(kind=C_INT), VALUE :: index
1497 CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
1498 END SUBROUTINE torch_c_model_get_attr_strlist
1499 END INTERFACE
1500
1501 CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
1502 key=trim(key)//c_null_char, &
1503 size=num_items)
1504 ALLOCATE (dest(num_items))
1505 dest(:) = ""
1506
1507 DO i = 1, num_items
1508 CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
1509 key=trim(key)//c_null_char, &
1510 index=i - 1, &
1511 dest=dest(i))
1512
1513 END DO
1514#else
1515 cpabort("CP2K compiled without the Torch library.")
1516 mark_used(model)
1517 mark_used(key)
1518 mark_used(dest)
1519#endif
1520
1521 END SUBROUTINE torch_model_get_attr_strlist
1522
1523END MODULE torch_api
struct tensor_ tensor
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
integer, parameter, public sp
Definition kinds.F:33
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
Definition torch_api.F:1113
subroutine, public torch_tensor_backward(tensor, outer_grad)
Runs autograd on a Torch tensor.
Definition torch_api.F:937
subroutine, public torch_dict_get(dict, key, tensor)
Retrieves a Torch tensor from a Torch dictionary.
Definition torch_api.F:1079
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
Definition torch_api.F:1137
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
Definition torch_api.F:1023
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
Definition torch_api.F:1205
subroutine, public torch_tensor_grad(tensor, grad)
Returns the gradient of a Torch tensor which was computed by autograd.
Definition torch_api.F:970
subroutine, public torch_allow_tf32(allow_tf32)
Set whether to allow the use of TF32. Needed due to changes in defaults from pytorch 1....
Definition torch_api.F:1310
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
Definition torch_api.F:1333
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
Definition torch_api.F:1229
subroutine, public torch_dict_insert(dict, key, tensor)
Inserts a Torch tensor into a Torch dictionary.
Definition torch_api.F:1047
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.
Definition torch_api.F:1286
subroutine, public torch_tensor_release(tensor)
Releases a Torch tensor and all its ressources.
Definition torch_api.F:999
subroutine, public torch_model_forward(model, inputs, outputs)
Evaluates the given Torch model.
Definition torch_api.F:1169