19 USE dbcsr_api,
ONLY: &
20 dbcsr_type, dbcsr_release, &
21 dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
22 dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
70 #include "../base/base_uses.f90"
74 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'dbt_methods'
113 SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
114 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_in, tensor_out
115 INTEGER,
DIMENSION(ndims_tensor(tensor_in)), &
116 INTENT(IN),
OPTIONAL :: order
117 LOGICAL,
INTENT(IN),
OPTIONAL :: summation, move_data
118 INTEGER,
DIMENSION(2, ndims_tensor(tensor_in)), &
119 INTENT(IN),
OPTIONAL :: bounds
120 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
123 CALL tensor_in%pgrid%mp_comm_2d%sync()
124 CALL timeset(
"dbt_total", handle)
130 CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
131 CALL tensor_in%pgrid%mp_comm_2d%sync()
132 CALL timestop(handle)
139 SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
140 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_in, tensor_out
141 INTEGER,
DIMENSION(ndims_tensor(tensor_in)), &
142 INTENT(IN),
OPTIONAL :: order
143 LOGICAL,
INTENT(IN),
OPTIONAL :: summation, move_data
144 INTEGER,
DIMENSION(2, ndims_tensor(tensor_in)), &
145 INTENT(IN),
OPTIONAL :: bounds
146 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
148 TYPE(dbt_type),
POINTER :: in_tmp_1, in_tmp_2, &
150 INTEGER :: handle, unit_nr_prv
151 INTEGER,
DIMENSION(:),
ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
153 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_copy'
154 LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
155 summation_prv, new_in_1, new_in_2, &
156 new_in_3, new_out_1, block_compatible, &
158 TYPE(array_list) :: blk_sizes_in
160 CALL timeset(routinen, handle)
162 cpassert(tensor_out%valid)
166 IF (
PRESENT(move_data))
THEN
172 dist_compatible_tas = .false.
173 dist_compatible_tensor = .false.
174 block_compatible = .false.
180 IF (
PRESENT(summation))
THEN
181 summation_prv = summation
183 summation_prv = .false.
186 IF (
PRESENT(bounds))
THEN
188 CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
192 in_tmp_1 => tensor_in
195 IF (
PRESENT(order))
THEN
196 CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
197 block_compatible =
check_equal(blk_sizes_in, tensor_out%blk_sizes)
199 block_compatible =
check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
202 IF (.NOT. block_compatible)
THEN
203 ALLOCATE (in_tmp_2, out_tmp_1)
205 nodata2=.NOT. summation_prv, move_data=move_prv)
206 new_in_2 = .true.; new_out_1 = .true.
210 out_tmp_1 => tensor_out
213 IF (
PRESENT(order))
THEN
215 CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
229 IF (.NOT.
PRESENT(order))
THEN
231 dist_compatible_tas =
check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
232 ELSEIF (
array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2]))
THEN
233 dist_compatible_tensor =
check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
237 IF (dist_compatible_tas)
THEN
238 CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
240 ELSEIF (dist_compatible_tensor)
THEN
241 CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
244 CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
249 DEALLOCATE (in_tmp_1)
254 DEALLOCATE (in_tmp_2)
259 DEALLOCATE (in_tmp_3)
263 IF (unit_nr_prv /= 0)
THEN
268 DEALLOCATE (out_tmp_1)
271 CALL timestop(handle)
280 SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
281 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
282 TYPE(dbt_type),
INTENT(INOUT) :: tensor_out
283 LOGICAL,
INTENT(IN),
OPTIONAL :: summation
284 TYPE(dbt_iterator_type) :: iter
285 INTEGER,
DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
286 TYPE(block_nd) :: blk_data
289 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_copy_nocomm'
292 CALL timeset(routinen, handle)
293 cpassert(tensor_out%valid)
295 IF (
PRESENT(summation))
THEN
296 IF (.NOT. summation)
CALL dbt_clear(tensor_out)
301 CALL dbt_reserve_blocks(tensor_in, tensor_out)
308 CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
310 CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
316 CALL timestop(handle)
325 TYPE(dbcsr_type),
TARGET,
INTENT(IN) :: matrix_in
326 TYPE(dbt_type),
INTENT(INOUT) :: tensor_out
327 LOGICAL,
INTENT(IN),
OPTIONAL :: summation
328 TYPE(dbcsr_type),
POINTER :: matrix_in_desym
330 INTEGER,
DIMENSION(2) :: ind_2d
331 REAL(kind=
dp),
ALLOCATABLE,
DIMENSION(:, :) :: block_arr
332 REAL(kind=
dp),
DIMENSION(:, :),
POINTER :: block
333 TYPE(dbcsr_iterator_type) :: iter
337 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_copy_matrix_to_tensor'
339 CALL timeset(routinen, handle)
340 cpassert(tensor_out%valid)
344 IF (dbcsr_has_symmetry(matrix_in))
THEN
345 ALLOCATE (matrix_in_desym)
346 CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
348 matrix_in_desym => matrix_in
351 IF (
PRESENT(summation))
THEN
352 IF (.NOT. summation)
CALL dbt_clear(tensor_out)
357 CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
361 CALL dbcsr_iterator_start(iter, matrix_in_desym)
362 DO WHILE (dbcsr_iterator_blocks_left(iter))
363 CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
364 CALL allocate_any(block_arr, source=block)
365 CALL dbt_put_block(tensor_out, ind_2d, shape(block_arr), block_arr, summation=summation)
366 DEALLOCATE (block_arr)
368 CALL dbcsr_iterator_stop(iter)
371 IF (dbcsr_has_symmetry(matrix_in))
THEN
372 CALL dbcsr_release(matrix_in_desym)
373 DEALLOCATE (matrix_in_desym)
376 CALL timestop(handle)
386 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
387 TYPE(dbcsr_type),
INTENT(INOUT) :: matrix_out
388 LOGICAL,
INTENT(IN),
OPTIONAL :: summation
389 TYPE(dbt_iterator_type) :: iter
391 INTEGER,
DIMENSION(2) :: ind_2d
392 REAL(kind=
dp),
DIMENSION(:, :),
ALLOCATABLE :: block
393 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_copy_tensor_to_matrix'
396 CALL timeset(routinen, handle)
398 IF (
PRESENT(summation))
THEN
399 IF (.NOT. summation)
CALL dbcsr_clear(matrix_out)
401 CALL dbcsr_clear(matrix_out)
404 CALL dbt_reserve_blocks(tensor_in, matrix_out)
411 IF (dbcsr_has_symmetry(matrix_out) .AND.
checker_tr(ind_2d(1), ind_2d(2))) cycle
413 CALL dbt_get_block(tensor_in, ind_2d, block, found)
416 IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2))
THEN
417 CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), transpose(block), summation=summation)
419 CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
426 CALL timestop(handle)
493 contract_1, notcontract_1, &
494 contract_2, notcontract_2, &
496 bounds_1, bounds_2, bounds_3, &
497 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
498 filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
499 REAL(
dp),
INTENT(IN) :: alpha
500 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_1
501 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_2
502 REAL(
dp),
INTENT(IN) :: beta
503 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_1
504 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_2
505 INTEGER,
DIMENSION(:),
INTENT(IN) :: map_1
506 INTEGER,
DIMENSION(:),
INTENT(IN) :: map_2
507 INTEGER,
DIMENSION(:),
INTENT(IN) :: notcontract_1
508 INTEGER,
DIMENSION(:),
INTENT(IN) :: notcontract_2
509 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_3
510 INTEGER,
DIMENSION(2, SIZE(contract_1)), &
511 INTENT(IN),
OPTIONAL :: bounds_1
512 INTEGER,
DIMENSION(2, SIZE(notcontract_1)), &
513 INTENT(IN),
OPTIONAL :: bounds_2
514 INTEGER,
DIMENSION(2, SIZE(notcontract_2)), &
515 INTENT(IN),
OPTIONAL :: bounds_3
516 LOGICAL,
INTENT(IN),
OPTIONAL :: optimize_dist
517 TYPE(dbt_pgrid_type),
INTENT(OUT), &
518 POINTER,
OPTIONAL :: pgrid_opt_1
519 TYPE(dbt_pgrid_type),
INTENT(OUT), &
520 POINTER,
OPTIONAL :: pgrid_opt_2
521 TYPE(dbt_pgrid_type),
INTENT(OUT), &
522 POINTER,
OPTIONAL :: pgrid_opt_3
523 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: filter_eps
524 INTEGER(KIND=int_8),
INTENT(OUT),
OPTIONAL :: flop
525 LOGICAL,
INTENT(IN),
OPTIONAL :: move_data
526 LOGICAL,
INTENT(IN),
OPTIONAL :: retain_sparsity
527 INTEGER,
OPTIONAL,
INTENT(IN) :: unit_nr
528 LOGICAL,
INTENT(IN),
OPTIONAL :: log_verbose
532 CALL tensor_1%pgrid%mp_comm_2d%sync()
533 CALL timeset(
"dbt_total", handle)
534 CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
535 contract_1, notcontract_1, &
536 contract_2, notcontract_2, &
541 optimize_dist=optimize_dist, &
542 pgrid_opt_1=pgrid_opt_1, &
543 pgrid_opt_2=pgrid_opt_2, &
544 pgrid_opt_3=pgrid_opt_3, &
545 filter_eps=filter_eps, &
547 move_data=move_data, &
548 retain_sparsity=retain_sparsity, &
550 log_verbose=log_verbose)
551 CALL tensor_1%pgrid%mp_comm_2d%sync()
552 CALL timestop(handle)
561 SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
562 contract_1, notcontract_1, &
563 contract_2, notcontract_2, &
565 bounds_1, bounds_2, bounds_3, &
566 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
567 filter_eps, flop, move_data, retain_sparsity, &
568 nblks_local, unit_nr, log_verbose)
569 REAL(
dp),
INTENT(IN) :: alpha
570 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_1
571 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_2
572 REAL(
dp),
INTENT(IN) :: beta
573 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_1
574 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_2
575 INTEGER,
DIMENSION(:),
INTENT(IN) :: map_1
576 INTEGER,
DIMENSION(:),
INTENT(IN) :: map_2
577 INTEGER,
DIMENSION(:),
INTENT(IN) :: notcontract_1
578 INTEGER,
DIMENSION(:),
INTENT(IN) :: notcontract_2
579 TYPE(dbt_type),
INTENT(INOUT),
TARGET :: tensor_3
580 INTEGER,
DIMENSION(2, SIZE(contract_1)), &
581 INTENT(IN),
OPTIONAL :: bounds_1
582 INTEGER,
DIMENSION(2, SIZE(notcontract_1)), &
583 INTENT(IN),
OPTIONAL :: bounds_2
584 INTEGER,
DIMENSION(2, SIZE(notcontract_2)), &
585 INTENT(IN),
OPTIONAL :: bounds_3
586 LOGICAL,
INTENT(IN),
OPTIONAL :: optimize_dist
587 TYPE(dbt_pgrid_type),
INTENT(OUT), &
588 POINTER,
OPTIONAL :: pgrid_opt_1
589 TYPE(dbt_pgrid_type),
INTENT(OUT), &
590 POINTER,
OPTIONAL :: pgrid_opt_2
591 TYPE(dbt_pgrid_type),
INTENT(OUT), &
592 POINTER,
OPTIONAL :: pgrid_opt_3
593 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: filter_eps
594 INTEGER(KIND=int_8),
INTENT(OUT),
OPTIONAL :: flop
595 LOGICAL,
INTENT(IN),
OPTIONAL :: move_data
596 LOGICAL,
INTENT(IN),
OPTIONAL :: retain_sparsity
597 INTEGER,
INTENT(OUT),
OPTIONAL :: nblks_local
598 INTEGER,
OPTIONAL,
INTENT(IN) :: unit_nr
599 LOGICAL,
INTENT(IN),
OPTIONAL :: log_verbose
601 TYPE(dbt_type),
POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
602 TYPE(dbt_type),
TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
603 TYPE(dbt_type),
POINTER :: tensor_crop_1, tensor_crop_2
604 TYPE(dbt_type),
POINTER :: tensor_small, tensor_large
606 LOGICAL :: assert_stmt, tensors_remapped
607 INTEGER :: max_mm_dim, max_tensor, &
608 unit_nr_prv, ref_tensor, handle
609 TYPE(mp_cart_type) :: mp_comm_opt
610 INTEGER,
DIMENSION(SIZE(contract_1)) :: contract_1_mod
611 INTEGER,
DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
612 INTEGER,
DIMENSION(SIZE(contract_2)) :: contract_2_mod
613 INTEGER,
DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
614 INTEGER,
DIMENSION(SIZE(map_1)) :: map_1_mod
615 INTEGER,
DIMENSION(SIZE(map_2)) :: map_2_mod
616 LOGICAL :: trans_1, trans_2, trans_3
617 LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
618 INTEGER :: ndims1, ndims2, ndims3
619 INTEGER :: occ_1, occ_2
620 INTEGER,
DIMENSION(:),
ALLOCATABLE :: dims1, dims2, dims3
622 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_contract'
623 CHARACTER(LEN=1),
DIMENSION(:),
ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
624 indchar2_mod, indchar3_mod
625 CHARACTER(LEN=1),
DIMENSION(15),
SAVE :: alph = &
626 [
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o']
627 INTEGER,
DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
628 INTEGER,
DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
629 LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
630 pgrid_changed_any, do_change_pgrid(2)
631 TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
632 INTEGER,
DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
633 REAL(
dp) :: pdim_ratio, pdim_ratio_opt
635 NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
638 CALL timeset(routinen, handle)
640 cpassert(tensor_1%valid)
641 cpassert(tensor_2%valid)
642 cpassert(tensor_3%valid)
644 assert_stmt =
SIZE(contract_1) .EQ.
SIZE(contract_2)
645 cpassert(assert_stmt)
647 assert_stmt =
SIZE(map_1) .EQ.
SIZE(notcontract_1)
648 cpassert(assert_stmt)
650 assert_stmt =
SIZE(map_2) .EQ.
SIZE(notcontract_2)
651 cpassert(assert_stmt)
653 assert_stmt =
SIZE(notcontract_1) +
SIZE(contract_1) .EQ.
ndims_tensor(tensor_1)
654 cpassert(assert_stmt)
656 assert_stmt =
SIZE(notcontract_2) +
SIZE(contract_2) .EQ.
ndims_tensor(tensor_2)
657 cpassert(assert_stmt)
659 assert_stmt =
SIZE(map_1) +
SIZE(map_2) .EQ.
ndims_tensor(tensor_3)
660 cpassert(assert_stmt)
664 IF (
PRESENT(flop)) flop = 0
665 IF (
PRESENT(nblks_local)) nblks_local = 0
667 IF (
PRESENT(move_data))
THEN
668 move_data_1 = move_data
669 move_data_2 = move_data
671 move_data_1 = .false.
672 move_data_2 = .false.
676 IF (
PRESENT(retain_sparsity))
THEN
677 IF (retain_sparsity) nodata_3 = .false.
680 CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
681 contract_1, notcontract_1, &
682 contract_2, notcontract_2, &
683 bounds_t1, bounds_t2, &
684 bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
685 do_crop_1=do_crop_1, do_crop_2=do_crop_2)
688 ALLOCATE (tensor_crop_1)
689 CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
692 tensor_crop_1 => tensor_1
696 ALLOCATE (tensor_crop_2)
697 CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
700 tensor_crop_2 => tensor_2
706 associate(mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
708 CALL mp_comm%max(occ_1)
710 CALL mp_comm%max(occ_2)
713 IF (occ_1 == 0 .OR. occ_2 == 0)
THEN
717 DEALLOCATE (tensor_crop_1)
721 DEALLOCATE (tensor_crop_2)
724 CALL timestop(handle)
728 IF (unit_nr_prv /= 0)
THEN
729 IF (unit_nr_prv > 0)
THEN
730 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
731 WRITE (unit_nr_prv,
'(A,1X,A,1X,A,1X,A,1X,A,1X,A)')
"DBT TENSOR CONTRACTION:", &
732 trim(tensor_crop_1%name),
'x', trim(tensor_crop_2%name),
'=', trim(tensor_3%name)
733 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
745 ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
746 ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
747 ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
751 indchar1([notcontract_1, contract_1]) = alph(1:ndims1)
752 indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 +
SIZE(notcontract_2))
753 indchar2(contract_2) = indchar1(contract_1)
754 indchar3(map_1) = indchar1(notcontract_1)
755 indchar3(map_2) = indchar2(notcontract_2)
757 IF (unit_nr_prv /= 0)
CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
758 tensor_crop_2, indchar2, &
759 tensor_3, indchar3, unit_nr_prv)
760 IF (unit_nr_prv > 0)
THEN
761 WRITE (unit_nr_prv,
'(T2,A)')
"aligning tensor index with data"
764 CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
765 tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
767 CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
768 tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
770 CALL align_tensor(tensor_3, map_1, map_2, &
771 tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
773 IF (unit_nr_prv /= 0)
CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
774 tensor_algn_2, indchar2_mod, &
775 tensor_algn_3, indchar3_mod, unit_nr_prv)
777 ALLOCATE (dims1(ndims1))
778 ALLOCATE (dims2(ndims2))
779 ALLOCATE (dims3(ndims3))
787 max_mm_dim = maxloc([product(int(dims1(notcontract_1),
int_8)), &
788 product(int(dims1(contract_1),
int_8)), &
789 product(int(dims2(notcontract_2),
int_8))], dim=1)
790 max_tensor = maxloc([product(int(dims1,
int_8)), product(int(dims2,
int_8)), product(int(dims3,
int_8))], dim=1)
791 SELECT CASE (max_mm_dim)
793 IF (unit_nr_prv > 0)
THEN
794 WRITE (unit_nr_prv,
'(T2,A)')
"large tensors: 1, 3; small tensor: 2"
795 WRITE (unit_nr_prv,
'(T2,A)')
"sorting contraction indices"
797 CALL index_linked_sort(contract_1_mod, contract_2_mod)
798 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
799 SELECT CASE (max_tensor)
801 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
803 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
805 cpabort(
"should not happen")
808 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
809 contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
810 trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
811 move_data_1=move_data_1, unit_nr=unit_nr_prv)
813 CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
814 new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
816 SELECT CASE (ref_tensor)
818 tensor_large => tensor_contr_1
820 tensor_large => tensor_contr_3
822 tensor_small => tensor_contr_2
825 IF (unit_nr_prv > 0)
THEN
826 WRITE (unit_nr_prv,
'(T2,A)')
"large tensors: 1, 2; small tensor: 3"
827 WRITE (unit_nr_prv,
'(T2,A)')
"sorting contraction indices"
830 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
831 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
832 SELECT CASE (max_tensor)
834 CALL index_linked_sort(contract_1_mod, contract_2_mod)
836 CALL index_linked_sort(contract_2_mod, contract_1_mod)
838 cpabort(
"should not happen")
841 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
842 notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
843 trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
844 move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
845 trans_1 = .NOT. trans_1
847 CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
848 new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
850 SELECT CASE (ref_tensor)
852 tensor_large => tensor_contr_1
854 tensor_large => tensor_contr_2
856 tensor_small => tensor_contr_3
859 IF (unit_nr_prv > 0)
THEN
860 WRITE (unit_nr_prv,
'(T2,A)')
"large tensors: 2, 3; small tensor: 1"
861 WRITE (unit_nr_prv,
'(T2,A)')
"sorting contraction indices"
863 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
864 CALL index_linked_sort(contract_2_mod, contract_1_mod)
865 SELECT CASE (max_tensor)
867 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
869 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
871 cpabort(
"should not happen")
874 CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
875 contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
876 trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
877 move_data_1=move_data_2, unit_nr=unit_nr_prv)
879 trans_2 = .NOT. trans_2
880 trans_3 = .NOT. trans_3
882 CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
883 trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
885 SELECT CASE (ref_tensor)
887 tensor_large => tensor_contr_2
889 tensor_large => tensor_contr_3
891 tensor_small => tensor_contr_1
895 IF (unit_nr_prv /= 0)
CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
896 tensor_contr_2, indchar2_mod, &
897 tensor_contr_3, indchar3_mod, unit_nr_prv)
898 IF (unit_nr_prv /= 0)
THEN
906 tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
908 tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
909 unit_nr=unit_nr_prv, log_verbose=log_verbose, &
910 split_opt=split_opt, &
911 move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
913 IF (
PRESENT(pgrid_opt_1))
THEN
914 IF (.NOT. new_1)
THEN
915 ALLOCATE (pgrid_opt_1)
916 pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
920 IF (
PRESENT(pgrid_opt_2))
THEN
921 IF (.NOT. new_2)
THEN
922 ALLOCATE (pgrid_opt_2)
923 pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
927 IF (
PRESENT(pgrid_opt_3))
THEN
928 IF (.NOT. new_3)
THEN
929 ALLOCATE (pgrid_opt_3)
930 pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
934 do_batched = tensor_small%matrix_rep%do_batched > 0
936 tensors_remapped = .false.
937 IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .true.
939 IF (tensors_remapped .AND. do_batched)
THEN
940 CALL cp_warn(__location__, &
941 "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
945 do_change_pgrid(:) = .false.
946 IF ((.NOT. tensors_remapped) .AND. do_batched)
THEN
947 associate(storage => tensor_small%contraction_storage)
948 cpassert(storage%static)
950 do_change_pgrid(:) = &
951 update_contraction_storage(storage, split_opt, split)
953 IF (any(do_change_pgrid))
THEN
954 mp_comm_opt =
dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, nint(storage%nsplit_avg))
956 nint(storage%nsplit_avg), own_comm=.true.)
957 pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
962 IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2))
THEN
964 pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
965 pdims_sub = split%mp_comm_group%num_pe_cart
967 pdim_ratio = maxval(real(pdims_sub, dp))/minval(pdims_sub)
968 pdim_ratio_opt = maxval(real(pdims_sub_opt, dp))/minval(pdims_sub_opt)
969 IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2)
THEN
970 do_change_pgrid(1) = .false.
971 CALL dbt_tas_release_info(split_opt_avg)
976 IF (unit_nr_prv /= 0)
THEN
978 IF (tensor_contr_3%matrix_rep%do_batched > 0)
THEN
979 IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .false.
982 CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
983 CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
989 CALL dbt_scale(tensor_algn_3, beta)
990 CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.true., move_data=.true.)
991 IF (
PRESENT(filter_eps))
CALL dbt_filter(tensor_algn_3, filter_eps)
997 CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
998 CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
999 CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
1001 IF (unit_nr_prv /= 0)
THEN
1002 IF (new_3 .AND. do_write_3)
CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1003 IF (new_3 .AND. do_write_3)
CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1006 CALL dbt_destroy(tensor_algn_1)
1007 CALL dbt_destroy(tensor_algn_2)
1008 CALL dbt_destroy(tensor_algn_3)
1011 CALL dbt_destroy(tensor_crop_1)
1012 DEALLOCATE (tensor_crop_1)
1016 CALL dbt_destroy(tensor_crop_2)
1017 DEALLOCATE (tensor_crop_2)
1021 CALL dbt_destroy(tensor_contr_1)
1022 DEALLOCATE (tensor_contr_1)
1025 CALL dbt_destroy(tensor_contr_2)
1026 DEALLOCATE (tensor_contr_2)
1029 CALL dbt_destroy(tensor_contr_3)
1030 DEALLOCATE (tensor_contr_3)
1033 IF (
PRESENT(move_data))
THEN
1035 CALL dbt_clear(tensor_1)
1036 CALL dbt_clear(tensor_2)
1040 IF (unit_nr_prv > 0)
THEN
1041 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
1042 WRITE (unit_nr_prv,
'(A)')
"TENSOR CONTRACTION DONE"
1043 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
1046 IF (any(do_change_pgrid))
THEN
1047 pgrid_changed_any = .false.
1048 SELECT CASE (max_mm_dim)
1050 IF (
ALLOCATED(tensor_1%contraction_storage) .AND.
ALLOCATED(tensor_3%contraction_storage))
THEN
1051 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1052 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1053 pgrid_changed=pgrid_changed, &
1054 unit_nr=unit_nr_prv)
1055 IF (pgrid_changed) pgrid_changed_any = .true.
1056 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1057 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1058 pgrid_changed=pgrid_changed, &
1059 unit_nr=unit_nr_prv)
1060 IF (pgrid_changed) pgrid_changed_any = .true.
1062 IF (pgrid_changed_any)
THEN
1063 IF (tensor_2%matrix_rep%do_batched == 3)
THEN
1066 CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1070 IF (
ALLOCATED(tensor_1%contraction_storage) .AND.
ALLOCATED(tensor_2%contraction_storage))
THEN
1071 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1072 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1073 pgrid_changed=pgrid_changed, &
1074 unit_nr=unit_nr_prv)
1075 IF (pgrid_changed) pgrid_changed_any = .true.
1076 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1077 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1078 pgrid_changed=pgrid_changed, &
1079 unit_nr=unit_nr_prv)
1080 IF (pgrid_changed) pgrid_changed_any = .true.
1082 IF (pgrid_changed_any)
THEN
1083 IF (tensor_3%matrix_rep%do_batched == 3)
THEN
1084 CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1088 IF (
ALLOCATED(tensor_2%contraction_storage) .AND.
ALLOCATED(tensor_3%contraction_storage))
THEN
1089 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1090 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1091 pgrid_changed=pgrid_changed, &
1092 unit_nr=unit_nr_prv)
1093 IF (pgrid_changed) pgrid_changed_any = .true.
1094 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1095 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1096 pgrid_changed=pgrid_changed, &
1097 unit_nr=unit_nr_prv)
1098 IF (pgrid_changed) pgrid_changed_any = .true.
1100 IF (pgrid_changed_any)
THEN
1101 IF (tensor_1%matrix_rep%do_batched == 3)
THEN
1102 CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1106 CALL dbt_tas_release_info(split_opt_avg)
1109 IF ((.NOT. tensors_remapped) .AND. do_batched)
THEN
1111 CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.true.)
1112 CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.true.)
1113 CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.true.)
1116 CALL dbt_tas_release_info(split_opt)
1118 CALL timestop(handle)
1126 SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1127 tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1128 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
1129 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_in, notcontract_in
1130 TYPE(dbt_type),
INTENT(OUT) :: tensor_out
1131 INTEGER,
DIMENSION(SIZE(contract_in)), &
1132 INTENT(OUT) :: contract_out
1133 INTEGER,
DIMENSION(SIZE(notcontract_in)), &
1134 INTENT(OUT) :: notcontract_out
1135 CHARACTER(LEN=1),
DIMENSION(ndims_tensor(tensor_in)),
INTENT(IN) :: indp_in
1136 CHARACTER(LEN=1),
DIMENSION(ndims_tensor(tensor_in)),
INTENT(OUT) :: indp_out
1137 INTEGER,
DIMENSION(ndims_tensor(tensor_in)) :: align
1139 CALL dbt_align_index(tensor_in, tensor_out, order=align)
1140 contract_out = align(contract_in)
1141 notcontract_out = align(notcontract_in)
1142 indp_out(align) = indp_in
1165 SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1166 ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1167 nodata1, nodata2, move_data_1, &
1168 move_data_2, optimize_dist, unit_nr)
1169 TYPE(dbt_type),
TARGET,
INTENT(INOUT) :: tensor1
1170 TYPE(dbt_type),
TARGET,
INTENT(INOUT) :: tensor2
1171 TYPE(dbt_type),
POINTER,
INTENT(OUT) :: tensor1_out, tensor2_out
1172 INTEGER,
DIMENSION(:),
INTENT(IN) :: ind1_free, ind2_free
1173 INTEGER,
DIMENSION(:),
INTENT(IN) :: ind1_linked, ind2_linked
1174 LOGICAL,
INTENT(OUT) :: trans1, trans2
1175 LOGICAL,
INTENT(OUT) :: new1, new2
1176 INTEGER,
INTENT(OUT) :: ref_tensor
1177 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata1, nodata2
1178 LOGICAL,
INTENT(INOUT),
OPTIONAL :: move_data_1, move_data_2
1179 LOGICAL,
INTENT(IN),
OPTIONAL :: optimize_dist
1180 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
1181 INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1183 TYPE(mp_cart_type) :: comm_2d
1184 TYPE(array_list) :: dist_list
1185 INTEGER,
DIMENSION(:),
ALLOCATABLE :: mp_dims
1186 TYPE(dbt_distribution_type) :: dist_in
1187 INTEGER(KIND=int_8) :: nblkrows, nblkcols
1188 LOGICAL :: optimize_dist_prv
1189 INTEGER,
DIMENSION(ndims_tensor(tensor1)) :: dims1
1190 INTEGER,
DIMENSION(ndims_tensor(tensor2)) :: dims2
1192 NULLIFY (tensor1_out, tensor2_out)
1194 unit_nr_prv = prep_output_unit(unit_nr)
1196 CALL blk_dims_tensor(tensor1, dims1)
1197 CALL blk_dims_tensor(tensor2, dims2)
1199 IF (product(int(dims1, int_8)) .GE. product(int(dims2, int_8)))
THEN
1205 IF (
PRESENT(optimize_dist))
THEN
1206 optimize_dist_prv = optimize_dist
1208 optimize_dist_prv = .false.
1211 compat1 = compat_map(tensor1%nd_index, ind1_linked)
1212 compat2 = compat_map(tensor2%nd_index, ind2_linked)
1213 compat1_old = compat1
1214 compat2_old = compat2
1216 IF (unit_nr_prv > 0)
THEN
1217 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor1%name),
":"
1218 SELECT CASE (compat1)
1220 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1222 WRITE (unit_nr_prv,
'(A)')
"Normal"
1224 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1226 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor2%name),
":"
1227 SELECT CASE (compat2)
1229 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1231 WRITE (unit_nr_prv,
'(A)')
"Normal"
1233 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1240 IF (compat1 == 0 .OR. optimize_dist_prv)
THEN
1244 IF (compat2 == 0 .OR. optimize_dist_prv)
THEN
1248 IF (ref_tensor == 1)
THEN
1249 IF (compat1 == 0 .OR. optimize_dist_prv)
THEN
1250 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"Redistribution of", trim(tensor1%name)
1251 nblkrows = product(int(dims1(ind1_linked), kind=int_8))
1252 nblkcols = product(int(dims1(ind1_free), kind=int_8))
1253 comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1254 ALLOCATE (tensor1_out)
1255 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1256 nodata=nodata1, move_data=move_data_1)
1260 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"No redistribution of", trim(tensor1%name)
1261 tensor1_out => tensor1
1263 IF (compat2 == 0 .OR. optimize_dist_prv)
THEN
1264 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A,1X,A,1X,A)')
"Redistribution of", &
1265 trim(tensor2%name),
"compatible with", trim(tensor1%name)
1266 dist_in = dbt_distribution(tensor1_out)
1267 dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1268 IF (compat1 == 1)
THEN
1271 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1272 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1273 ALLOCATE (tensor2_out)
1274 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1275 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1276 ELSEIF (compat1 == 2)
THEN
1279 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1280 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1281 ALLOCATE (tensor2_out)
1282 CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1283 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1285 cpabort(
"should not happen")
1289 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"No redistribution of", trim(tensor2%name)
1290 tensor2_out => tensor2
1293 IF (compat2 == 0 .OR. optimize_dist_prv)
THEN
1294 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"Redistribution of", trim(tensor2%name)
1295 nblkrows = product(int(dims2(ind2_linked), kind=int_8))
1296 nblkcols = product(int(dims2(ind2_free), kind=int_8))
1297 comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1298 ALLOCATE (tensor2_out)
1299 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1303 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"No redistribution of", trim(tensor2%name)
1304 tensor2_out => tensor2
1306 IF (compat1 == 0 .OR. optimize_dist_prv)
THEN
1307 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A,1X,A,1X,A)')
"Redistribution of", trim(tensor1%name), &
1308 "compatible with", trim(tensor2%name)
1309 dist_in = dbt_distribution(tensor2_out)
1310 dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1311 IF (compat2 == 1)
THEN
1312 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1313 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1314 ALLOCATE (tensor1_out)
1315 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1316 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1317 ELSEIF (compat2 == 2)
THEN
1318 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1319 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1320 ALLOCATE (tensor1_out)
1321 CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1322 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1324 cpabort(
"should not happen")
1328 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"No redistribution of", trim(tensor1%name)
1329 tensor1_out => tensor1
1333 SELECT CASE (compat1)
1339 cpabort(
"should not happen")
1342 SELECT CASE (compat2)
1348 cpabort(
"should not happen")
1351 IF (unit_nr_prv > 0)
THEN
1352 IF (compat1 .NE. compat1_old)
THEN
1353 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor1_out%name),
":"
1354 SELECT CASE (compat1)
1356 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1358 WRITE (unit_nr_prv,
'(A)')
"Normal"
1360 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1363 IF (compat2 .NE. compat2_old)
THEN
1364 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor2_out%name),
":"
1365 SELECT CASE (compat2)
1367 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1369 WRITE (unit_nr_prv,
'(A)')
"Normal"
1371 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1376 IF (new1 .AND.
PRESENT(move_data_1)) move_data_1 = .true.
1377 IF (new2 .AND.
PRESENT(move_data_2)) move_data_2 = .true.
1393 SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1394 TYPE(dbt_type),
TARGET,
INTENT(INOUT) :: tensor_in
1395 INTEGER,
DIMENSION(:),
INTENT(IN) :: ind1, ind2
1396 TYPE(dbt_type),
POINTER,
INTENT(OUT) :: tensor_out
1397 LOGICAL,
INTENT(OUT) :: trans
1398 LOGICAL,
INTENT(OUT) :: new
1399 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata, move_data
1400 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
1401 INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1402 LOGICAL :: nodata_prv
1404 NULLIFY (tensor_out)
1405 IF (
PRESENT(nodata))
THEN
1408 nodata_prv = .false.
1411 unit_nr_prv = prep_output_unit(unit_nr)
1414 compat1 = compat_map(tensor_in%nd_index, ind1)
1415 compat2 = compat_map(tensor_in%nd_index, ind2)
1416 compat1_old = compat1; compat2_old = compat2
1417 IF (unit_nr_prv > 0)
THEN
1418 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor_in%name),
":"
1419 IF (compat1 == 1 .AND. compat2 == 2)
THEN
1420 WRITE (unit_nr_prv,
'(A)')
"Normal"
1421 ELSEIF (compat1 == 2 .AND. compat2 == 1)
THEN
1422 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1424 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1427 IF (compat1 == 0 .or. compat2 == 0)
THEN
1429 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"Redistribution of", trim(tensor_in%name)
1431 ALLOCATE (tensor_out)
1432 CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1433 CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1438 IF (unit_nr_prv > 0)
WRITE (unit_nr_prv,
'(T2,A,1X,A)')
"No redistribution of", trim(tensor_in%name)
1439 tensor_out => tensor_in
1442 IF (compat1 == 1 .AND. compat2 == 2)
THEN
1444 ELSEIF (compat1 == 2 .AND. compat2 == 1)
THEN
1447 cpabort(
"this should not happen")
1450 IF (unit_nr_prv > 0)
THEN
1451 IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2)
THEN
1452 WRITE (unit_nr_prv,
'(T2,A,1X,A,A,1X)', advance=
'no')
"compatibility of", trim(tensor_out%name),
":"
1453 IF (compat1 == 1 .AND. compat2 == 2)
THEN
1454 WRITE (unit_nr_prv,
'(A)')
"Normal"
1455 ELSEIF (compat1 == 2 .AND. compat2 == 1)
THEN
1456 WRITE (unit_nr_prv,
'(A)')
"Transposed"
1458 WRITE (unit_nr_prv,
'(A)')
"Not compatible"
1472 FUNCTION update_contraction_storage(storage, split_opt, split)
RESULT(do_change_pgrid)
1473 TYPE(dbt_contraction_storage),
INTENT(INOUT) :: storage
1474 TYPE(dbt_tas_split_info),
INTENT(IN) :: split_opt
1475 TYPE(dbt_tas_split_info),
INTENT(IN) :: split
1476 INTEGER,
DIMENSION(2) :: pdims, pdims_sub
1477 LOGICAL,
DIMENSION(2) :: do_change_pgrid
1478 REAL(kind=dp) :: change_criterion, pdims_ratio
1479 INTEGER :: nsplit_opt, nsplit
1481 cpassert(
ALLOCATED(split_opt%ngroup_opt))
1482 nsplit_opt = split_opt%ngroup_opt
1483 nsplit = split%ngroup
1485 pdims = split%mp_comm%num_pe_cart
1487 storage%ibatch = storage%ibatch + 1
1489 storage%nsplit_avg = (storage%nsplit_avg*real(storage%ibatch - 1, dp) + real(nsplit_opt, dp)) &
1490 /real(storage%ibatch, dp)
1492 SELECT CASE (split_opt%split_rowcol)
1494 pdims_ratio = real(pdims(1), dp)/pdims(2)
1496 pdims_ratio = real(pdims(2), dp)/pdims(1)
1499 do_change_pgrid(:) = .false.
1502 pdims_sub = split%mp_comm_group%num_pe_cart
1503 change_criterion = maxval(real(pdims_sub, dp))/minval(pdims_sub)
1504 IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .true.
1507 change_criterion = max(real(nsplit, dp)/storage%nsplit_avg, real(storage%nsplit_avg, dp)/nsplit)
1508 IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .true.
1516 FUNCTION compat_map(nd_index, compat_ind)
1517 TYPE(nd_to_2d_mapping),
INTENT(IN) :: nd_index
1518 INTEGER,
DIMENSION(:),
INTENT(IN) :: compat_ind
1519 INTEGER,
DIMENSION(ndims_mapping_row(nd_index)) :: map1
1520 INTEGER,
DIMENSION(ndims_mapping_column(nd_index)) :: map2
1521 INTEGER :: compat_map
1523 CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1526 IF (array_eq_i(map1, compat_ind))
THEN
1528 ELSEIF (array_eq_i(map2, compat_ind))
THEN
1538 SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1539 INTEGER,
DIMENSION(:),
INTENT(INOUT) :: ind_ref, ind_dep
1540 INTEGER,
DIMENSION(SIZE(ind_ref)) :: sort_indices
1542 CALL sort(ind_ref,
SIZE(ind_ref), sort_indices)
1543 ind_dep(:) = ind_dep(sort_indices)
1551 FUNCTION opt_pgrid(tensor, tas_split_info)
1552 TYPE(dbt_type),
INTENT(IN) ::
tensor
1553 TYPE(dbt_tas_split_info),
INTENT(IN) :: tas_split_info
1554 INTEGER,
DIMENSION(ndims_matrix_row(tensor)) :: map1
1555 INTEGER,
DIMENSION(ndims_matrix_column(tensor)) :: map2
1556 TYPE(dbt_pgrid_type) :: opt_pgrid
1557 INTEGER,
DIMENSION(ndims_tensor(tensor)) :: dims
1559 CALL dbt_get_mapping_info(
tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1560 CALL blk_dims_tensor(
tensor, dims)
1561 opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1563 ALLOCATE (opt_pgrid%tas_split_info, source=tas_split_info)
1564 CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1573 SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1574 mp_dims_1, mp_dims_2, name, nodata, move_data)
1575 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
1576 INTEGER,
DIMENSION(:),
INTENT(IN) :: map1_2d, map2_2d
1577 TYPE(dbt_type),
INTENT(OUT) :: tensor_out
1578 CHARACTER(len=*),
INTENT(IN),
OPTIONAL :: name
1579 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata, move_data
1580 CLASS(mp_comm_type),
INTENT(IN),
OPTIONAL :: comm_2d
1581 TYPE(array_list),
INTENT(IN),
OPTIONAL :: dist1, dist2
1582 INTEGER,
DIMENSION(SIZE(map1_2d)),
OPTIONAL :: mp_dims_1
1583 INTEGER,
DIMENSION(SIZE(map2_2d)),
OPTIONAL :: mp_dims_2
1584 CHARACTER(len=default_string_length) :: name_tmp
1585 INTEGER,
DIMENSION(:),
ALLOCATABLE :: blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4, &
1586 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4
1587 TYPE(dbt_distribution_type) :: dist
1588 TYPE(mp_cart_type) :: comm_2d_prv
1589 INTEGER :: handle, i
1590 INTEGER,
DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1591 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_remap'
1592 LOGICAL :: nodata_prv
1593 TYPE(dbt_pgrid_type) :: comm_nd
1595 CALL timeset(routinen, handle)
1597 IF (
PRESENT(name))
THEN
1600 name_tmp = tensor_in%name
1602 IF (
PRESENT(dist1))
THEN
1603 cpassert(
PRESENT(mp_dims_1))
1606 IF (
PRESENT(dist2))
THEN
1607 cpassert(
PRESENT(mp_dims_2))
1610 IF (
PRESENT(comm_2d))
THEN
1611 comm_2d_prv = comm_2d
1613 comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1616 comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1617 CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1619 IF (ndims_tensor(tensor_in) == 2)
THEN
1620 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2)
1622 IF (ndims_tensor(tensor_in) == 3)
THEN
1623 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3)
1625 IF (ndims_tensor(tensor_in) == 4)
THEN
1626 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1629 IF (ndims_tensor(tensor_in) == 2)
THEN
1630 IF (
PRESENT(dist1))
THEN
1631 IF (any(map1_2d == 1))
THEN
1632 i = minloc(map1_2d, dim=1, mask=map1_2d == 1)
1633 CALL get_ith_array(dist1, i, nd_dist_1)
1637 IF (
PRESENT(dist2))
THEN
1638 IF (any(map2_2d == 1))
THEN
1639 i = minloc(map2_2d, dim=1, mask=map2_2d == 1)
1640 CALL get_ith_array(dist2, i, nd_dist_1)
1644 IF (.NOT.
ALLOCATED(nd_dist_1))
THEN
1645 ALLOCATE (nd_dist_1(
SIZE(blk_sizes_1)))
1646 CALL dbt_default_distvec(
SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1648 IF (
PRESENT(dist1))
THEN
1649 IF (any(map1_2d == 2))
THEN
1650 i = minloc(map1_2d, dim=1, mask=map1_2d == 2)
1651 CALL get_ith_array(dist1, i, nd_dist_2)
1655 IF (
PRESENT(dist2))
THEN
1656 IF (any(map2_2d == 2))
THEN
1657 i = minloc(map2_2d, dim=1, mask=map2_2d == 2)
1658 CALL get_ith_array(dist2, i, nd_dist_2)
1662 IF (.NOT.
ALLOCATED(nd_dist_2))
THEN
1663 ALLOCATE (nd_dist_2(
SIZE(blk_sizes_2)))
1664 CALL dbt_default_distvec(
SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1666 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1667 nd_dist_1, nd_dist_2, own_comm=.true.)
1669 IF (ndims_tensor(tensor_in) == 3)
THEN
1670 IF (
PRESENT(dist1))
THEN
1671 IF (any(map1_2d == 1))
THEN
1672 i = minloc(map1_2d, dim=1, mask=map1_2d == 1)
1673 CALL get_ith_array(dist1, i, nd_dist_1)
1677 IF (
PRESENT(dist2))
THEN
1678 IF (any(map2_2d == 1))
THEN
1679 i = minloc(map2_2d, dim=1, mask=map2_2d == 1)
1680 CALL get_ith_array(dist2, i, nd_dist_1)
1684 IF (.NOT.
ALLOCATED(nd_dist_1))
THEN
1685 ALLOCATE (nd_dist_1(
SIZE(blk_sizes_1)))
1686 CALL dbt_default_distvec(
SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1688 IF (
PRESENT(dist1))
THEN
1689 IF (any(map1_2d == 2))
THEN
1690 i = minloc(map1_2d, dim=1, mask=map1_2d == 2)
1691 CALL get_ith_array(dist1, i, nd_dist_2)
1695 IF (
PRESENT(dist2))
THEN
1696 IF (any(map2_2d == 2))
THEN
1697 i = minloc(map2_2d, dim=1, mask=map2_2d == 2)
1698 CALL get_ith_array(dist2, i, nd_dist_2)
1702 IF (.NOT.
ALLOCATED(nd_dist_2))
THEN
1703 ALLOCATE (nd_dist_2(
SIZE(blk_sizes_2)))
1704 CALL dbt_default_distvec(
SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1706 IF (
PRESENT(dist1))
THEN
1707 IF (any(map1_2d == 3))
THEN
1708 i = minloc(map1_2d, dim=1, mask=map1_2d == 3)
1709 CALL get_ith_array(dist1, i, nd_dist_3)
1713 IF (
PRESENT(dist2))
THEN
1714 IF (any(map2_2d == 3))
THEN
1715 i = minloc(map2_2d, dim=1, mask=map2_2d == 3)
1716 CALL get_ith_array(dist2, i, nd_dist_3)
1720 IF (.NOT.
ALLOCATED(nd_dist_3))
THEN
1721 ALLOCATE (nd_dist_3(
SIZE(blk_sizes_3)))
1722 CALL dbt_default_distvec(
SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1724 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1725 nd_dist_1, nd_dist_2, nd_dist_3, own_comm=.true.)
1727 IF (ndims_tensor(tensor_in) == 4)
THEN
1728 IF (
PRESENT(dist1))
THEN
1729 IF (any(map1_2d == 1))
THEN
1730 i = minloc(map1_2d, dim=1, mask=map1_2d == 1)
1731 CALL get_ith_array(dist1, i, nd_dist_1)
1735 IF (
PRESENT(dist2))
THEN
1736 IF (any(map2_2d == 1))
THEN
1737 i = minloc(map2_2d, dim=1, mask=map2_2d == 1)
1738 CALL get_ith_array(dist2, i, nd_dist_1)
1742 IF (.NOT.
ALLOCATED(nd_dist_1))
THEN
1743 ALLOCATE (nd_dist_1(
SIZE(blk_sizes_1)))
1744 CALL dbt_default_distvec(
SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1746 IF (
PRESENT(dist1))
THEN
1747 IF (any(map1_2d == 2))
THEN
1748 i = minloc(map1_2d, dim=1, mask=map1_2d == 2)
1749 CALL get_ith_array(dist1, i, nd_dist_2)
1753 IF (
PRESENT(dist2))
THEN
1754 IF (any(map2_2d == 2))
THEN
1755 i = minloc(map2_2d, dim=1, mask=map2_2d == 2)
1756 CALL get_ith_array(dist2, i, nd_dist_2)
1760 IF (.NOT.
ALLOCATED(nd_dist_2))
THEN
1761 ALLOCATE (nd_dist_2(
SIZE(blk_sizes_2)))
1762 CALL dbt_default_distvec(
SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1764 IF (
PRESENT(dist1))
THEN
1765 IF (any(map1_2d == 3))
THEN
1766 i = minloc(map1_2d, dim=1, mask=map1_2d == 3)
1767 CALL get_ith_array(dist1, i, nd_dist_3)
1771 IF (
PRESENT(dist2))
THEN
1772 IF (any(map2_2d == 3))
THEN
1773 i = minloc(map2_2d, dim=1, mask=map2_2d == 3)
1774 CALL get_ith_array(dist2, i, nd_dist_3)
1778 IF (.NOT.
ALLOCATED(nd_dist_3))
THEN
1779 ALLOCATE (nd_dist_3(
SIZE(blk_sizes_3)))
1780 CALL dbt_default_distvec(
SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1782 IF (
PRESENT(dist1))
THEN
1783 IF (any(map1_2d == 4))
THEN
1784 i = minloc(map1_2d, dim=1, mask=map1_2d == 4)
1785 CALL get_ith_array(dist1, i, nd_dist_4)
1789 IF (
PRESENT(dist2))
THEN
1790 IF (any(map2_2d == 4))
THEN
1791 i = minloc(map2_2d, dim=1, mask=map2_2d == 4)
1792 CALL get_ith_array(dist2, i, nd_dist_4)
1796 IF (.NOT.
ALLOCATED(nd_dist_4))
THEN
1797 ALLOCATE (nd_dist_4(
SIZE(blk_sizes_4)))
1798 CALL dbt_default_distvec(
SIZE(blk_sizes_4), pdims(4), blk_sizes_4, nd_dist_4)
1800 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1801 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm=.true.)
1804 IF (ndims_tensor(tensor_in) == 2)
THEN
1805 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1806 blk_sizes_1, blk_sizes_2)
1808 IF (ndims_tensor(tensor_in) == 3)
THEN
1809 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1810 blk_sizes_1, blk_sizes_2, blk_sizes_3)
1812 IF (ndims_tensor(tensor_in) == 4)
THEN
1813 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1814 blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1817 IF (
PRESENT(nodata))
THEN
1820 nodata_prv = .false.
1823 IF (.NOT. nodata_prv)
CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1824 CALL dbt_distribution_destroy(dist)
1826 CALL timestop(handle)
1834 SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1835 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
1836 TYPE(dbt_type),
INTENT(OUT) :: tensor_out
1837 INTEGER,
DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1838 INTEGER,
DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1839 INTEGER,
DIMENSION(ndims_tensor(tensor_in)), &
1840 INTENT(OUT),
OPTIONAL :: order
1841 INTEGER,
DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1842 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_align_index'
1845 CALL timeset(routinen, handle)
1847 CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1848 order_prv = dbt_inverse_order([map1_2d, map2_2d])
1849 CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1851 IF (
PRESENT(order)) order = order_prv
1853 CALL timestop(handle)
1860 SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1861 TYPE(dbt_type),
INTENT(INOUT) :: tensor_in
1862 TYPE(dbt_type),
INTENT(OUT) :: tensor_out
1863 INTEGER,
DIMENSION(ndims_tensor(tensor_in)), &
1866 TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1867 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_permute_index'
1871 CALL timeset(routinen, handle)
1873 ndims = ndims_tensor(tensor_in)
1875 CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1876 CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1877 CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1879 tensor_out%matrix_rep => tensor_in%matrix_rep
1880 tensor_out%owns_matrix = .false.
1882 tensor_out%nd_index = nd_index_rs
1883 tensor_out%nd_index_blk = nd_index_blk_rs
1884 tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1885 IF (
ALLOCATED(tensor_in%pgrid%tas_split_info))
THEN
1886 ALLOCATE (tensor_out%pgrid%tas_split_info, source=tensor_in%pgrid%tas_split_info)
1888 tensor_out%refcount => tensor_in%refcount
1889 CALL dbt_hold(tensor_out)
1891 CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1892 CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1893 CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1894 CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1895 ALLOCATE (tensor_out%nblks_local(ndims))
1896 ALLOCATE (tensor_out%nfull_local(ndims))
1897 tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1898 tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1899 tensor_out%name = tensor_in%name
1900 tensor_out%valid = .true.
1902 IF (
ALLOCATED(tensor_in%contraction_storage))
THEN
1903 ALLOCATE (tensor_out%contraction_storage, source=tensor_in%contraction_storage)
1904 CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1905 CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1908 CALL timestop(handle)
1920 SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1921 contract_1, notcontract_1, &
1922 contract_2, notcontract_2, &
1923 bounds_t1, bounds_t2, &
1924 bounds_1, bounds_2, bounds_3, &
1925 do_crop_1, do_crop_2)
1927 TYPE(dbt_type),
INTENT(IN) :: tensor_1, tensor_2
1928 INTEGER,
DIMENSION(:),
INTENT(IN) :: contract_1, contract_2, &
1929 notcontract_1, notcontract_2
1930 INTEGER,
DIMENSION(2, ndims_tensor(tensor_1)), &
1931 INTENT(OUT) :: bounds_t1
1932 INTEGER,
DIMENSION(2, ndims_tensor(tensor_2)), &
1933 INTENT(OUT) :: bounds_t2
1934 INTEGER,
DIMENSION(2, SIZE(contract_1)), &
1935 INTENT(IN),
OPTIONAL :: bounds_1
1936 INTEGER,
DIMENSION(2, SIZE(notcontract_1)), &
1937 INTENT(IN),
OPTIONAL :: bounds_2
1938 INTEGER,
DIMENSION(2, SIZE(notcontract_2)), &
1939 INTENT(IN),
OPTIONAL :: bounds_3
1940 LOGICAL,
INTENT(OUT),
OPTIONAL :: do_crop_1, do_crop_2
1941 LOGICAL,
DIMENSION(2) :: do_crop
1946 CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1949 CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1951 IF (
PRESENT(bounds_1))
THEN
1952 bounds_t1(:, contract_1) = bounds_1
1954 bounds_t2(:, contract_2) = bounds_1
1958 IF (
PRESENT(bounds_2))
THEN
1959 bounds_t1(:, notcontract_1) = bounds_2
1963 IF (
PRESENT(bounds_3))
THEN
1964 bounds_t2(:, notcontract_2) = bounds_3
1968 IF (
PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1969 IF (
PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1981 SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1982 TYPE(dbt_type),
INTENT(IN) :: tensor_1, tensor_2, tensor_3
1983 CHARACTER(LEN=1),
DIMENSION(ndims_tensor(tensor_1)),
INTENT(IN) :: indchar1
1984 CHARACTER(LEN=1),
DIMENSION(ndims_tensor(tensor_2)),
INTENT(IN) :: indchar2
1985 CHARACTER(LEN=1),
DIMENSION(ndims_tensor(tensor_3)),
INTENT(IN) :: indchar3
1986 INTEGER,
INTENT(IN) :: unit_nr
1987 INTEGER,
DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1988 INTEGER,
DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1989 INTEGER,
DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1990 INTEGER,
DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1991 INTEGER,
DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1992 INTEGER,
DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1993 INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1995 unit_nr_prv = prep_output_unit(unit_nr)
1997 IF (unit_nr_prv /= 0)
THEN
1998 CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1999 CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
2000 CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
2003 IF (unit_nr_prv > 0)
THEN
2004 WRITE (unit_nr_prv,
'(T2,A)')
"INDEX INFO"
2005 WRITE (unit_nr_prv,
'(T15,A)', advance=
'no')
"tensor index: ("
2006 DO ichar1 = 1,
SIZE(indchar1)
2007 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar1(ichar1)
2009 WRITE (unit_nr_prv,
'(A)', advance=
'no')
") x ("
2010 DO ichar2 = 1,
SIZE(indchar2)
2011 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar2(ichar2)
2013 WRITE (unit_nr_prv,
'(A)', advance=
'no')
") = ("
2014 DO ichar3 = 1,
SIZE(indchar3)
2015 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar3(ichar3)
2017 WRITE (unit_nr_prv,
'(A)')
")"
2019 WRITE (unit_nr_prv,
'(T15,A)', advance=
'no')
"matrix index: ("
2020 DO ichar1 = 1,
SIZE(map11)
2021 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar1(map11(ichar1))
2023 WRITE (unit_nr_prv,
'(A1)', advance=
'no')
"|"
2024 DO ichar1 = 1,
SIZE(map12)
2025 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar1(map12(ichar1))
2027 WRITE (unit_nr_prv,
'(A)', advance=
'no')
") x ("
2028 DO ichar2 = 1,
SIZE(map21)
2029 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar2(map21(ichar2))
2031 WRITE (unit_nr_prv,
'(A1)', advance=
'no')
"|"
2032 DO ichar2 = 1,
SIZE(map22)
2033 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar2(map22(ichar2))
2035 WRITE (unit_nr_prv,
'(A)', advance=
'no')
") = ("
2036 DO ichar3 = 1,
SIZE(map31)
2037 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar3(map31(ichar3))
2039 WRITE (unit_nr_prv,
'(A1)', advance=
'no')
"|"
2040 DO ichar3 = 1,
SIZE(map32)
2041 WRITE (unit_nr_prv,
'(A1)', advance=
'no') indchar3(map32(ichar3))
2043 WRITE (unit_nr_prv,
'(A)')
")"
2092 TYPE(dbt_type),
INTENT(INOUT) ::
tensor
2093 INTEGER,
DIMENSION(:),
OPTIONAL,
INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2094 INTEGER,
DIMENSION(ndims_tensor(tensor)) :: tdims
2095 INTEGER,
DIMENSION(:),
ALLOCATABLE :: batch_range_prv_1, batch_range_prv_2, batch_range_prv_3,&
2097 LOGICAL :: static_range
2099 CALL dbt_get_info(
tensor, nblks_total=tdims)
2101 static_range = .true.
2102 IF (ndims_tensor(
tensor) >= 1)
THEN
2103 IF (
PRESENT(batch_range_1))
THEN
2104 ALLOCATE (batch_range_prv_1, source=batch_range_1)
2105 static_range = .false.
2107 ALLOCATE (batch_range_prv_1(2))
2108 batch_range_prv_1(1) = 1
2109 batch_range_prv_1(2) = tdims(1) + 1
2112 IF (ndims_tensor(
tensor) >= 2)
THEN
2113 IF (
PRESENT(batch_range_2))
THEN
2114 ALLOCATE (batch_range_prv_2, source=batch_range_2)
2115 static_range = .false.
2117 ALLOCATE (batch_range_prv_2(2))
2118 batch_range_prv_2(1) = 1
2119 batch_range_prv_2(2) = tdims(2) + 1
2122 IF (ndims_tensor(
tensor) >= 3)
THEN
2123 IF (
PRESENT(batch_range_3))
THEN
2124 ALLOCATE (batch_range_prv_3, source=batch_range_3)
2125 static_range = .false.
2127 ALLOCATE (batch_range_prv_3(2))
2128 batch_range_prv_3(1) = 1
2129 batch_range_prv_3(2) = tdims(3) + 1
2132 IF (ndims_tensor(
tensor) >= 4)
THEN
2133 IF (
PRESENT(batch_range_4))
THEN
2134 ALLOCATE (batch_range_prv_4, source=batch_range_4)
2135 static_range = .false.
2137 ALLOCATE (batch_range_prv_4(2))
2138 batch_range_prv_4(1) = 1
2139 batch_range_prv_4(2) = tdims(4) + 1
2143 ALLOCATE (
tensor%contraction_storage)
2144 tensor%contraction_storage%static = static_range
2145 IF (static_range)
THEN
2146 CALL dbt_tas_batched_mm_init(
tensor%matrix_rep)
2148 tensor%contraction_storage%nsplit_avg = 0.0_dp
2149 tensor%contraction_storage%ibatch = 0
2151 IF (ndims_tensor(
tensor) == 1)
THEN
2152 CALL create_array_list(
tensor%contraction_storage%batch_ranges, 1, &
2155 IF (ndims_tensor(
tensor) == 2)
THEN
2156 CALL create_array_list(
tensor%contraction_storage%batch_ranges, 2, &
2157 batch_range_prv_1, batch_range_prv_2)
2159 IF (ndims_tensor(
tensor) == 3)
THEN
2160 CALL create_array_list(
tensor%contraction_storage%batch_ranges, 3, &
2161 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3)
2163 IF (ndims_tensor(
tensor) == 4)
THEN
2164 CALL create_array_list(
tensor%contraction_storage%batch_ranges, 4, &
2165 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3, batch_range_prv_4)
2176 TYPE(dbt_type),
INTENT(INOUT) ::
tensor
2177 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
2179 INTEGER :: unit_nr_prv, handle
2181 CALL tensor%pgrid%mp_comm_2d%sync()
2182 CALL timeset(
"dbt_total", handle)
2183 unit_nr_prv = prep_output_unit(unit_nr)
2187 IF (
tensor%contraction_storage%static)
THEN
2188 IF (
tensor%matrix_rep%do_batched > 0)
THEN
2189 IF (
tensor%matrix_rep%mm_storage%batched_out) do_write = .true.
2191 CALL dbt_tas_batched_mm_finalize(
tensor%matrix_rep)
2194 IF (do_write .AND. unit_nr_prv /= 0)
THEN
2195 IF (unit_nr_prv > 0)
THEN
2196 WRITE (unit_nr_prv,
"(T2,A)") &
2197 "FINALIZING BATCHED PROCESSING OF MATMUL"
2199 CALL dbt_write_tensor_info(
tensor, unit_nr_prv)
2200 CALL dbt_write_tensor_dist(
tensor, unit_nr_prv)
2203 CALL destroy_array_list(
tensor%contraction_storage%batch_ranges)
2204 DEALLOCATE (
tensor%contraction_storage)
2205 CALL tensor%pgrid%mp_comm_2d%sync()
2206 CALL timestop(handle)
2220 SUBROUTINE dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2221 nodata, pgrid_changed, unit_nr)
2222 TYPE(dbt_type),
INTENT(INOUT) ::
tensor
2223 TYPE(dbt_pgrid_type),
INTENT(IN) :: pgrid
2224 INTEGER,
DIMENSION(:),
OPTIONAL,
INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2226 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata
2227 LOGICAL,
INTENT(OUT),
OPTIONAL :: pgrid_changed
2228 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
2229 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_change_pgrid'
2230 CHARACTER(default_string_length) :: name
2232 INTEGER,
ALLOCATABLE,
DIMENSION(:) :: bs_1, bs_2, bs_3, bs_4, &
2233 dist_1, dist_2, dist_3, dist_4
2234 INTEGER,
DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2236 TYPE(dbt_type) :: t_tmp
2237 TYPE(dbt_distribution_type) :: dist
2238 INTEGER,
DIMENSION(ndims_matrix_row(tensor)) :: map1
2240 DIMENSION(ndims_matrix_column(tensor)) :: map2
2241 LOGICAL,
DIMENSION(ndims_tensor(tensor)) :: mem_aware
2242 INTEGER,
DIMENSION(ndims_tensor(tensor)) :: nbatch
2243 INTEGER :: ind1, ind2, batch_size, ibatch
2245 IF (
PRESENT(pgrid_changed)) pgrid_changed = .false.
2246 CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2247 CALL mp_environ_pgrid(
tensor%pgrid, pdims_ref, pcoord_ref)
2249 IF (all(pdims == pdims_ref))
THEN
2250 IF (
ALLOCATED(pgrid%tas_split_info) .AND.
ALLOCATED(
tensor%pgrid%tas_split_info))
THEN
2251 IF (pgrid%tas_split_info%ngroup ==
tensor%pgrid%tas_split_info%ngroup)
THEN
2257 CALL timeset(routinen, handle)
2259 IF (ndims_tensor(
tensor) >= 1)
THEN
2260 mem_aware(1) =
PRESENT(batch_range_1)
2261 IF (mem_aware(1)) nbatch(1) =
SIZE(batch_range_1) - 1
2263 IF (ndims_tensor(
tensor) >= 2)
THEN
2264 mem_aware(2) =
PRESENT(batch_range_2)
2265 IF (mem_aware(2)) nbatch(2) =
SIZE(batch_range_2) - 1
2267 IF (ndims_tensor(
tensor) >= 3)
THEN
2268 mem_aware(3) =
PRESENT(batch_range_3)
2269 IF (mem_aware(3)) nbatch(3) =
SIZE(batch_range_3) - 1
2271 IF (ndims_tensor(
tensor) >= 4)
THEN
2272 mem_aware(4) =
PRESENT(batch_range_4)
2273 IF (mem_aware(4)) nbatch(4) =
SIZE(batch_range_4) - 1
2276 CALL dbt_get_info(
tensor, nblks_total=tdims, name=name)
2278 IF (ndims_tensor(
tensor) >= 1)
THEN
2279 ALLOCATE (bs_1(dbt_nblks_total(
tensor, 1)))
2280 CALL get_ith_array(
tensor%blk_sizes, 1, bs_1)
2281 ALLOCATE (dist_1(tdims(1)))
2283 IF (mem_aware(1))
THEN
2284 DO ibatch = 1, nbatch(1)
2285 ind1 = batch_range_1(ibatch)
2286 ind2 = batch_range_1(ibatch + 1) - 1
2287 batch_size = ind2 - ind1 + 1
2288 CALL dbt_default_distvec(batch_size, pdims(1), &
2289 bs_1(ind1:ind2), dist_1(ind1:ind2))
2292 CALL dbt_default_distvec(tdims(1), pdims(1), bs_1, dist_1)
2295 IF (ndims_tensor(
tensor) >= 2)
THEN
2296 ALLOCATE (bs_2(dbt_nblks_total(
tensor, 2)))
2297 CALL get_ith_array(
tensor%blk_sizes, 2, bs_2)
2298 ALLOCATE (dist_2(tdims(2)))
2300 IF (mem_aware(2))
THEN
2301 DO ibatch = 1, nbatch(2)
2302 ind1 = batch_range_2(ibatch)
2303 ind2 = batch_range_2(ibatch + 1) - 1
2304 batch_size = ind2 - ind1 + 1
2305 CALL dbt_default_distvec(batch_size, pdims(2), &
2306 bs_2(ind1:ind2), dist_2(ind1:ind2))
2309 CALL dbt_default_distvec(tdims(2), pdims(2), bs_2, dist_2)
2312 IF (ndims_tensor(
tensor) >= 3)
THEN
2313 ALLOCATE (bs_3(dbt_nblks_total(
tensor, 3)))
2314 CALL get_ith_array(
tensor%blk_sizes, 3, bs_3)
2315 ALLOCATE (dist_3(tdims(3)))
2317 IF (mem_aware(3))
THEN
2318 DO ibatch = 1, nbatch(3)
2319 ind1 = batch_range_3(ibatch)
2320 ind2 = batch_range_3(ibatch + 1) - 1
2321 batch_size = ind2 - ind1 + 1
2322 CALL dbt_default_distvec(batch_size, pdims(3), &
2323 bs_3(ind1:ind2), dist_3(ind1:ind2))
2326 CALL dbt_default_distvec(tdims(3), pdims(3), bs_3, dist_3)
2329 IF (ndims_tensor(
tensor) >= 4)
THEN
2330 ALLOCATE (bs_4(dbt_nblks_total(
tensor, 4)))
2331 CALL get_ith_array(
tensor%blk_sizes, 4, bs_4)
2332 ALLOCATE (dist_4(tdims(4)))
2334 IF (mem_aware(4))
THEN
2335 DO ibatch = 1, nbatch(4)
2336 ind1 = batch_range_4(ibatch)
2337 ind2 = batch_range_4(ibatch + 1) - 1
2338 batch_size = ind2 - ind1 + 1
2339 CALL dbt_default_distvec(batch_size, pdims(4), &
2340 bs_4(ind1:ind2), dist_4(ind1:ind2))
2343 CALL dbt_default_distvec(tdims(4), pdims(4), bs_4, dist_4)
2347 CALL dbt_get_mapping_info(
tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2348 IF (ndims_tensor(
tensor) == 2)
THEN
2349 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2)
2350 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2)
2352 IF (ndims_tensor(
tensor) == 3)
THEN
2353 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3)
2354 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3)
2356 IF (ndims_tensor(
tensor) == 4)
THEN
2357 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3, dist_4)
2358 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3, bs_4)
2360 CALL dbt_distribution_destroy(dist)
2362 IF (
PRESENT(nodata))
THEN
2363 IF (.NOT. nodata)
CALL dbt_copy_expert(
tensor, t_tmp, move_data=.true.)
2365 CALL dbt_copy_expert(
tensor, t_tmp, move_data=.true.)
2368 CALL dbt_copy_contraction_storage(
tensor, t_tmp)
2373 IF (
PRESENT(unit_nr))
THEN
2374 IF (unit_nr > 0)
THEN
2375 WRITE (unit_nr,
"(T2,A,1X,A)")
"OPTIMIZED PGRID INFO FOR", trim(
tensor%name)
2376 WRITE (unit_nr,
"(T4,A,1X,3I6)")
"process grid dimensions:", pdims
2377 CALL dbt_write_split_info(pgrid, unit_nr)
2381 IF (
PRESENT(pgrid_changed)) pgrid_changed = .true.
2383 CALL timestop(handle)
2390 SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2391 TYPE(dbt_type),
INTENT(INOUT) ::
tensor
2392 TYPE(mp_cart_type),
INTENT(IN) :: mp_comm
2393 INTEGER,
DIMENSION(2),
INTENT(IN),
OPTIONAL :: pdims
2394 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata
2395 INTEGER,
INTENT(IN),
OPTIONAL :: nsplit, dimsplit
2396 LOGICAL,
INTENT(OUT),
OPTIONAL :: pgrid_changed
2397 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
2398 INTEGER,
DIMENSION(ndims_matrix_row(tensor)) :: map1
2399 INTEGER,
DIMENSION(ndims_matrix_column(tensor)) :: map2
2400 INTEGER,
DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2401 TYPE(dbt_pgrid_type) :: pgrid
2402 INTEGER,
DIMENSION(:),
ALLOCATABLE :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2403 INTEGER,
DIMENSION(:),
ALLOCATABLE :: array
2406 CALL dbt_get_mapping_info(
tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2407 CALL blk_dims_tensor(
tensor, dims)
2409 IF (
ALLOCATED(
tensor%contraction_storage))
THEN
2410 associate(batch_ranges =>
tensor%contraction_storage%batch_ranges)
2411 nbatches = sizes_of_arrays(
tensor%contraction_storage%batch_ranges) - 1
2415 DO idim = 1, ndims_tensor(
tensor)
2416 CALL get_ith_array(
tensor%contraction_storage%batch_ranges, idim, array)
2417 dims(idim) = array(nbatches(idim) + 1) - array(1)
2419 dims(idim) = dims(idim)/nbatches(idim)
2420 IF (dims(idim) <= 0) dims(idim) = 1
2425 pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2426 IF (
ALLOCATED(
tensor%contraction_storage))
THEN
2427 IF (ndims_tensor(
tensor) == 1)
THEN
2428 CALL get_arrays(
tensor%contraction_storage%batch_ranges, batch_range_1)
2429 CALL dbt_change_pgrid(
tensor, pgrid, batch_range_1, &
2430 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2432 IF (ndims_tensor(
tensor) == 2)
THEN
2433 CALL get_arrays(
tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2)
2434 CALL dbt_change_pgrid(
tensor, pgrid, batch_range_1, batch_range_2, &
2435 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2437 IF (ndims_tensor(
tensor) == 3)
THEN
2438 CALL get_arrays(
tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3)
2439 CALL dbt_change_pgrid(
tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, &
2440 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2442 IF (ndims_tensor(
tensor) == 4)
THEN
2443 CALL get_arrays(
tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2444 CALL dbt_change_pgrid(
tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2445 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2448 CALL dbt_change_pgrid(
tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2450 CALL dbt_pgrid_destroy(pgrid)
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Wrapper for allocating, copying and reshaping arrays.
Representation of arbitrary number of 1d integer arrays with arbitrary sizes. This is needed for gene...
pure logical function, public array_eq_i(arr1, arr2)
check whether two arrays are equal
subroutine, public get_arrays(list, data_1, data_2, data_3, data_4, i_selected)
Get all arrays contained in list.
subroutine, public create_array_list(list, ndata, data_1, data_2, data_3, data_4)
collects any number of arrays of different sizes into a single array (listcol_data),...
subroutine, public destroy_array_list(list)
destroy array list.
integer function, dimension(:), allocatable, public sizes_of_arrays(list)
sizes of arrays stored in list
subroutine, public get_ith_array(list, i, array_size, array)
get ith array
type(array_list) function, public array_sublist(list, i_selected)
extract a subset of arrays
subroutine, public reorder_arrays(list_in, list_out, order)
reorder array list.
logical function, public check_equal(list1, list2)
check whether two array lists are equal
Methods to operate on n-dimensional tensor blocks.
elemental logical function, public checker_tr(row, column)
Determines whether a transpose must be applied.
logical function, public dbt_iterator_blocks_left(iterator)
Generalization of block_iterator_blocks_left for tensors.
subroutine, public destroy_block(block)
pure integer function, public ndims_iterator(iterator)
Number of dimensions.
subroutine, public dbt_iterator_stop(iterator)
Generalization of block_iterator_stop for tensors.
subroutine, public dbt_iterator_start(iterator, tensor)
Generalization of block_iterator_start for tensors.
subroutine, public dbt_iterator_next_block(iterator, ind_nd, blk_size, blk_offset)
iterate over nd blocks of an nd rank tensor, index only (blocks must be retrieved by calling dbt_get_...
tensor index and mapping to DBM index
pure integer function, public ndims_mapping_row(map)
how many tensor dimensions are mapped to matrix row
pure integer function, dimension(size(order)), public dbt_inverse_order(order)
Invert order.
pure integer function, public ndims_mapping(map)
subroutine, public permute_index(map_in, map_out, order)
reorder tensor index (no data)
pure subroutine, public dbt_get_mapping_info(map, ndim_nd, ndim1_2d, ndim2_2d, dims_2d_i8, dims_2d, dims_nd, dims1_2d, dims2_2d, map1_2d, map2_2d, map_nd, base, col_major)
get mapping info
pure integer function, public ndims_mapping_column(map)
how many tensor dimensions are mapped to matrix column
pure integer function, dimension(map%ndim_nd), public get_nd_indices_tensor(map, ind_in)
transform 2d index to nd index, using info from index mapping.
DBT tensor Input / Output.
subroutine, public dbt_write_tensor_info(tensor, unit_nr, full_info)
Write tensor global info: block dimensions, full dimensions and process grid dimensions.
subroutine, public dbt_write_tensor_dist(tensor, unit_nr)
Write info on tensor distribution & load balance.
subroutine, public dbt_write_split_info(pgrid, unit_nr)
integer function, public prep_output_unit(unit_nr)
DBT tensor framework for block-sparse tensor contraction. Representation of n-rank tensors as DBT tal...
subroutine, public dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
Copy tensor data. Redistributes tensor data according to distributions of target and source tensor....
subroutine, public dbt_batched_contract_finalize(tensor, unit_nr)
finalize batched contraction. This performs all communication that has been postponed in the contract...
subroutine, public dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
copy matrix to tensor.
subroutine, public dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, contract_1, notcontract_1, contract_2, notcontract_2, map_1, map_2, bounds_1, bounds_2, bounds_3, optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
Contract tensors by multiplying matrix representations. tensor_3(map_1, map_2) := alpha * tensor_1(no...
subroutine, public dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
copy tensor to matrix
subroutine, public dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
Initialize batched contraction for this tensor.
Routines to reshape / redistribute tensors.
subroutine, public dbt_reshape(tensor_in, tensor_out, summation, move_data)
copy data (involves reshape) tensor_out = tensor_out + tensor_in move_data memory optimization: trans...
Routines to split blocks and to convert between tensors with different block sizes.
subroutine, public dbt_split_copyback(tensor_split_in, tensor_out, summation)
Copy tensor with split blocks to tensor with original block sizes.
subroutine, public dbt_make_compatible_blocks(tensor1, tensor2, tensor1_split, tensor2_split, order, nodata1, nodata2, move_data)
split two tensors with same total sizes but different block sizes such that they have equal block siz...
subroutine, public dbt_crop(tensor_in, tensor_out, bounds, move_data)
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
subroutine, public dbt_tas_finalize(matrix)
...
type(dbt_tas_split_info) function, public dbt_tas_info(matrix)
get info on mpi grid splitting
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
subroutine, public dbt_tas_batched_mm_init(matrix)
...
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
integer, parameter, public colsplit
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
type(mp_cart_type) function, public dbt_tas_mp_comm(mp_comm, split_rowcol, nsplit)
Create default cartesian process grid that is consistent with default split heuristic of dbt_tas_crea...
real(dp), parameter, public default_nsplit_accept_ratio
real(dp), parameter, public default_pdims_accept_ratio
subroutine, public dbt_tas_info_hold(split_info)
...
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
subroutine, public dbt_pgrid_destroy(pgrid, keep_comm)
destroy process grid
subroutine, public dbt_distribution_new(dist, pgrid, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4)
Create a tensor distribution.
subroutine, public blk_dims_tensor(tensor, dims)
tensor block dimensions
subroutine, public dims_tensor(tensor, dims)
tensor dimensions
subroutine, public dbt_copy_contraction_storage(tensor_in, tensor_out)
type(dbt_pgrid_type) function, public dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, nsplit, dimsplit)
Create a default nd process topology that is consistent with a given 2d topology. Purpose: a nd tenso...
subroutine, public dbt_destroy(tensor)
Destroy a tensor.
pure integer function, public dbt_max_nblks_local(tensor)
returns an estimate of maximum number of local blocks in tensor (irrespective of the actual number of...
subroutine, public dbt_get_info(tensor, nblks_total, nfull_total, nblks_local, nfull_local, pdims, my_ploc, blks_local_1, blks_local_2, blks_local_3, blks_local_4, proc_dist_1, proc_dist_2, proc_dist_3, proc_dist_4, blk_size_1, blk_size_2, blk_size_3, blk_size_4, blk_offset_1, blk_offset_2, blk_offset_3, blk_offset_4, distribution, name)
As block_get_info but for tensors.
subroutine, public dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm)
Create a tensor distribution.
type(dbt_distribution_type) function, public dbt_distribution(tensor)
get distribution from tensor
pure integer function, public ndims_tensor(tensor)
tensor rank
pure integer function, public dbt_nblks_total(tensor, idim)
total numbers of blocks along dimension idim
pure integer function, public dbt_get_num_blocks(tensor)
As block_get_num_blocks: get number of local blocks.
subroutine, public dbt_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
subroutine, public dbt_hold(tensor)
reference counting for tensors (only needed for communicator handle that must be freed when no longer...
subroutine, public dbt_clear(tensor)
Clear tensor (s.t. it does not contain any blocks)
subroutine, public dbt_finalize(tensor)
Finalize tensor, as block_finalize. This should be taken care of internally in DBT tensors,...
subroutine, public mp_environ_pgrid(pgrid, dims, task_coor)
as mp_environ but for special pgrid type
subroutine, public dbt_get_stored_coordinates(tensor, ind_nd, processor)
Generalization of block_get_stored_coordinates for tensors.
integer(kind=int_8) function, public dbt_get_num_blocks_total(tensor)
Get total number of blocks.
pure integer(int_8) function, public ndims_matrix_row(tensor)
how many tensor dimensions are mapped to matrix row
pure integer(int_8) function, public ndims_matrix_column(tensor)
how many tensor dimensions are mapped to matrix column
subroutine, public dbt_filter(tensor, eps)
As block_filter.
subroutine, public dbt_distribution_destroy(dist)
Destroy tensor distribution.
subroutine, public dbt_scale(tensor, alpha)
as block_scale
Defines the basic variable types.
integer, parameter, public int_8
integer, parameter, public dp
integer, parameter, public default_string_length
Interface to the message passing library MPI.
All kind of helpful little routines.