8 USE iso_c_binding,
ONLY: c_associated, &
22 #include "./base/base_uses.f90"
30 TYPE(C_PTR) :: c_ptr = c_null_ptr
31 END TYPE torch_dict_type
35 TYPE(C_PTR) :: c_ptr = c_null_ptr
36 END TYPE torch_model_type
38 INTERFACE torch_dict_insert
39 MODULE PROCEDURE torch_dict_insert_float_1d
40 MODULE PROCEDURE torch_dict_insert_int64_1d
41 MODULE PROCEDURE torch_dict_insert_double_1d
42 MODULE PROCEDURE torch_dict_insert_float_2d
43 MODULE PROCEDURE torch_dict_insert_int64_2d
44 MODULE PROCEDURE torch_dict_insert_double_2d
45 MODULE PROCEDURE torch_dict_insert_float_3d
46 MODULE PROCEDURE torch_dict_insert_int64_3d
47 MODULE PROCEDURE torch_dict_insert_double_3d
48 END INTERFACE torch_dict_insert
50 INTERFACE torch_dict_get
51 MODULE PROCEDURE torch_dict_get_float_1d
52 MODULE PROCEDURE torch_dict_get_int64_1d
53 MODULE PROCEDURE torch_dict_get_double_1d
54 MODULE PROCEDURE torch_dict_get_float_2d
55 MODULE PROCEDURE torch_dict_get_int64_2d
56 MODULE PROCEDURE torch_dict_get_double_2d
57 MODULE PROCEDURE torch_dict_get_float_3d
58 MODULE PROCEDURE torch_dict_get_int64_3d
59 MODULE PROCEDURE torch_dict_get_double_3d
60 END INTERFACE torch_dict_get
63 PUBLIC :: torch_dict_insert, torch_dict_get
76 SUBROUTINE torch_dict_insert_float_1d(dict, key, source)
77 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
78 CHARACTER(len=*),
INTENT(IN) :: key
79 REAL(sp),
CONTIGUOUS,
DIMENSION(:),
INTENT(IN) :: source
81 #if defined(__LIBTORCH)
82 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_c
85 SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
86 BIND(C, name="torch_c_dict_insert_float")
87 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
88 TYPE(C_PTR),
VALUE :: dict
89 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
90 INTEGER(kind=C_INT),
VALUE :: ndims
91 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
92 REAL(kind=c_float),
DIMENSION(*) :: source
93 END SUBROUTINE torch_c_dict_insert_float
96 sizes_c(1) =
SIZE(source, 1)
98 cpassert(c_associated(dict%c_ptr))
99 CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
100 key=trim(key)//c_null_char, &
105 cpabort(
"CP2K compiled without the Torch library.")
110 END SUBROUTINE torch_dict_insert_float_1d
116 SUBROUTINE torch_dict_get_float_1d(dict, key, dest)
117 TYPE(torch_dict_type),
INTENT(IN) :: dict
118 CHARACTER(len=*),
INTENT(IN) :: key
119 REAL(sp),
DIMENSION(:),
POINTER :: dest
121 #if defined(__LIBTORCH)
122 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_f, sizes_c
123 TYPE(C_PTR) :: dest_c
126 SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
127 BIND(C, name="torch_c_dict_get_float")
128 IMPORT :: c_char, c_ptr, c_int, c_int64_t
129 TYPE(C_PTR),
VALUE :: dict
130 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
131 INTEGER(kind=C_INT),
VALUE :: ndims
132 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
134 END SUBROUTINE torch_c_dict_get_float
139 cpassert(c_associated(dict%c_ptr))
140 cpassert(.NOT.
ASSOCIATED(dest))
141 CALL torch_c_dict_get_float (dict=dict%c_ptr, &
142 key=trim(key)//c_null_char, &
147 cpassert(all(sizes_c >= 0))
148 cpassert(c_associated(dest_c))
150 sizes_f(1) = sizes_c(1)
151 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
153 cpabort(
"CP2K compiled without the Torch library.")
158 END SUBROUTINE torch_dict_get_float_1d
165 SUBROUTINE torch_dict_insert_int64_1d(dict, key, source)
166 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
167 CHARACTER(len=*),
INTENT(IN) :: key
168 INTEGER(kind=int_8),
CONTIGUOUS,
DIMENSION(:),
INTENT(IN) :: source
170 #if defined(__LIBTORCH)
171 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_c
174 SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
175 BIND(C, name="torch_c_dict_insert_int64")
176 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
177 TYPE(C_PTR),
VALUE :: dict
178 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
179 INTEGER(kind=C_INT),
VALUE :: ndims
180 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
181 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: source
182 END SUBROUTINE torch_c_dict_insert_int64
185 sizes_c(1) =
SIZE(source, 1)
187 cpassert(c_associated(dict%c_ptr))
188 CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
189 key=trim(key)//c_null_char, &
194 cpabort(
"CP2K compiled without the Torch library.")
199 END SUBROUTINE torch_dict_insert_int64_1d
205 SUBROUTINE torch_dict_get_int64_1d(dict, key, dest)
206 TYPE(torch_dict_type),
INTENT(IN) :: dict
207 CHARACTER(len=*),
INTENT(IN) :: key
208 INTEGER(kind=int_8),
DIMENSION(:),
POINTER :: dest
210 #if defined(__LIBTORCH)
211 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_f, sizes_c
212 TYPE(C_PTR) :: dest_c
215 SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
216 BIND(C, name="torch_c_dict_get_int64")
217 IMPORT :: c_char, c_ptr, c_int, c_int64_t
218 TYPE(C_PTR),
VALUE :: dict
219 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
220 INTEGER(kind=C_INT),
VALUE :: ndims
221 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
223 END SUBROUTINE torch_c_dict_get_int64
228 cpassert(c_associated(dict%c_ptr))
229 cpassert(.NOT.
ASSOCIATED(dest))
230 CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
231 key=trim(key)//c_null_char, &
236 cpassert(all(sizes_c >= 0))
237 cpassert(c_associated(dest_c))
239 sizes_f(1) = sizes_c(1)
240 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
242 cpabort(
"CP2K compiled without the Torch library.")
247 END SUBROUTINE torch_dict_get_int64_1d
254 SUBROUTINE torch_dict_insert_double_1d(dict, key, source)
255 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
256 CHARACTER(len=*),
INTENT(IN) :: key
257 REAL(dp),
CONTIGUOUS,
DIMENSION(:),
INTENT(IN) :: source
259 #if defined(__LIBTORCH)
260 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_c
263 SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
264 BIND(C, name="torch_c_dict_insert_double")
265 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
266 TYPE(C_PTR),
VALUE :: dict
267 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
268 INTEGER(kind=C_INT),
VALUE :: ndims
269 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
270 REAL(kind=c_double),
DIMENSION(*) :: source
271 END SUBROUTINE torch_c_dict_insert_double
274 sizes_c(1) =
SIZE(source, 1)
276 cpassert(c_associated(dict%c_ptr))
277 CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
278 key=trim(key)//c_null_char, &
283 cpabort(
"CP2K compiled without the Torch library.")
288 END SUBROUTINE torch_dict_insert_double_1d
294 SUBROUTINE torch_dict_get_double_1d(dict, key, dest)
295 TYPE(torch_dict_type),
INTENT(IN) :: dict
296 CHARACTER(len=*),
INTENT(IN) :: key
297 REAL(dp),
DIMENSION(:),
POINTER :: dest
299 #if defined(__LIBTORCH)
300 INTEGER(kind=int_8),
DIMENSION(1) :: sizes_f, sizes_c
301 TYPE(C_PTR) :: dest_c
304 SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
305 BIND(C, name="torch_c_dict_get_double")
306 IMPORT :: c_char, c_ptr, c_int, c_int64_t
307 TYPE(C_PTR),
VALUE :: dict
308 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
309 INTEGER(kind=C_INT),
VALUE :: ndims
310 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
312 END SUBROUTINE torch_c_dict_get_double
317 cpassert(c_associated(dict%c_ptr))
318 cpassert(.NOT.
ASSOCIATED(dest))
319 CALL torch_c_dict_get_double (dict=dict%c_ptr, &
320 key=trim(key)//c_null_char, &
325 cpassert(all(sizes_c >= 0))
326 cpassert(c_associated(dest_c))
328 sizes_f(1) = sizes_c(1)
329 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
331 cpabort(
"CP2K compiled without the Torch library.")
336 END SUBROUTINE torch_dict_get_double_1d
343 SUBROUTINE torch_dict_insert_float_2d(dict, key, source)
344 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
345 CHARACTER(len=*),
INTENT(IN) :: key
346 REAL(sp),
CONTIGUOUS,
DIMENSION(:, :),
INTENT(IN) :: source
348 #if defined(__LIBTORCH)
349 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_c
352 SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
353 BIND(C, name="torch_c_dict_insert_float")
354 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
355 TYPE(C_PTR),
VALUE :: dict
356 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
357 INTEGER(kind=C_INT),
VALUE :: ndims
358 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
359 REAL(kind=c_float),
DIMENSION(*) :: source
360 END SUBROUTINE torch_c_dict_insert_float
363 sizes_c(1) =
SIZE(source, 2)
364 sizes_c(2) =
SIZE(source, 1)
366 cpassert(c_associated(dict%c_ptr))
367 CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
368 key=trim(key)//c_null_char, &
373 cpabort(
"CP2K compiled without the Torch library.")
378 END SUBROUTINE torch_dict_insert_float_2d
384 SUBROUTINE torch_dict_get_float_2d(dict, key, dest)
385 TYPE(torch_dict_type),
INTENT(IN) :: dict
386 CHARACTER(len=*),
INTENT(IN) :: key
387 REAL(sp),
DIMENSION(:, :),
POINTER :: dest
389 #if defined(__LIBTORCH)
390 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_f, sizes_c
391 TYPE(C_PTR) :: dest_c
394 SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
395 BIND(C, name="torch_c_dict_get_float")
396 IMPORT :: c_char, c_ptr, c_int, c_int64_t
397 TYPE(C_PTR),
VALUE :: dict
398 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
399 INTEGER(kind=C_INT),
VALUE :: ndims
400 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
402 END SUBROUTINE torch_c_dict_get_float
407 cpassert(c_associated(dict%c_ptr))
408 cpassert(.NOT.
ASSOCIATED(dest))
409 CALL torch_c_dict_get_float (dict=dict%c_ptr, &
410 key=trim(key)//c_null_char, &
415 cpassert(all(sizes_c >= 0))
416 cpassert(c_associated(dest_c))
418 sizes_f(1) = sizes_c(2)
419 sizes_f(2) = sizes_c(1)
420 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
422 cpabort(
"CP2K compiled without the Torch library.")
427 END SUBROUTINE torch_dict_get_float_2d
434 SUBROUTINE torch_dict_insert_int64_2d(dict, key, source)
435 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
436 CHARACTER(len=*),
INTENT(IN) :: key
437 INTEGER(kind=int_8),
CONTIGUOUS,
DIMENSION(:, :),
INTENT(IN) :: source
439 #if defined(__LIBTORCH)
440 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_c
443 SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
444 BIND(C, name="torch_c_dict_insert_int64")
445 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
446 TYPE(C_PTR),
VALUE :: dict
447 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
448 INTEGER(kind=C_INT),
VALUE :: ndims
449 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
450 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: source
451 END SUBROUTINE torch_c_dict_insert_int64
454 sizes_c(1) =
SIZE(source, 2)
455 sizes_c(2) =
SIZE(source, 1)
457 cpassert(c_associated(dict%c_ptr))
458 CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
459 key=trim(key)//c_null_char, &
464 cpabort(
"CP2K compiled without the Torch library.")
469 END SUBROUTINE torch_dict_insert_int64_2d
475 SUBROUTINE torch_dict_get_int64_2d(dict, key, dest)
476 TYPE(torch_dict_type),
INTENT(IN) :: dict
477 CHARACTER(len=*),
INTENT(IN) :: key
478 INTEGER(kind=int_8),
DIMENSION(:, :),
POINTER :: dest
480 #if defined(__LIBTORCH)
481 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_f, sizes_c
482 TYPE(C_PTR) :: dest_c
485 SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
486 BIND(C, name="torch_c_dict_get_int64")
487 IMPORT :: c_char, c_ptr, c_int, c_int64_t
488 TYPE(C_PTR),
VALUE :: dict
489 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
490 INTEGER(kind=C_INT),
VALUE :: ndims
491 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
493 END SUBROUTINE torch_c_dict_get_int64
498 cpassert(c_associated(dict%c_ptr))
499 cpassert(.NOT.
ASSOCIATED(dest))
500 CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
501 key=trim(key)//c_null_char, &
506 cpassert(all(sizes_c >= 0))
507 cpassert(c_associated(dest_c))
509 sizes_f(1) = sizes_c(2)
510 sizes_f(2) = sizes_c(1)
511 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
513 cpabort(
"CP2K compiled without the Torch library.")
518 END SUBROUTINE torch_dict_get_int64_2d
525 SUBROUTINE torch_dict_insert_double_2d(dict, key, source)
526 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
527 CHARACTER(len=*),
INTENT(IN) :: key
528 REAL(dp),
CONTIGUOUS,
DIMENSION(:, :),
INTENT(IN) :: source
530 #if defined(__LIBTORCH)
531 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_c
534 SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
535 BIND(C, name="torch_c_dict_insert_double")
536 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
537 TYPE(C_PTR),
VALUE :: dict
538 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
539 INTEGER(kind=C_INT),
VALUE :: ndims
540 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
541 REAL(kind=c_double),
DIMENSION(*) :: source
542 END SUBROUTINE torch_c_dict_insert_double
545 sizes_c(1) =
SIZE(source, 2)
546 sizes_c(2) =
SIZE(source, 1)
548 cpassert(c_associated(dict%c_ptr))
549 CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
550 key=trim(key)//c_null_char, &
555 cpabort(
"CP2K compiled without the Torch library.")
560 END SUBROUTINE torch_dict_insert_double_2d
566 SUBROUTINE torch_dict_get_double_2d(dict, key, dest)
567 TYPE(torch_dict_type),
INTENT(IN) :: dict
568 CHARACTER(len=*),
INTENT(IN) :: key
569 REAL(dp),
DIMENSION(:, :),
POINTER :: dest
571 #if defined(__LIBTORCH)
572 INTEGER(kind=int_8),
DIMENSION(2) :: sizes_f, sizes_c
573 TYPE(C_PTR) :: dest_c
576 SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
577 BIND(C, name="torch_c_dict_get_double")
578 IMPORT :: c_char, c_ptr, c_int, c_int64_t
579 TYPE(C_PTR),
VALUE :: dict
580 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
581 INTEGER(kind=C_INT),
VALUE :: ndims
582 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
584 END SUBROUTINE torch_c_dict_get_double
589 cpassert(c_associated(dict%c_ptr))
590 cpassert(.NOT.
ASSOCIATED(dest))
591 CALL torch_c_dict_get_double (dict=dict%c_ptr, &
592 key=trim(key)//c_null_char, &
597 cpassert(all(sizes_c >= 0))
598 cpassert(c_associated(dest_c))
600 sizes_f(1) = sizes_c(2)
601 sizes_f(2) = sizes_c(1)
602 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
604 cpabort(
"CP2K compiled without the Torch library.")
609 END SUBROUTINE torch_dict_get_double_2d
616 SUBROUTINE torch_dict_insert_float_3d(dict, key, source)
617 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
618 CHARACTER(len=*),
INTENT(IN) :: key
619 REAL(sp),
CONTIGUOUS,
DIMENSION(:, :, :),
INTENT(IN) :: source
621 #if defined(__LIBTORCH)
622 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_c
625 SUBROUTINE torch_c_dict_insert_float (dict, key, ndims, sizes, source) &
626 BIND(C, name="torch_c_dict_insert_float")
627 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
628 TYPE(C_PTR),
VALUE :: dict
629 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
630 INTEGER(kind=C_INT),
VALUE :: ndims
631 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
632 REAL(kind=c_float),
DIMENSION(*) :: source
633 END SUBROUTINE torch_c_dict_insert_float
636 sizes_c(1) =
SIZE(source, 3)
637 sizes_c(2) =
SIZE(source, 2)
638 sizes_c(3) =
SIZE(source, 1)
640 cpassert(c_associated(dict%c_ptr))
641 CALL torch_c_dict_insert_float (dict=dict%c_ptr, &
642 key=trim(key)//c_null_char, &
647 cpabort(
"CP2K compiled without the Torch library.")
652 END SUBROUTINE torch_dict_insert_float_3d
658 SUBROUTINE torch_dict_get_float_3d(dict, key, dest)
659 TYPE(torch_dict_type),
INTENT(IN) :: dict
660 CHARACTER(len=*),
INTENT(IN) :: key
661 REAL(sp),
DIMENSION(:, :, :),
POINTER :: dest
663 #if defined(__LIBTORCH)
664 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_f, sizes_c
665 TYPE(C_PTR) :: dest_c
668 SUBROUTINE torch_c_dict_get_float (dict, key, ndims, sizes, dest) &
669 BIND(C, name="torch_c_dict_get_float")
670 IMPORT :: c_char, c_ptr, c_int, c_int64_t
671 TYPE(C_PTR),
VALUE :: dict
672 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
673 INTEGER(kind=C_INT),
VALUE :: ndims
674 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
676 END SUBROUTINE torch_c_dict_get_float
681 cpassert(c_associated(dict%c_ptr))
682 cpassert(.NOT.
ASSOCIATED(dest))
683 CALL torch_c_dict_get_float (dict=dict%c_ptr, &
684 key=trim(key)//c_null_char, &
689 cpassert(all(sizes_c >= 0))
690 cpassert(c_associated(dest_c))
692 sizes_f(1) = sizes_c(3)
693 sizes_f(2) = sizes_c(2)
694 sizes_f(3) = sizes_c(1)
695 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
697 cpabort(
"CP2K compiled without the Torch library.")
702 END SUBROUTINE torch_dict_get_float_3d
709 SUBROUTINE torch_dict_insert_int64_3d(dict, key, source)
710 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
711 CHARACTER(len=*),
INTENT(IN) :: key
712 INTEGER(kind=int_8),
CONTIGUOUS,
DIMENSION(:, :, :),
INTENT(IN) :: source
714 #if defined(__LIBTORCH)
715 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_c
718 SUBROUTINE torch_c_dict_insert_int64 (dict, key, ndims, sizes, source) &
719 BIND(C, name="torch_c_dict_insert_int64")
720 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
721 TYPE(C_PTR),
VALUE :: dict
722 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
723 INTEGER(kind=C_INT),
VALUE :: ndims
724 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
725 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: source
726 END SUBROUTINE torch_c_dict_insert_int64
729 sizes_c(1) =
SIZE(source, 3)
730 sizes_c(2) =
SIZE(source, 2)
731 sizes_c(3) =
SIZE(source, 1)
733 cpassert(c_associated(dict%c_ptr))
734 CALL torch_c_dict_insert_int64 (dict=dict%c_ptr, &
735 key=trim(key)//c_null_char, &
740 cpabort(
"CP2K compiled without the Torch library.")
745 END SUBROUTINE torch_dict_insert_int64_3d
751 SUBROUTINE torch_dict_get_int64_3d(dict, key, dest)
752 TYPE(torch_dict_type),
INTENT(IN) :: dict
753 CHARACTER(len=*),
INTENT(IN) :: key
754 INTEGER(kind=int_8),
DIMENSION(:, :, :),
POINTER :: dest
756 #if defined(__LIBTORCH)
757 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_f, sizes_c
758 TYPE(C_PTR) :: dest_c
761 SUBROUTINE torch_c_dict_get_int64 (dict, key, ndims, sizes, dest) &
762 BIND(C, name="torch_c_dict_get_int64")
763 IMPORT :: c_char, c_ptr, c_int, c_int64_t
764 TYPE(C_PTR),
VALUE :: dict
765 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
766 INTEGER(kind=C_INT),
VALUE :: ndims
767 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
769 END SUBROUTINE torch_c_dict_get_int64
774 cpassert(c_associated(dict%c_ptr))
775 cpassert(.NOT.
ASSOCIATED(dest))
776 CALL torch_c_dict_get_int64 (dict=dict%c_ptr, &
777 key=trim(key)//c_null_char, &
782 cpassert(all(sizes_c >= 0))
783 cpassert(c_associated(dest_c))
785 sizes_f(1) = sizes_c(3)
786 sizes_f(2) = sizes_c(2)
787 sizes_f(3) = sizes_c(1)
788 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
790 cpabort(
"CP2K compiled without the Torch library.")
795 END SUBROUTINE torch_dict_get_int64_3d
802 SUBROUTINE torch_dict_insert_double_3d(dict, key, source)
803 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
804 CHARACTER(len=*),
INTENT(IN) :: key
805 REAL(dp),
CONTIGUOUS,
DIMENSION(:, :, :),
INTENT(IN) :: source
807 #if defined(__LIBTORCH)
808 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_c
811 SUBROUTINE torch_c_dict_insert_double (dict, key, ndims, sizes, source) &
812 BIND(C, name="torch_c_dict_insert_double")
813 IMPORT :: c_char, c_ptr, c_int, c_int64_t, c_float, c_double
814 TYPE(C_PTR),
VALUE :: dict
815 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
816 INTEGER(kind=C_INT),
VALUE :: ndims
817 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
818 REAL(kind=c_double),
DIMENSION(*) :: source
819 END SUBROUTINE torch_c_dict_insert_double
822 sizes_c(1) =
SIZE(source, 3)
823 sizes_c(2) =
SIZE(source, 2)
824 sizes_c(3) =
SIZE(source, 1)
826 cpassert(c_associated(dict%c_ptr))
827 CALL torch_c_dict_insert_double (dict=dict%c_ptr, &
828 key=trim(key)//c_null_char, &
833 cpabort(
"CP2K compiled without the Torch library.")
838 END SUBROUTINE torch_dict_insert_double_3d
844 SUBROUTINE torch_dict_get_double_3d(dict, key, dest)
845 TYPE(torch_dict_type),
INTENT(IN) :: dict
846 CHARACTER(len=*),
INTENT(IN) :: key
847 REAL(dp),
DIMENSION(:, :, :),
POINTER :: dest
849 #if defined(__LIBTORCH)
850 INTEGER(kind=int_8),
DIMENSION(3) :: sizes_f, sizes_c
851 TYPE(C_PTR) :: dest_c
854 SUBROUTINE torch_c_dict_get_double (dict, key, ndims, sizes, dest) &
855 BIND(C, name="torch_c_dict_get_double")
856 IMPORT :: c_char, c_ptr, c_int, c_int64_t
857 TYPE(C_PTR),
VALUE :: dict
858 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: key
859 INTEGER(kind=C_INT),
VALUE :: ndims
860 INTEGER(kind=C_INT64_T),
DIMENSION(*) :: sizes
862 END SUBROUTINE torch_c_dict_get_double
867 cpassert(c_associated(dict%c_ptr))
868 cpassert(.NOT.
ASSOCIATED(dest))
869 CALL torch_c_dict_get_double (dict=dict%c_ptr, &
870 key=trim(key)//c_null_char, &
875 cpassert(all(sizes_c >= 0))
876 cpassert(c_associated(dest_c))
878 sizes_f(1) = sizes_c(3)
879 sizes_f(2) = sizes_c(2)
880 sizes_f(3) = sizes_c(1)
881 CALL c_f_pointer(dest_c, dest, shape=sizes_f)
883 cpabort(
"CP2K compiled without the Torch library.")
888 END SUBROUTINE torch_dict_get_double_3d
896 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
898 #if defined(__LIBTORCH)
900 SUBROUTINE torch_c_dict_create(dict)
BIND(C, name="torch_c_dict_create")
903 END SUBROUTINE torch_c_dict_create
906 cpassert(.NOT. c_associated(dict%c_ptr))
907 CALL torch_c_dict_create(dict=dict%c_ptr)
908 cpassert(c_associated(dict%c_ptr))
910 cpabort(
"CP2K was compiled without Torch library.")
920 TYPE(torch_dict_type),
INTENT(INOUT) :: dict
922 #if defined(__LIBTORCH)
924 SUBROUTINE torch_c_dict_release(dict)
BIND(C, name="torch_c_dict_release")
926 TYPE(c_ptr),
VALUE :: dict
927 END SUBROUTINE torch_c_dict_release
930 cpassert(c_associated(dict%c_ptr))
931 CALL torch_c_dict_release(dict=dict%c_ptr)
932 dict%c_ptr = c_null_ptr
934 cpabort(
"CP2K was compiled without Torch library.")
944 TYPE(torch_model_type),
INTENT(INOUT) :: model
945 CHARACTER(len=*),
INTENT(IN) :: filename
947 #if defined(__LIBTORCH)
949 SUBROUTINE torch_c_model_load(model, filename)
BIND(C, name="torch_c_model_load")
950 IMPORT :: c_ptr, c_char
952 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: filename
953 END SUBROUTINE torch_c_model_load
956 cpassert(.NOT. c_associated(model%c_ptr))
957 CALL torch_c_model_load(model=model%c_ptr, filename=trim(filename)//c_null_char)
958 cpassert(c_associated(model%c_ptr))
960 cpabort(
"CP2K was compiled without Torch library.")
971 TYPE(torch_model_type),
INTENT(INOUT) :: model
972 TYPE(torch_dict_type),
INTENT(IN) :: inputs
973 TYPE(torch_dict_type),
INTENT(INOUT) :: outputs
975 #if defined(__LIBTORCH)
977 SUBROUTINE torch_c_model_eval(model, inputs, outputs)
BIND(C, name="torch_c_model_eval")
979 TYPE(c_ptr),
VALUE :: model
980 TYPE(c_ptr),
VALUE :: inputs
981 TYPE(c_ptr),
VALUE :: outputs
982 END SUBROUTINE torch_c_model_eval
985 cpassert(c_associated(model%c_ptr))
986 cpassert(c_associated(inputs%c_ptr))
987 cpassert(c_associated(outputs%c_ptr))
988 CALL torch_c_model_eval(model=model%c_ptr, &
989 inputs=inputs%c_ptr, &
990 outputs=outputs%c_ptr)
992 cpabort(
"CP2K was compiled without Torch library.")
1004 TYPE(torch_model_type),
INTENT(INOUT) :: model
1006 #if defined(__LIBTORCH)
1008 SUBROUTINE torch_c_model_release(model)
BIND(C, name="torch_c_model_release")
1010 TYPE(c_ptr),
VALUE :: model
1011 END SUBROUTINE torch_c_model_release
1014 cpassert(c_associated(model%c_ptr))
1015 CALL torch_c_model_release(model=model%c_ptr)
1016 model%c_ptr = c_null_ptr
1018 cpabort(
"CP2K was compiled without Torch library.")
1028 CHARACTER(len=*),
INTENT(IN) :: filename, key
1029 CHARACTER(:),
ALLOCATABLE :: res
1031 #if defined(__LIBTORCH)
1032 CHARACTER(LEN=1, KIND=C_CHAR),
DIMENSION(:), &
1033 POINTER :: content_f
1036 TYPE(c_ptr) :: content_c
1039 SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
1040 BIND(C, name="torch_c_model_read_metadata")
1041 IMPORT :: c_char, c_ptr, c_int
1042 CHARACTER(kind=C_CHAR),
DIMENSION(*) :: filename, key
1043 TYPE(c_ptr) :: content
1044 INTEGER(kind=C_INT) :: length
1045 END SUBROUTINE torch_c_model_read_metadata
1048 content_c = c_null_ptr
1050 CALL torch_c_model_read_metadata(filename=trim(filename)//c_null_char, &
1051 key=trim(key)//c_null_char, &
1052 content=content_c, &
1054 cpassert(c_associated(content_c))
1055 cpassert(length >= 0)
1057 CALL c_f_pointer(content_c, content_f, shape=(/length + 1/))
1058 cpassert(content_f(length + 1) == c_null_char)
1060 ALLOCATE (
CHARACTER(LEN=length) :: res)
1062 cpassert(content_f(i) /= c_null_char)
1063 res(i:i) = content_f(i)
1066 DEALLOCATE (content_f)
1068 cpabort(
"CP2K was compiled without Torch library.")
1082 #if defined(__LIBTORCH)
1084 FUNCTION torch_c_cuda_is_available()
BIND(C, name="torch_c_cuda_is_available")
1086 LOGICAL(C_BOOL) :: torch_c_cuda_is_available
1087 END FUNCTION torch_c_cuda_is_available
1090 res = torch_c_cuda_is_available()
1092 cpabort(
"CP2K was compiled without Torch library.")
1104 LOGICAL,
INTENT(IN) :: allow_tf32
1106 #if defined(__LIBTORCH)
1108 SUBROUTINE torch_c_allow_tf32(allow_tf32)
BIND(C, name="torch_c_allow_tf32")
1110 LOGICAL(C_BOOL),
VALUE :: allow_tf32
1111 END SUBROUTINE torch_c_allow_tf32
1114 CALL torch_c_allow_tf32(allow_tf32=
LOGICAL(allow_tf32, c_bool))
1116 cpabort(
"CP2K was compiled without Torch library.")
1117 mark_used(allow_tf32)
1127 TYPE(torch_model_type),
INTENT(INOUT) :: model
1129 #if defined(__LIBTORCH)
1131 SUBROUTINE torch_c_model_freeze(model)
BIND(C, name="torch_c_model_freeze")
1133 TYPE(c_ptr),
VALUE :: model
1134 END SUBROUTINE torch_c_model_freeze
1137 cpassert(c_associated(model%c_ptr))
1138 CALL torch_c_model_freeze(model=model%c_ptr)
1140 cpabort(
"CP2K was compiled without Torch library.")
Defines the basic variable types.
integer, parameter, public int_8
integer, parameter, public dp
integer, parameter, public sp
subroutine, public torch_dict_release(dict)
Releases a Torch dictionary and all its ressources.
subroutine, public torch_model_load(model, filename)
Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
subroutine, public torch_dict_create(dict)
Creates an empty Torch dictionary.
subroutine, public torch_model_release(model)
Releases a Torch model and all its ressources.
subroutine, public torch_allow_tf32(allow_tf32)
Set whether to allow the use of TF32. Needed due to changes in defaults from pytorch 1....
subroutine, public torch_model_eval(model, inputs, outputs)
Evaluates the given Torch model. (In Torch lingo this operation is called forward())
subroutine, public torch_model_freeze(model)
Freeze the given Torch model: applies generic optimization that speed up model. See https://pytorch....
character(:) function, allocatable, public torch_model_read_metadata(filename, key)
Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
logical function, public torch_cuda_is_available()
Returns true iff the Torch CUDA backend is available.