(git:3add494)
dbt_methods.F
Go to the documentation of this file.
1 !--------------------------------------------------------------------------------------------------!
2 ! CP2K: A general program to perform molecular dynamics simulations !
3 ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 ! !
5 ! SPDX-License-Identifier: GPL-2.0-or-later !
6 !--------------------------------------------------------------------------------------------------!
7 
8 ! **************************************************************************************************
9 !> \brief DBT tensor framework for block-sparse tensor contraction.
10 !> Representation of n-rank tensors as DBT tall-and-skinny matrices.
11 !> Support for arbitrary redistribution between different representations.
12 !> Support for arbitrary tensor contractions
13 !> \todo implement checks and error messages
14 !> \author Patrick Seewald
15 ! **************************************************************************************************
17 
18 
19  USE dbcsr_api, ONLY: &
20  dbcsr_type, dbcsr_release, &
21  dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
22  dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
23  USE dbt_allocate_wrap, ONLY: &
24  allocate_any
25  USE dbt_array_list_methods, ONLY: &
28  USE dbm_api, ONLY: &
29  dbm_clear
30  USE dbt_tas_types, ONLY: &
31  dbt_tas_split_info
32  USE dbt_tas_base, ONLY: &
34  USE dbt_tas_mm, ONLY: &
37  USE dbt_block, ONLY: &
38  dbt_iterator_type, dbt_get_block, dbt_put_block, dbt_iterator_start, &
40  ndims_iterator, dbt_reserve_blocks, block_nd, destroy_block, checker_tr
41  USE dbt_index, ONLY: &
44  USE dbt_types, ONLY: &
45  dbt_create, dbt_type, ndims_tensor, dims_tensor, &
46  dbt_distribution_type, dbt_distribution, dbt_nd_mp_comm, dbt_destroy, &
51  dbt_max_nblks_local, dbt_default_distvec, dbt_contraction_storage, dbt_nblks_total, &
53  USE kinds, ONLY: &
55  USE message_passing, ONLY: &
56  mp_cart_type
57  USE util, ONLY: &
58  sort
59  USE dbt_reshape_ops, ONLY: &
61  USE dbt_tas_split, ONLY: &
64  USE dbt_split, ONLY: &
66  USE dbt_io, ONLY: &
68  USE message_passing, ONLY: mp_comm_type
69 
70 #include "../base/base_uses.f90"
71 
72  IMPLICIT NONE
73  PRIVATE
74  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
75 
76  PUBLIC :: &
77  dbt_contract, &
78  dbt_copy, &
79  dbt_get_block, &
86  dbt_iterator_type, &
87  dbt_put_block, &
88  dbt_reserve_blocks, &
93 
94 CONTAINS
95 
96 ! **************************************************************************************************
97 !> \brief Copy tensor data.
98 !> Redistributes tensor data according to distributions of target and source tensor.
99 !> Permutes tensor index according to `order` argument (if present).
100 !> Source and target tensor formats are arbitrary as long as the following requirements are met:
101 !> * source and target tensors have the same rank and the same sizes in each dimension in terms
102 !> of tensor elements (block sizes don't need to be the same).
103 !> If `order` argument is present, sizes must match after index permutation.
104 !> OR
105 !> * target tensor is not yet created, in this case an exact copy of source tensor is returned.
106 !> \param tensor_in Source
107 !> \param tensor_out Target
108 !> \param order Permutation of target tensor index.
109 !> Exact same convention as order argument of RESHAPE intrinsic.
110 !> \param bounds crop tensor data: start and end index for each tensor dimension
111 !> \author Patrick Seewald
112 ! **************************************************************************************************
113  SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
114  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
115  INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
116  INTENT(IN), OPTIONAL :: order
117  LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
118  INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
119  INTENT(IN), OPTIONAL :: bounds
120  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
121  INTEGER :: handle
122 
123  CALL tensor_in%pgrid%mp_comm_2d%sync()
124  CALL timeset("dbt_total", handle)
125 
126  ! make sure that it is safe to use dbt_copy during a batched contraction
127  CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.true.)
128  CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.true.)
129 
130  CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
131  CALL tensor_in%pgrid%mp_comm_2d%sync()
132  CALL timestop(handle)
133  END SUBROUTINE
134 
135 ! **************************************************************************************************
136 !> \brief expert routine for copying a tensor. For internal use only.
137 !> \author Patrick Seewald
138 ! **************************************************************************************************
139  SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
140  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
141  INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
142  INTENT(IN), OPTIONAL :: order
143  LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
144  INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
145  INTENT(IN), OPTIONAL :: bounds
146  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
147 
148  TYPE(dbt_type), POINTER :: in_tmp_1, in_tmp_2, &
149  in_tmp_3, out_tmp_1
150  INTEGER :: handle, unit_nr_prv
151  INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
152 
153  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy'
154  LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
155  summation_prv, new_in_1, new_in_2, &
156  new_in_3, new_out_1, block_compatible, &
157  move_prv
158  TYPE(array_list) :: blk_sizes_in
159 
160  CALL timeset(routinen, handle)
161 
162  cpassert(tensor_out%valid)
163 
164  unit_nr_prv = prep_output_unit(unit_nr)
165 
166  IF (PRESENT(move_data)) THEN
167  move_prv = move_data
168  ELSE
169  move_prv = .false.
170  END IF
171 
172  dist_compatible_tas = .false.
173  dist_compatible_tensor = .false.
174  block_compatible = .false.
175  new_in_1 = .false.
176  new_in_2 = .false.
177  new_in_3 = .false.
178  new_out_1 = .false.
179 
180  IF (PRESENT(summation)) THEN
181  summation_prv = summation
182  ELSE
183  summation_prv = .false.
184  END IF
185 
186  IF (PRESENT(bounds)) THEN
187  ALLOCATE (in_tmp_1)
188  CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
189  new_in_1 = .true.
190  move_prv = .true.
191  ELSE
192  in_tmp_1 => tensor_in
193  END IF
194 
195  IF (PRESENT(order)) THEN
196  CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
197  block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
198  ELSE
199  block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
200  END IF
201 
202  IF (.NOT. block_compatible) THEN
203  ALLOCATE (in_tmp_2, out_tmp_1)
204  CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
205  nodata2=.NOT. summation_prv, move_data=move_prv)
206  new_in_2 = .true.; new_out_1 = .true.
207  move_prv = .true.
208  ELSE
209  in_tmp_2 => in_tmp_1
210  out_tmp_1 => tensor_out
211  END IF
212 
213  IF (PRESENT(order)) THEN
214  ALLOCATE (in_tmp_3)
215  CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
216  new_in_3 = .true.
217  ELSE
218  in_tmp_3 => in_tmp_2
219  END IF
220 
221  ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
222  ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
223  CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
224 
225  ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
226  ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
227  CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
228 
229  IF (.NOT. PRESENT(order)) THEN
230  IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
231  dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
232  ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
233  dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234  END IF
235  END IF
236 
237  IF (dist_compatible_tas) THEN
238  CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
239  IF (move_prv) CALL dbt_clear(in_tmp_3)
240  ELSEIF (dist_compatible_tensor) THEN
241  CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
242  IF (move_prv) CALL dbt_clear(in_tmp_3)
243  ELSE
244  CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
245  END IF
246 
247  IF (new_in_1) THEN
248  CALL dbt_destroy(in_tmp_1)
249  DEALLOCATE (in_tmp_1)
250  END IF
251 
252  IF (new_in_2) THEN
253  CALL dbt_destroy(in_tmp_2)
254  DEALLOCATE (in_tmp_2)
255  END IF
256 
257  IF (new_in_3) THEN
258  CALL dbt_destroy(in_tmp_3)
259  DEALLOCATE (in_tmp_3)
260  END IF
261 
262  IF (new_out_1) THEN
263  IF (unit_nr_prv /= 0) THEN
264  CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
265  END IF
266  CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
267  CALL dbt_destroy(out_tmp_1)
268  DEALLOCATE (out_tmp_1)
269  END IF
270 
271  CALL timestop(handle)
272 
273  END SUBROUTINE
274 
275 ! **************************************************************************************************
276 !> \brief copy without communication, requires that both tensors have same process grid and distribution
277 !> \param summation Whether to sum matrices b = a + b
278 !> \author Patrick Seewald
279 ! **************************************************************************************************
280  SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
281  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
282  TYPE(dbt_type), INTENT(INOUT) :: tensor_out
283  LOGICAL, INTENT(IN), OPTIONAL :: summation
284  TYPE(dbt_iterator_type) :: iter
285  INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
286  TYPE(block_nd) :: blk_data
287  LOGICAL :: found
288 
289  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_nocomm'
290  INTEGER :: handle
291 
292  CALL timeset(routinen, handle)
293  cpassert(tensor_out%valid)
294 
295  IF (PRESENT(summation)) THEN
296  IF (.NOT. summation) CALL dbt_clear(tensor_out)
297  ELSE
298  CALL dbt_clear(tensor_out)
299  END IF
300 
301  CALL dbt_reserve_blocks(tensor_in, tensor_out)
302 
303 !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
304 !$OMP PRIVATE(iter,ind_nd,blk_data,found)
305  CALL dbt_iterator_start(iter, tensor_in)
306  DO WHILE (dbt_iterator_blocks_left(iter))
307  CALL dbt_iterator_next_block(iter, ind_nd)
308  CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
309  cpassert(found)
310  CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
311  CALL destroy_block(blk_data)
312  END DO
313  CALL dbt_iterator_stop(iter)
314 !$OMP END PARALLEL
315 
316  CALL timestop(handle)
317  END SUBROUTINE
318 
319 ! **************************************************************************************************
320 !> \brief copy matrix to tensor.
321 !> \param summation tensor_out = tensor_out + matrix_in
322 !> \author Patrick Seewald
323 ! **************************************************************************************************
324  SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
325  TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in
326  TYPE(dbt_type), INTENT(INOUT) :: tensor_out
327  LOGICAL, INTENT(IN), OPTIONAL :: summation
328  TYPE(dbcsr_type), POINTER :: matrix_in_desym
329 
330  INTEGER, DIMENSION(2) :: ind_2d
331  REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: block_arr
332  REAL(kind=dp), DIMENSION(:, :), POINTER :: block
333  TYPE(dbcsr_iterator_type) :: iter
334  LOGICAL :: tr
335 
336  INTEGER :: handle
337  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_matrix_to_tensor'
338 
339  CALL timeset(routinen, handle)
340  cpassert(tensor_out%valid)
341 
342  NULLIFY (block)
343 
344  IF (dbcsr_has_symmetry(matrix_in)) THEN
345  ALLOCATE (matrix_in_desym)
346  CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
347  ELSE
348  matrix_in_desym => matrix_in
349  END IF
350 
351  IF (PRESENT(summation)) THEN
352  IF (.NOT. summation) CALL dbt_clear(tensor_out)
353  ELSE
354  CALL dbt_clear(tensor_out)
355  END IF
356 
357  CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
358 
359 !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
360 !$OMP PRIVATE(iter,ind_2d,block,tr,block_arr)
361  CALL dbcsr_iterator_start(iter, matrix_in_desym)
362  DO WHILE (dbcsr_iterator_blocks_left(iter))
363  CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
364  CALL allocate_any(block_arr, source=block)
365  CALL dbt_put_block(tensor_out, ind_2d, shape(block_arr), block_arr, summation=summation)
366  DEALLOCATE (block_arr)
367  END DO
368  CALL dbcsr_iterator_stop(iter)
369 !$OMP END PARALLEL
370 
371  IF (dbcsr_has_symmetry(matrix_in)) THEN
372  CALL dbcsr_release(matrix_in_desym)
373  DEALLOCATE (matrix_in_desym)
374  END IF
375 
376  CALL timestop(handle)
377 
378  END SUBROUTINE
379 
380 ! **************************************************************************************************
381 !> \brief copy tensor to matrix
382 !> \param summation matrix_out = matrix_out + tensor_in
383 !> \author Patrick Seewald
384 ! **************************************************************************************************
385  SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
386  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
387  TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
388  LOGICAL, INTENT(IN), OPTIONAL :: summation
389  TYPE(dbt_iterator_type) :: iter
390  INTEGER :: handle
391  INTEGER, DIMENSION(2) :: ind_2d
392  REAL(kind=dp), DIMENSION(:, :), ALLOCATABLE :: block
393  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_tensor_to_matrix'
394  LOGICAL :: found
395 
396  CALL timeset(routinen, handle)
397 
398  IF (PRESENT(summation)) THEN
399  IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
400  ELSE
401  CALL dbcsr_clear(matrix_out)
402  END IF
403 
404  CALL dbt_reserve_blocks(tensor_in, matrix_out)
405 
406 !$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
407 !$OMP PRIVATE(iter,ind_2d,block,found)
408  CALL dbt_iterator_start(iter, tensor_in)
409  DO WHILE (dbt_iterator_blocks_left(iter))
410  CALL dbt_iterator_next_block(iter, ind_2d)
411  IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) cycle
412 
413  CALL dbt_get_block(tensor_in, ind_2d, block, found)
414  cpassert(found)
415 
416  IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
417  CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), transpose(block), summation=summation)
418  ELSE
419  CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
420  END IF
421  DEALLOCATE (block)
422  END DO
423  CALL dbt_iterator_stop(iter)
424 !$OMP END PARALLEL
425 
426  CALL timestop(handle)
427 
428  END SUBROUTINE
429 
430 ! **************************************************************************************************
431 !> \brief Contract tensors by multiplying matrix representations.
432 !> tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
433 !> * tensor_2(contract_2, notcontract_2)
434 !> + beta * tensor_3(map_1, map_2)
435 !>
436 !> \note
437 !> note 1: block sizes of the corresponding indices need to be the same in all tensors.
438 !>
439 !> note 2: for best performance the tensors should have been created in matrix layouts
440 !> compatible with the contraction, e.g. tensor_1 should have been created with either
441 !> map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
442 !> map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
443 !> tensor_3 and map_1 / map_2).
444 !> Furthermore the two largest tensors involved in the contraction should map both to either
445 !> tall or short matrices: the largest matrix dimension should be "on the same side"
446 !> and should have identical distribution (which is always the case if the distributions were
447 !> obtained with dbt_default_distvec).
448 !>
449 !> note 3: if the same tensor occurs in multiple contractions, a different tensor object should
450 !> be created for each contraction and the data should be copied between the tensors by use of
451 !> dbt_copy. If the same tensor object is used in multiple contractions,
452 !> matrix layouts are not compatible for all contractions (see note 2).
453 !>
454 !> note 4: automatic optimizations are enabled by using the feature of batched contraction, see
455 !> dbt_batched_contract_init, dbt_batched_contract_finalize.
456 !> The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
457 !>
458 !> \param tensor_1 first tensor (in)
459 !> \param tensor_2 second tensor (in)
460 !> \param contract_1 indices of tensor_1 to contract
461 !> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
462 !> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
463 !> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
464 !> \param notcontract_1 indices of tensor_1 not to contract
465 !> \param notcontract_2 indices of tensor_2 not to contract
466 !> \param tensor_3 contracted tensor (out)
467 !> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
468 !> start and end index of an index range over which to contract.
469 !> For use in batched contraction.
470 !> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
471 !> For use in batched contraction.
472 !> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
473 !> For use in batched contraction.
474 !> \param optimize_dist Whether distribution should be optimized internally. In the current
475 !> implementation this guarantees optimal parameters only for dense matrices.
476 !> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
477 !> This can be used to choose optimal process grids for subsequent tensor
478 !> contractions with tensors of similar shape and sparsity. Under some conditions,
479 !> pgrid_opt_1 can not be returned, in this case the pointer is not associated.
480 !> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
481 !> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
482 !> \param filter_eps As in DBM mm
483 !> \param flop As in DBM mm
484 !> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
485 !> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
486 !> \param unit_nr output unit for logging
487 !> set it to -1 on ranks that should not write (and any valid unit number on
488 !> ranks that should write output) if 0 on ALL ranks, no output is written
489 !> \param log_verbose verbose logging (for testing only)
490 !> \author Patrick Seewald
491 ! **************************************************************************************************
492  SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
493  contract_1, notcontract_1, &
494  contract_2, notcontract_2, &
495  map_1, map_2, &
496  bounds_1, bounds_2, bounds_3, &
497  optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
498  filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
499  REAL(dp), INTENT(IN) :: alpha
500  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
501  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
502  REAL(dp), INTENT(IN) :: beta
503  INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
504  INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
505  INTEGER, DIMENSION(:), INTENT(IN) :: map_1
506  INTEGER, DIMENSION(:), INTENT(IN) :: map_2
507  INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
508  INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
509  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
510  INTEGER, DIMENSION(2, SIZE(contract_1)), &
511  INTENT(IN), OPTIONAL :: bounds_1
512  INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
513  INTENT(IN), OPTIONAL :: bounds_2
514  INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
515  INTENT(IN), OPTIONAL :: bounds_3
516  LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
517  TYPE(dbt_pgrid_type), INTENT(OUT), &
518  POINTER, OPTIONAL :: pgrid_opt_1
519  TYPE(dbt_pgrid_type), INTENT(OUT), &
520  POINTER, OPTIONAL :: pgrid_opt_2
521  TYPE(dbt_pgrid_type), INTENT(OUT), &
522  POINTER, OPTIONAL :: pgrid_opt_3
523  REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
524  INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
525  LOGICAL, INTENT(IN), OPTIONAL :: move_data
526  LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
527  INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
528  LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
529 
530  INTEGER :: handle
531 
532  CALL tensor_1%pgrid%mp_comm_2d%sync()
533  CALL timeset("dbt_total", handle)
534  CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
535  contract_1, notcontract_1, &
536  contract_2, notcontract_2, &
537  map_1, map_2, &
538  bounds_1=bounds_1, &
539  bounds_2=bounds_2, &
540  bounds_3=bounds_3, &
541  optimize_dist=optimize_dist, &
542  pgrid_opt_1=pgrid_opt_1, &
543  pgrid_opt_2=pgrid_opt_2, &
544  pgrid_opt_3=pgrid_opt_3, &
545  filter_eps=filter_eps, &
546  flop=flop, &
547  move_data=move_data, &
548  retain_sparsity=retain_sparsity, &
549  unit_nr=unit_nr, &
550  log_verbose=log_verbose)
551  CALL tensor_1%pgrid%mp_comm_2d%sync()
552  CALL timestop(handle)
553 
554  END SUBROUTINE
555 
556 ! **************************************************************************************************
557 !> \brief expert routine for tensor contraction. For internal use only.
558 !> \param nblks_local number of local blocks on this MPI rank
559 !> \author Patrick Seewald
560 ! **************************************************************************************************
561  SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
562  contract_1, notcontract_1, &
563  contract_2, notcontract_2, &
564  map_1, map_2, &
565  bounds_1, bounds_2, bounds_3, &
566  optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
567  filter_eps, flop, move_data, retain_sparsity, &
568  nblks_local, unit_nr, log_verbose)
569  REAL(dp), INTENT(IN) :: alpha
570  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
571  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
572  REAL(dp), INTENT(IN) :: beta
573  INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
574  INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
575  INTEGER, DIMENSION(:), INTENT(IN) :: map_1
576  INTEGER, DIMENSION(:), INTENT(IN) :: map_2
577  INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
578  INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
579  TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
580  INTEGER, DIMENSION(2, SIZE(contract_1)), &
581  INTENT(IN), OPTIONAL :: bounds_1
582  INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
583  INTENT(IN), OPTIONAL :: bounds_2
584  INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
585  INTENT(IN), OPTIONAL :: bounds_3
586  LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
587  TYPE(dbt_pgrid_type), INTENT(OUT), &
588  POINTER, OPTIONAL :: pgrid_opt_1
589  TYPE(dbt_pgrid_type), INTENT(OUT), &
590  POINTER, OPTIONAL :: pgrid_opt_2
591  TYPE(dbt_pgrid_type), INTENT(OUT), &
592  POINTER, OPTIONAL :: pgrid_opt_3
593  REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
594  INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
595  LOGICAL, INTENT(IN), OPTIONAL :: move_data
596  LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
597  INTEGER, INTENT(OUT), OPTIONAL :: nblks_local
598  INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
599  LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
600 
601  TYPE(dbt_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
602  TYPE(dbt_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
603  TYPE(dbt_type), POINTER :: tensor_crop_1, tensor_crop_2
604  TYPE(dbt_type), POINTER :: tensor_small, tensor_large
605 
606  LOGICAL :: assert_stmt, tensors_remapped
607  INTEGER :: max_mm_dim, max_tensor, &
608  unit_nr_prv, ref_tensor, handle
609  TYPE(mp_cart_type) :: mp_comm_opt
610  INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod
611  INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
612  INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod
613  INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
614  INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod
615  INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod
616  LOGICAL :: trans_1, trans_2, trans_3
617  LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
618  INTEGER :: ndims1, ndims2, ndims3
619  INTEGER :: occ_1, occ_2
620  INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3
621 
622  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_contract'
623  CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
624  indchar2_mod, indchar3_mod
625  CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
626  ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
627  INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
628  INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
629  LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
630  pgrid_changed_any, do_change_pgrid(2)
631  TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
632  INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
633  REAL(dp) :: pdim_ratio, pdim_ratio_opt
634 
635  NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
636  tensor_small)
637 
638  CALL timeset(routinen, handle)
639 
640  cpassert(tensor_1%valid)
641  cpassert(tensor_2%valid)
642  cpassert(tensor_3%valid)
643 
644  assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
645  cpassert(assert_stmt)
646 
647  assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
648  cpassert(assert_stmt)
649 
650  assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
651  cpassert(assert_stmt)
652 
653  assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
654  cpassert(assert_stmt)
655 
656  assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
657  cpassert(assert_stmt)
658 
659  assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
660  cpassert(assert_stmt)
661 
662  unit_nr_prv = prep_output_unit(unit_nr)
663 
664  IF (PRESENT(flop)) flop = 0
665  IF (PRESENT(nblks_local)) nblks_local = 0
666 
667  IF (PRESENT(move_data)) THEN
668  move_data_1 = move_data
669  move_data_2 = move_data
670  ELSE
671  move_data_1 = .false.
672  move_data_2 = .false.
673  END IF
674 
675  nodata_3 = .true.
676  IF (PRESENT(retain_sparsity)) THEN
677  IF (retain_sparsity) nodata_3 = .false.
678  END IF
679 
680  CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
681  contract_1, notcontract_1, &
682  contract_2, notcontract_2, &
683  bounds_t1, bounds_t2, &
684  bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
685  do_crop_1=do_crop_1, do_crop_2=do_crop_2)
686 
687  IF (do_crop_1) THEN
688  ALLOCATE (tensor_crop_1)
689  CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
690  move_data_1 = .true.
691  ELSE
692  tensor_crop_1 => tensor_1
693  END IF
694 
695  IF (do_crop_2) THEN
696  ALLOCATE (tensor_crop_2)
697  CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
698  move_data_2 = .true.
699  ELSE
700  tensor_crop_2 => tensor_2
701  END IF
702 
703  ! shortcut for empty tensors
704  ! this is needed to avoid unnecessary work in case user contracts different portions of a
705  ! tensor consecutively to save memory
706  associate(mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
707  occ_1 = dbt_get_num_blocks(tensor_crop_1)
708  CALL mp_comm%max(occ_1)
709  occ_2 = dbt_get_num_blocks(tensor_crop_2)
710  CALL mp_comm%max(occ_2)
711  END associate
712 
713  IF (occ_1 == 0 .OR. occ_2 == 0) THEN
714  CALL dbt_scale(tensor_3, beta)
715  IF (do_crop_1) THEN
716  CALL dbt_destroy(tensor_crop_1)
717  DEALLOCATE (tensor_crop_1)
718  END IF
719  IF (do_crop_2) THEN
720  CALL dbt_destroy(tensor_crop_2)
721  DEALLOCATE (tensor_crop_2)
722  END IF
723 
724  CALL timestop(handle)
725  RETURN
726  END IF
727 
728  IF (unit_nr_prv /= 0) THEN
729  IF (unit_nr_prv > 0) THEN
730  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
731  WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
732  trim(tensor_crop_1%name), 'x', trim(tensor_crop_2%name), '=', trim(tensor_3%name)
733  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
734  END IF
735  CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
736  CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
737  CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
738  CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
739  END IF
740 
741  ! align tensor index with data, tensor data is not modified
742  ndims1 = ndims_tensor(tensor_crop_1)
743  ndims2 = ndims_tensor(tensor_crop_2)
744  ndims3 = ndims_tensor(tensor_3)
745  ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
746  ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
747  ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
748 
749  ! labeling tensor index with letters
750 
751  indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
752  indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
753  indchar2(contract_2) = indchar1(contract_1)
754  indchar3(map_1) = indchar1(notcontract_1)
755  indchar3(map_2) = indchar2(notcontract_2)
756 
757  IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
758  tensor_crop_2, indchar2, &
759  tensor_3, indchar3, unit_nr_prv)
760  IF (unit_nr_prv > 0) THEN
761  WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
762  END IF
763 
764  CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
765  tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
766 
767  CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
768  tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
769 
770  CALL align_tensor(tensor_3, map_1, map_2, &
771  tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
772 
773  IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
774  tensor_algn_2, indchar2_mod, &
775  tensor_algn_3, indchar3_mod, unit_nr_prv)
776 
777  ALLOCATE (dims1(ndims1))
778  ALLOCATE (dims2(ndims2))
779  ALLOCATE (dims3(ndims3))
780 
781  ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
782  ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
783  CALL blk_dims_tensor(tensor_crop_1, dims1)
784  CALL blk_dims_tensor(tensor_crop_2, dims2)
785  CALL blk_dims_tensor(tensor_3, dims3)
786 
787  max_mm_dim = maxloc([product(int(dims1(notcontract_1), int_8)), &
788  product(int(dims1(contract_1), int_8)), &
789  product(int(dims2(notcontract_2), int_8))], dim=1)
790  max_tensor = maxloc([product(int(dims1, int_8)), product(int(dims2, int_8)), product(int(dims3, int_8))], dim=1)
791  SELECT CASE (max_mm_dim)
792  CASE (1)
793  IF (unit_nr_prv > 0) THEN
794  WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
795  WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
796  END IF
797  CALL index_linked_sort(contract_1_mod, contract_2_mod)
798  CALL index_linked_sort(map_2_mod, notcontract_2_mod)
799  SELECT CASE (max_tensor)
800  CASE (1)
801  CALL index_linked_sort(notcontract_1_mod, map_1_mod)
802  CASE (3)
803  CALL index_linked_sort(map_1_mod, notcontract_1_mod)
804  CASE DEFAULT
805  cpabort("should not happen")
806  END SELECT
807 
808  CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
809  contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
810  trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
811  move_data_1=move_data_1, unit_nr=unit_nr_prv)
812 
813  CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
814  new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
815 
816  SELECT CASE (ref_tensor)
817  CASE (1)
818  tensor_large => tensor_contr_1
819  CASE (2)
820  tensor_large => tensor_contr_3
821  END SELECT
822  tensor_small => tensor_contr_2
823 
824  CASE (2)
825  IF (unit_nr_prv > 0) THEN
826  WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
827  WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
828  END IF
829 
830  CALL index_linked_sort(notcontract_1_mod, map_1_mod)
831  CALL index_linked_sort(notcontract_2_mod, map_2_mod)
832  SELECT CASE (max_tensor)
833  CASE (1)
834  CALL index_linked_sort(contract_1_mod, contract_2_mod)
835  CASE (2)
836  CALL index_linked_sort(contract_2_mod, contract_1_mod)
837  CASE DEFAULT
838  cpabort("should not happen")
839  END SELECT
840 
841  CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
842  notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
843  trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
844  move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
845  trans_1 = .NOT. trans_1
846 
847  CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
848  new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
849 
850  SELECT CASE (ref_tensor)
851  CASE (1)
852  tensor_large => tensor_contr_1
853  CASE (2)
854  tensor_large => tensor_contr_2
855  END SELECT
856  tensor_small => tensor_contr_3
857 
858  CASE (3)
859  IF (unit_nr_prv > 0) THEN
860  WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
861  WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
862  END IF
863  CALL index_linked_sort(map_1_mod, notcontract_1_mod)
864  CALL index_linked_sort(contract_2_mod, contract_1_mod)
865  SELECT CASE (max_tensor)
866  CASE (2)
867  CALL index_linked_sort(notcontract_2_mod, map_2_mod)
868  CASE (3)
869  CALL index_linked_sort(map_2_mod, notcontract_2_mod)
870  CASE DEFAULT
871  cpabort("should not happen")
872  END SELECT
873 
874  CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
875  contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
876  trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
877  move_data_1=move_data_2, unit_nr=unit_nr_prv)
878 
879  trans_2 = .NOT. trans_2
880  trans_3 = .NOT. trans_3
881 
882  CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
883  trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
884 
885  SELECT CASE (ref_tensor)
886  CASE (1)
887  tensor_large => tensor_contr_2
888  CASE (2)
889  tensor_large => tensor_contr_3
890  END SELECT
891  tensor_small => tensor_contr_1
892 
893  END SELECT
894 
895  IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
896  tensor_contr_2, indchar2_mod, &
897  tensor_contr_3, indchar3_mod, unit_nr_prv)
898  IF (unit_nr_prv /= 0) THEN
899  IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
900  IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
901  IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
902  IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
903  END IF
904 
905  CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
906  tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
907  beta, &
908  tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
909  unit_nr=unit_nr_prv, log_verbose=log_verbose, &
910  split_opt=split_opt, &
911  move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
912 
913  IF (PRESENT(pgrid_opt_1)) THEN
914  IF (.NOT. new_1) THEN
915  ALLOCATE (pgrid_opt_1)
916  pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
917  END IF
918  END IF
919 
920  IF (PRESENT(pgrid_opt_2)) THEN
921  IF (.NOT. new_2) THEN
922  ALLOCATE (pgrid_opt_2)
923  pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
924  END IF
925  END IF
926 
927  IF (PRESENT(pgrid_opt_3)) THEN
928  IF (.NOT. new_3) THEN
929  ALLOCATE (pgrid_opt_3)
930  pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
931  END IF
932  END IF
933 
934  do_batched = tensor_small%matrix_rep%do_batched > 0
935 
936  tensors_remapped = .false.
937  IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .true.
938 
939  IF (tensors_remapped .AND. do_batched) THEN
940  CALL cp_warn(__location__, &
941  "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
942  END IF
943 
944  ! optimize process grid during batched contraction
945  do_change_pgrid(:) = .false.
946  IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
947  associate(storage => tensor_small%contraction_storage)
948  cpassert(storage%static)
949  split = dbt_tas_info(tensor_large%matrix_rep)
950  do_change_pgrid(:) = &
951  update_contraction_storage(storage, split_opt, split)
952 
953  IF (any(do_change_pgrid)) THEN
954  mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, nint(storage%nsplit_avg))
955  CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
956  nint(storage%nsplit_avg), own_comm=.true.)
957  pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
958  END IF
959 
960  END associate
961 
962  IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
963  ! check if new grid has better subgrid, if not there is no need to change process grid
964  pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
965  pdims_sub = split%mp_comm_group%num_pe_cart
966 
967  pdim_ratio = maxval(real(pdims_sub, dp))/minval(pdims_sub)
968  pdim_ratio_opt = maxval(real(pdims_sub_opt, dp))/minval(pdims_sub_opt)
969  IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
970  do_change_pgrid(1) = .false.
971  CALL dbt_tas_release_info(split_opt_avg)
972  END IF
973  END IF
974  END IF
975 
976  IF (unit_nr_prv /= 0) THEN
977  do_write_3 = .true.
978  IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
979  IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .false.
980  END IF
981  IF (do_write_3) THEN
982  CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
983  CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
984  END IF
985  END IF
986 
987  IF (new_3) THEN
988  ! need redistribute if we created new tensor for tensor 3
989  CALL dbt_scale(tensor_algn_3, beta)
990  CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.true., move_data=.true.)
991  IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
992  ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
993  ! pointer to data of tensor_3
994  END IF
995 
996  ! transfer contraction storage
997  CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
998  CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
999  CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
1000 
1001  IF (unit_nr_prv /= 0) THEN
1002  IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1003  IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1004  END IF
1005 
1006  CALL dbt_destroy(tensor_algn_1)
1007  CALL dbt_destroy(tensor_algn_2)
1008  CALL dbt_destroy(tensor_algn_3)
1009 
1010  IF (do_crop_1) THEN
1011  CALL dbt_destroy(tensor_crop_1)
1012  DEALLOCATE (tensor_crop_1)
1013  END IF
1014 
1015  IF (do_crop_2) THEN
1016  CALL dbt_destroy(tensor_crop_2)
1017  DEALLOCATE (tensor_crop_2)
1018  END IF
1019 
1020  IF (new_1) THEN
1021  CALL dbt_destroy(tensor_contr_1)
1022  DEALLOCATE (tensor_contr_1)
1023  END IF
1024  IF (new_2) THEN
1025  CALL dbt_destroy(tensor_contr_2)
1026  DEALLOCATE (tensor_contr_2)
1027  END IF
1028  IF (new_3) THEN
1029  CALL dbt_destroy(tensor_contr_3)
1030  DEALLOCATE (tensor_contr_3)
1031  END IF
1032 
1033  IF (PRESENT(move_data)) THEN
1034  IF (move_data) THEN
1035  CALL dbt_clear(tensor_1)
1036  CALL dbt_clear(tensor_2)
1037  END IF
1038  END IF
1039 
1040  IF (unit_nr_prv > 0) THEN
1041  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1042  WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1043  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1044  END IF
1045 
1046  IF (any(do_change_pgrid)) THEN
1047  pgrid_changed_any = .false.
1048  SELECT CASE (max_mm_dim)
1049  CASE (1)
1050  IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1051  CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1052  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1053  pgrid_changed=pgrid_changed, &
1054  unit_nr=unit_nr_prv)
1055  IF (pgrid_changed) pgrid_changed_any = .true.
1056  CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1057  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1058  pgrid_changed=pgrid_changed, &
1059  unit_nr=unit_nr_prv)
1060  IF (pgrid_changed) pgrid_changed_any = .true.
1061  END IF
1062  IF (pgrid_changed_any) THEN
1063  IF (tensor_2%matrix_rep%do_batched == 3) THEN
1064  ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1065  ! in TAS multiply algorithm
1066  CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1067  END IF
1068  END IF
1069  CASE (2)
1070  IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1071  CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1072  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1073  pgrid_changed=pgrid_changed, &
1074  unit_nr=unit_nr_prv)
1075  IF (pgrid_changed) pgrid_changed_any = .true.
1076  CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1077  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1078  pgrid_changed=pgrid_changed, &
1079  unit_nr=unit_nr_prv)
1080  IF (pgrid_changed) pgrid_changed_any = .true.
1081  END IF
1082  IF (pgrid_changed_any) THEN
1083  IF (tensor_3%matrix_rep%do_batched == 3) THEN
1084  CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1085  END IF
1086  END IF
1087  CASE (3)
1088  IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1089  CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1090  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1091  pgrid_changed=pgrid_changed, &
1092  unit_nr=unit_nr_prv)
1093  IF (pgrid_changed) pgrid_changed_any = .true.
1094  CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1095  nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1096  pgrid_changed=pgrid_changed, &
1097  unit_nr=unit_nr_prv)
1098  IF (pgrid_changed) pgrid_changed_any = .true.
1099  END IF
1100  IF (pgrid_changed_any) THEN
1101  IF (tensor_1%matrix_rep%do_batched == 3) THEN
1102  CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1103  END IF
1104  END IF
1105  END SELECT
1106  CALL dbt_tas_release_info(split_opt_avg)
1107  END IF
1108 
1109  IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1110  ! freeze TAS process grids if tensor grids were optimized
1111  CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.true.)
1112  CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.true.)
1113  CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.true.)
1114  END IF
1115 
1116  CALL dbt_tas_release_info(split_opt)
1117 
1118  CALL timestop(handle)
1119 
1120  END SUBROUTINE
1121 
1122 ! **************************************************************************************************
1123 !> \brief align tensor index with data
1124 !> \author Patrick Seewald
1125 ! **************************************************************************************************
1126  SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1127  tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1128  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1129  INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in
1130  TYPE(dbt_type), INTENT(OUT) :: tensor_out
1131  INTEGER, DIMENSION(SIZE(contract_in)), &
1132  INTENT(OUT) :: contract_out
1133  INTEGER, DIMENSION(SIZE(notcontract_in)), &
1134  INTENT(OUT) :: notcontract_out
1135  CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1136  CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1137  INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1138 
1139  CALL dbt_align_index(tensor_in, tensor_out, order=align)
1140  contract_out = align(contract_in)
1141  notcontract_out = align(notcontract_in)
1142  indp_out(align) = indp_in
1143 
1144  END SUBROUTINE
1145 
1146 ! **************************************************************************************************
1147 !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1148 !> matrix multiplication. This routine reshapes the two largest of the three tensors.
1149 !> Redistribution is avoided if tensors already in a consistent layout.
1150 !> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1151 !> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
1152 !> 1:1 correspondence with ind1_linked
1153 !> \param trans1 transpose flag of matrix rep. tensor 1
1154 !> \param trans2 transpose flag of matrix rep. tensor 2
1155 !> \param new1 whether a new tensor 1 was created
1156 !> \param new2 whether a new tensor 2 was created
1157 !> \param nodata1 don't copy data of tensor 1
1158 !> \param nodata2 don't copy data of tensor 2
1159 !> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
1160 !> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
1161 !> \param optimize_dist experimental: optimize distribution
1162 !> \param unit_nr output unit
1163 !> \author Patrick Seewald
1164 ! **************************************************************************************************
1165  SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1166  ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1167  nodata1, nodata2, move_data_1, &
1168  move_data_2, optimize_dist, unit_nr)
1169  TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor1
1170  TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor2
1171  TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out
1172  INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free
1173  INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked
1174  LOGICAL, INTENT(OUT) :: trans1, trans2
1175  LOGICAL, INTENT(OUT) :: new1, new2
1176  INTEGER, INTENT(OUT) :: ref_tensor
1177  LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
1178  LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
1179  LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
1180  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1181  INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1182  unit_nr_prv
1183  TYPE(mp_cart_type) :: comm_2d
1184  TYPE(array_list) :: dist_list
1185  INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims
1186  TYPE(dbt_distribution_type) :: dist_in
1187  INTEGER(KIND=int_8) :: nblkrows, nblkcols
1188  LOGICAL :: optimize_dist_prv
1189  INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1190  INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1191 
1192  NULLIFY (tensor1_out, tensor2_out)
1193 
1194  unit_nr_prv = prep_output_unit(unit_nr)
1195 
1196  CALL blk_dims_tensor(tensor1, dims1)
1197  CALL blk_dims_tensor(tensor2, dims2)
1198 
1199  IF (product(int(dims1, int_8)) .GE. product(int(dims2, int_8))) THEN
1200  ref_tensor = 1
1201  ELSE
1202  ref_tensor = 2
1203  END IF
1204 
1205  IF (PRESENT(optimize_dist)) THEN
1206  optimize_dist_prv = optimize_dist
1207  ELSE
1208  optimize_dist_prv = .false.
1209  END IF
1210 
1211  compat1 = compat_map(tensor1%nd_index, ind1_linked)
1212  compat2 = compat_map(tensor2%nd_index, ind2_linked)
1213  compat1_old = compat1
1214  compat2_old = compat2
1215 
1216  IF (unit_nr_prv > 0) THEN
1217  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1%name), ":"
1218  SELECT CASE (compat1)
1219  CASE (0)
1220  WRITE (unit_nr_prv, '(A)') "Not compatible"
1221  CASE (1)
1222  WRITE (unit_nr_prv, '(A)') "Normal"
1223  CASE (2)
1224  WRITE (unit_nr_prv, '(A)') "Transposed"
1225  END SELECT
1226  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2%name), ":"
1227  SELECT CASE (compat2)
1228  CASE (0)
1229  WRITE (unit_nr_prv, '(A)') "Not compatible"
1230  CASE (1)
1231  WRITE (unit_nr_prv, '(A)') "Normal"
1232  CASE (2)
1233  WRITE (unit_nr_prv, '(A)') "Transposed"
1234  END SELECT
1235  END IF
1236 
1237  new1 = .false.
1238  new2 = .false.
1239 
1240  IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1241  new1 = .true.
1242  END IF
1243 
1244  IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1245  new2 = .true.
1246  END IF
1247 
1248  IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1249  IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1250  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor1%name)
1251  nblkrows = product(int(dims1(ind1_linked), kind=int_8))
1252  nblkcols = product(int(dims1(ind1_free), kind=int_8))
1253  comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1254  ALLOCATE (tensor1_out)
1255  CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1256  nodata=nodata1, move_data=move_data_1)
1257  CALL comm_2d%free()
1258  compat1 = 1
1259  ELSE
1260  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1261  tensor1_out => tensor1
1262  END IF
1263  IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1264  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1265  trim(tensor2%name), "compatible with", trim(tensor1%name)
1266  dist_in = dbt_distribution(tensor1_out)
1267  dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1268  IF (compat1 == 1) THEN ! linked index is first 2d dimension
1269  ! get distribution of linked index, tensor 2 must adopt this distribution
1270  ! get grid dimensions of linked index
1271  ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1272  CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1273  ALLOCATE (tensor2_out)
1274  CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1275  dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1276  ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1277  ! get distribution of linked index, tensor 2 must adopt this distribution
1278  ! get grid dimensions of linked index
1279  ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1280  CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1281  ALLOCATE (tensor2_out)
1282  CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1283  dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1284  ELSE
1285  cpabort("should not happen")
1286  END IF
1287  compat2 = compat1
1288  ELSE
1289  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1290  tensor2_out => tensor2
1291  END IF
1292  ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1293  IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1294  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor2%name)
1295  nblkrows = product(int(dims2(ind2_linked), kind=int_8))
1296  nblkcols = product(int(dims2(ind2_free), kind=int_8))
1297  comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1298  ALLOCATE (tensor2_out)
1299  CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1300  CALL comm_2d%free()
1301  compat2 = 1
1302  ELSE
1303  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1304  tensor2_out => tensor2
1305  END IF
1306  IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1307  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", trim(tensor1%name), &
1308  "compatible with", trim(tensor2%name)
1309  dist_in = dbt_distribution(tensor2_out)
1310  dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1311  IF (compat2 == 1) THEN
1312  ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1313  CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1314  ALLOCATE (tensor1_out)
1315  CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1316  dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1317  ELSEIF (compat2 == 2) THEN
1318  ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1319  CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1320  ALLOCATE (tensor1_out)
1321  CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1322  dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1323  ELSE
1324  cpabort("should not happen")
1325  END IF
1326  compat1 = compat2
1327  ELSE
1328  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1329  tensor1_out => tensor1
1330  END IF
1331  END IF
1332 
1333  SELECT CASE (compat1)
1334  CASE (1)
1335  trans1 = .false.
1336  CASE (2)
1337  trans1 = .true.
1338  CASE DEFAULT
1339  cpabort("should not happen")
1340  END SELECT
1341 
1342  SELECT CASE (compat2)
1343  CASE (1)
1344  trans2 = .false.
1345  CASE (2)
1346  trans2 = .true.
1347  CASE DEFAULT
1348  cpabort("should not happen")
1349  END SELECT
1350 
1351  IF (unit_nr_prv > 0) THEN
1352  IF (compat1 .NE. compat1_old) THEN
1353  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1_out%name), ":"
1354  SELECT CASE (compat1)
1355  CASE (0)
1356  WRITE (unit_nr_prv, '(A)') "Not compatible"
1357  CASE (1)
1358  WRITE (unit_nr_prv, '(A)') "Normal"
1359  CASE (2)
1360  WRITE (unit_nr_prv, '(A)') "Transposed"
1361  END SELECT
1362  END IF
1363  IF (compat2 .NE. compat2_old) THEN
1364  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2_out%name), ":"
1365  SELECT CASE (compat2)
1366  CASE (0)
1367  WRITE (unit_nr_prv, '(A)') "Not compatible"
1368  CASE (1)
1369  WRITE (unit_nr_prv, '(A)') "Normal"
1370  CASE (2)
1371  WRITE (unit_nr_prv, '(A)') "Transposed"
1372  END SELECT
1373  END IF
1374  END IF
1375 
1376  IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .true.
1377  IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .true.
1378 
1379  END SUBROUTINE
1380 
1381 ! **************************************************************************************************
1382 !> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1383 !> matrix multiplication. This routine reshapes the smallest of the three tensors.
1384 !> \param ind1 index that should be mapped to first matrix dimension
1385 !> \param ind2 index that should be mapped to second matrix dimension
1386 !> \param trans transpose flag of matrix rep.
1387 !> \param new whether a new tensor was created for tensor_out
1388 !> \param nodata don't copy tensor data
1389 !> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
1390 !> \param unit_nr output unit
1391 !> \author Patrick Seewald
1392 ! **************************************************************************************************
1393  SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1394  TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor_in
1395  INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2
1396  TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor_out
1397  LOGICAL, INTENT(OUT) :: trans
1398  LOGICAL, INTENT(OUT) :: new
1399  LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1400  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1401  INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1402  LOGICAL :: nodata_prv
1403 
1404  NULLIFY (tensor_out)
1405  IF (PRESENT(nodata)) THEN
1406  nodata_prv = nodata
1407  ELSE
1408  nodata_prv = .false.
1409  END IF
1410 
1411  unit_nr_prv = prep_output_unit(unit_nr)
1412 
1413  new = .false.
1414  compat1 = compat_map(tensor_in%nd_index, ind1)
1415  compat2 = compat_map(tensor_in%nd_index, ind2)
1416  compat1_old = compat1; compat2_old = compat2
1417  IF (unit_nr_prv > 0) THEN
1418  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_in%name), ":"
1419  IF (compat1 == 1 .AND. compat2 == 2) THEN
1420  WRITE (unit_nr_prv, '(A)') "Normal"
1421  ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1422  WRITE (unit_nr_prv, '(A)') "Transposed"
1423  ELSE
1424  WRITE (unit_nr_prv, '(A)') "Not compatible"
1425  END IF
1426  END IF
1427  IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1428 
1429  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor_in%name)
1430 
1431  ALLOCATE (tensor_out)
1432  CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1433  CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1434  compat1 = 1
1435  compat2 = 2
1436  new = .true.
1437  ELSE
1438  IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor_in%name)
1439  tensor_out => tensor_in
1440  END IF
1441 
1442  IF (compat1 == 1 .AND. compat2 == 2) THEN
1443  trans = .false.
1444  ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1445  trans = .true.
1446  ELSE
1447  cpabort("this should not happen")
1448  END IF
1449 
1450  IF (unit_nr_prv > 0) THEN
1451  IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1452  WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_out%name), ":"
1453  IF (compat1 == 1 .AND. compat2 == 2) THEN
1454  WRITE (unit_nr_prv, '(A)') "Normal"
1455  ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1456  WRITE (unit_nr_prv, '(A)') "Transposed"
1457  ELSE
1458  WRITE (unit_nr_prv, '(A)') "Not compatible"
1459  END IF
1460  END IF
1461  END IF
1462 
1463  END SUBROUTINE
1464 
1465 ! **************************************************************************************************
1466 !> \brief update contraction storage that keeps track of process grids during a batched contraction
1467 !> and decide if tensor process grid needs to be optimized
1468 !> \param split_opt optimized TAS process grid
1469 !> \param split current TAS process grid
1470 !> \author Patrick Seewald
1471 ! **************************************************************************************************
1472  FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1473  TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
1474  TYPE(dbt_tas_split_info), INTENT(IN) :: split_opt
1475  TYPE(dbt_tas_split_info), INTENT(IN) :: split
1476  INTEGER, DIMENSION(2) :: pdims, pdims_sub
1477  LOGICAL, DIMENSION(2) :: do_change_pgrid
1478  REAL(kind=dp) :: change_criterion, pdims_ratio
1479  INTEGER :: nsplit_opt, nsplit
1480 
1481  cpassert(ALLOCATED(split_opt%ngroup_opt))
1482  nsplit_opt = split_opt%ngroup_opt
1483  nsplit = split%ngroup
1484 
1485  pdims = split%mp_comm%num_pe_cart
1486 
1487  storage%ibatch = storage%ibatch + 1
1488 
1489  storage%nsplit_avg = (storage%nsplit_avg*real(storage%ibatch - 1, dp) + real(nsplit_opt, dp)) &
1490  /real(storage%ibatch, dp)
1491 
1492  SELECT CASE (split_opt%split_rowcol)
1493  CASE (rowsplit)
1494  pdims_ratio = real(pdims(1), dp)/pdims(2)
1495  CASE (colsplit)
1496  pdims_ratio = real(pdims(2), dp)/pdims(1)
1497  END SELECT
1498 
1499  do_change_pgrid(:) = .false.
1500 
1501  ! check for process grid dimensions
1502  pdims_sub = split%mp_comm_group%num_pe_cart
1503  change_criterion = maxval(real(pdims_sub, dp))/minval(pdims_sub)
1504  IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .true.
1505 
1506  ! check for split factor
1507  change_criterion = max(real(nsplit, dp)/storage%nsplit_avg, real(storage%nsplit_avg, dp)/nsplit)
1508  IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .true.
1509 
1510  END FUNCTION
1511 
1512 ! **************************************************************************************************
1513 !> \brief Check if 2d index is compatible with tensor index
1514 !> \author Patrick Seewald
1515 ! **************************************************************************************************
1516  FUNCTION compat_map(nd_index, compat_ind)
1517  TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1518  INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind
1519  INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1520  INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1521  INTEGER :: compat_map
1522 
1523  CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1524 
1525  compat_map = 0
1526  IF (array_eq_i(map1, compat_ind)) THEN
1527  compat_map = 1
1528  ELSEIF (array_eq_i(map2, compat_ind)) THEN
1529  compat_map = 2
1530  END IF
1531 
1532  END FUNCTION
1533 
1534 ! **************************************************************************************************
1535 !> \brief
1536 !> \author Patrick Seewald
1537 ! **************************************************************************************************
1538  SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1539  INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1540  INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices
1541 
1542  CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1543  ind_dep(:) = ind_dep(sort_indices)
1544 
1545  END SUBROUTINE
1546 
1547 ! **************************************************************************************************
1548 !> \brief
1549 !> \author Patrick Seewald
1550 ! **************************************************************************************************
1551  FUNCTION opt_pgrid(tensor, tas_split_info)
1552  TYPE(dbt_type), INTENT(IN) :: tensor
1553  TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
1554  INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1555  INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1556  TYPE(dbt_pgrid_type) :: opt_pgrid
1557  INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1558 
1559  CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1560  CALL blk_dims_tensor(tensor, dims)
1561  opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1562 
1563  ALLOCATE (opt_pgrid%tas_split_info, source=tas_split_info)
1564  CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1565  END FUNCTION
1566 
1567 ! **************************************************************************************************
1568 !> \brief Copy tensor to tensor with modified index mapping
1569 !> \param map1_2d new index mapping
1570 !> \param map2_2d new index mapping
1571 !> \author Patrick Seewald
1572 ! **************************************************************************************************
1573  SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1574  mp_dims_1, mp_dims_2, name, nodata, move_data)
1575  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1576  INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
1577  TYPE(dbt_type), INTENT(OUT) :: tensor_out
1578  CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1579  LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1580  CLASS(mp_comm_type), INTENT(IN), OPTIONAL :: comm_2d
1581  TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1582  INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1583  INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1584  CHARACTER(len=default_string_length) :: name_tmp
1585  INTEGER, DIMENSION(:), ALLOCATABLE :: blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4, &
1586  nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4
1587  TYPE(dbt_distribution_type) :: dist
1588  TYPE(mp_cart_type) :: comm_2d_prv
1589  INTEGER :: handle, i
1590  INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1591  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_remap'
1592  LOGICAL :: nodata_prv
1593  TYPE(dbt_pgrid_type) :: comm_nd
1594 
1595  CALL timeset(routinen, handle)
1596 
1597  IF (PRESENT(name)) THEN
1598  name_tmp = name
1599  ELSE
1600  name_tmp = tensor_in%name
1601  END IF
1602  IF (PRESENT(dist1)) THEN
1603  cpassert(PRESENT(mp_dims_1))
1604  END IF
1605 
1606  IF (PRESENT(dist2)) THEN
1607  cpassert(PRESENT(mp_dims_2))
1608  END IF
1609 
1610  IF (PRESENT(comm_2d)) THEN
1611  comm_2d_prv = comm_2d
1612  ELSE
1613  comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1614  END IF
1615 
1616  comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1617  CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1618 
1619  IF (ndims_tensor(tensor_in) == 2) THEN
1620  CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2)
1621  END IF
1622  IF (ndims_tensor(tensor_in) == 3) THEN
1623  CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3)
1624  END IF
1625  IF (ndims_tensor(tensor_in) == 4) THEN
1626  CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1627  END IF
1628 
1629  IF (ndims_tensor(tensor_in) == 2) THEN
1630  IF (PRESENT(dist1)) THEN
1631  IF (any(map1_2d == 1)) THEN
1632  i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1633  CALL get_ith_array(dist1, i, nd_dist_1)
1634  END IF
1635  END IF
1636 
1637  IF (PRESENT(dist2)) THEN
1638  IF (any(map2_2d == 1)) THEN
1639  i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1640  CALL get_ith_array(dist2, i, nd_dist_1)
1641  END IF
1642  END IF
1643 
1644  IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1645  ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1646  CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1647  END IF
1648  IF (PRESENT(dist1)) THEN
1649  IF (any(map1_2d == 2)) THEN
1650  i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1651  CALL get_ith_array(dist1, i, nd_dist_2)
1652  END IF
1653  END IF
1654 
1655  IF (PRESENT(dist2)) THEN
1656  IF (any(map2_2d == 2)) THEN
1657  i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1658  CALL get_ith_array(dist2, i, nd_dist_2)
1659  END IF
1660  END IF
1661 
1662  IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1663  ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1664  CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1665  END IF
1666  CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1667  nd_dist_1, nd_dist_2, own_comm=.true.)
1668  END IF
1669  IF (ndims_tensor(tensor_in) == 3) THEN
1670  IF (PRESENT(dist1)) THEN
1671  IF (any(map1_2d == 1)) THEN
1672  i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1673  CALL get_ith_array(dist1, i, nd_dist_1)
1674  END IF
1675  END IF
1676 
1677  IF (PRESENT(dist2)) THEN
1678  IF (any(map2_2d == 1)) THEN
1679  i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1680  CALL get_ith_array(dist2, i, nd_dist_1)
1681  END IF
1682  END IF
1683 
1684  IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1685  ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1686  CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1687  END IF
1688  IF (PRESENT(dist1)) THEN
1689  IF (any(map1_2d == 2)) THEN
1690  i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1691  CALL get_ith_array(dist1, i, nd_dist_2)
1692  END IF
1693  END IF
1694 
1695  IF (PRESENT(dist2)) THEN
1696  IF (any(map2_2d == 2)) THEN
1697  i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1698  CALL get_ith_array(dist2, i, nd_dist_2)
1699  END IF
1700  END IF
1701 
1702  IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1703  ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1704  CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1705  END IF
1706  IF (PRESENT(dist1)) THEN
1707  IF (any(map1_2d == 3)) THEN
1708  i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1709  CALL get_ith_array(dist1, i, nd_dist_3)
1710  END IF
1711  END IF
1712 
1713  IF (PRESENT(dist2)) THEN
1714  IF (any(map2_2d == 3)) THEN
1715  i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1716  CALL get_ith_array(dist2, i, nd_dist_3)
1717  END IF
1718  END IF
1719 
1720  IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1721  ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1722  CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1723  END IF
1724  CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1725  nd_dist_1, nd_dist_2, nd_dist_3, own_comm=.true.)
1726  END IF
1727  IF (ndims_tensor(tensor_in) == 4) THEN
1728  IF (PRESENT(dist1)) THEN
1729  IF (any(map1_2d == 1)) THEN
1730  i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1731  CALL get_ith_array(dist1, i, nd_dist_1)
1732  END IF
1733  END IF
1734 
1735  IF (PRESENT(dist2)) THEN
1736  IF (any(map2_2d == 1)) THEN
1737  i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1738  CALL get_ith_array(dist2, i, nd_dist_1)
1739  END IF
1740  END IF
1741 
1742  IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1743  ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1744  CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1745  END IF
1746  IF (PRESENT(dist1)) THEN
1747  IF (any(map1_2d == 2)) THEN
1748  i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1749  CALL get_ith_array(dist1, i, nd_dist_2)
1750  END IF
1751  END IF
1752 
1753  IF (PRESENT(dist2)) THEN
1754  IF (any(map2_2d == 2)) THEN
1755  i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1756  CALL get_ith_array(dist2, i, nd_dist_2)
1757  END IF
1758  END IF
1759 
1760  IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1761  ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1762  CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1763  END IF
1764  IF (PRESENT(dist1)) THEN
1765  IF (any(map1_2d == 3)) THEN
1766  i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1767  CALL get_ith_array(dist1, i, nd_dist_3)
1768  END IF
1769  END IF
1770 
1771  IF (PRESENT(dist2)) THEN
1772  IF (any(map2_2d == 3)) THEN
1773  i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1774  CALL get_ith_array(dist2, i, nd_dist_3)
1775  END IF
1776  END IF
1777 
1778  IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1779  ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1780  CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1781  END IF
1782  IF (PRESENT(dist1)) THEN
1783  IF (any(map1_2d == 4)) THEN
1784  i = minloc(map1_2d, dim=1, mask=map1_2d == 4) ! i is location of idim in map1_2d
1785  CALL get_ith_array(dist1, i, nd_dist_4)
1786  END IF
1787  END IF
1788 
1789  IF (PRESENT(dist2)) THEN
1790  IF (any(map2_2d == 4)) THEN
1791  i = minloc(map2_2d, dim=1, mask=map2_2d == 4) ! i is location of idim in map2_2d
1792  CALL get_ith_array(dist2, i, nd_dist_4)
1793  END IF
1794  END IF
1795 
1796  IF (.NOT. ALLOCATED(nd_dist_4)) THEN
1797  ALLOCATE (nd_dist_4(SIZE(blk_sizes_4)))
1798  CALL dbt_default_distvec(SIZE(blk_sizes_4), pdims(4), blk_sizes_4, nd_dist_4)
1799  END IF
1800  CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1801  nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm=.true.)
1802  END IF
1803 
1804  IF (ndims_tensor(tensor_in) == 2) THEN
1805  CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1806  blk_sizes_1, blk_sizes_2)
1807  END IF
1808  IF (ndims_tensor(tensor_in) == 3) THEN
1809  CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1810  blk_sizes_1, blk_sizes_2, blk_sizes_3)
1811  END IF
1812  IF (ndims_tensor(tensor_in) == 4) THEN
1813  CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1814  blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1815  END IF
1816 
1817  IF (PRESENT(nodata)) THEN
1818  nodata_prv = nodata
1819  ELSE
1820  nodata_prv = .false.
1821  END IF
1822 
1823  IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1824  CALL dbt_distribution_destroy(dist)
1825 
1826  CALL timestop(handle)
1827  END SUBROUTINE
1828 
1829 ! **************************************************************************************************
1830 !> \brief Align index with data
1831 !> \param order permutation resulting from alignment
1832 !> \author Patrick Seewald
1833 ! **************************************************************************************************
1834  SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1835  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1836  TYPE(dbt_type), INTENT(OUT) :: tensor_out
1837  INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1838  INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1839  INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1840  INTENT(OUT), OPTIONAL :: order
1841  INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1842  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_align_index'
1843  INTEGER :: handle
1844 
1845  CALL timeset(routinen, handle)
1846 
1847  CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1848  order_prv = dbt_inverse_order([map1_2d, map2_2d])
1849  CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1850 
1851  IF (PRESENT(order)) order = order_prv
1852 
1853  CALL timestop(handle)
1854  END SUBROUTINE
1855 
1856 ! **************************************************************************************************
1857 !> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
1858 !> \author Patrick Seewald
1859 ! **************************************************************************************************
1860  SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1861  TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1862  TYPE(dbt_type), INTENT(OUT) :: tensor_out
1863  INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1864  INTENT(IN) :: order
1865 
1866  TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1867  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_permute_index'
1868  INTEGER :: handle
1869  INTEGER :: ndims
1870 
1871  CALL timeset(routinen, handle)
1872 
1873  ndims = ndims_tensor(tensor_in)
1874 
1875  CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1876  CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1877  CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1878 
1879  tensor_out%matrix_rep => tensor_in%matrix_rep
1880  tensor_out%owns_matrix = .false.
1881 
1882  tensor_out%nd_index = nd_index_rs
1883  tensor_out%nd_index_blk = nd_index_blk_rs
1884  tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1885  IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1886  ALLOCATE (tensor_out%pgrid%tas_split_info, source=tensor_in%pgrid%tas_split_info)
1887  END IF
1888  tensor_out%refcount => tensor_in%refcount
1889  CALL dbt_hold(tensor_out)
1890 
1891  CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1892  CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1893  CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1894  CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1895  ALLOCATE (tensor_out%nblks_local(ndims))
1896  ALLOCATE (tensor_out%nfull_local(ndims))
1897  tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1898  tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1899  tensor_out%name = tensor_in%name
1900  tensor_out%valid = .true.
1901 
1902  IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1903  ALLOCATE (tensor_out%contraction_storage, source=tensor_in%contraction_storage)
1904  CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1905  CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1906  END IF
1907 
1908  CALL timestop(handle)
1909  END SUBROUTINE
1910 
1911 ! **************************************************************************************************
1912 !> \brief Map contraction bounds to bounds referring to tensor indices
1913 !> see dbt_contract for docu of dummy arguments
1914 !> \param bounds_t1 bounds mapped to tensor_1
1915 !> \param bounds_t2 bounds mapped to tensor_2
1916 !> \param do_crop_1 whether tensor 1 should be cropped
1917 !> \param do_crop_2 whether tensor 2 should be cropped
1918 !> \author Patrick Seewald
1919 ! **************************************************************************************************
1920  SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1921  contract_1, notcontract_1, &
1922  contract_2, notcontract_2, &
1923  bounds_t1, bounds_t2, &
1924  bounds_1, bounds_2, bounds_3, &
1925  do_crop_1, do_crop_2)
1926 
1927  TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2
1928  INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, &
1929  notcontract_1, notcontract_2
1930  INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1931  INTENT(OUT) :: bounds_t1
1932  INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1933  INTENT(OUT) :: bounds_t2
1934  INTEGER, DIMENSION(2, SIZE(contract_1)), &
1935  INTENT(IN), OPTIONAL :: bounds_1
1936  INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1937  INTENT(IN), OPTIONAL :: bounds_2
1938  INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1939  INTENT(IN), OPTIONAL :: bounds_3
1940  LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2
1941  LOGICAL, DIMENSION(2) :: do_crop
1942 
1943  do_crop = .false.
1944 
1945  bounds_t1(1, :) = 1
1946  CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1947 
1948  bounds_t2(1, :) = 1
1949  CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1950 
1951  IF (PRESENT(bounds_1)) THEN
1952  bounds_t1(:, contract_1) = bounds_1
1953  do_crop(1) = .true.
1954  bounds_t2(:, contract_2) = bounds_1
1955  do_crop(2) = .true.
1956  END IF
1957 
1958  IF (PRESENT(bounds_2)) THEN
1959  bounds_t1(:, notcontract_1) = bounds_2
1960  do_crop(1) = .true.
1961  END IF
1962 
1963  IF (PRESENT(bounds_3)) THEN
1964  bounds_t2(:, notcontract_2) = bounds_3
1965  do_crop(2) = .true.
1966  END IF
1967 
1968  IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1969  IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1970 
1971  END SUBROUTINE
1972 
1973 ! **************************************************************************************************
1974 !> \brief print tensor contraction indices in a human readable way
1975 !> \param indchar1 characters printed for index of tensor 1
1976 !> \param indchar2 characters printed for index of tensor 2
1977 !> \param indchar3 characters printed for index of tensor 3
1978 !> \param unit_nr output unit
1979 !> \author Patrick Seewald
1980 ! **************************************************************************************************
1981  SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1982  TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1983  CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1984  CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1985  CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1986  INTEGER, INTENT(IN) :: unit_nr
1987  INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1988  INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1989  INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1990  INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1991  INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1992  INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1993  INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1994 
1995  unit_nr_prv = prep_output_unit(unit_nr)
1996 
1997  IF (unit_nr_prv /= 0) THEN
1998  CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1999  CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
2000  CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
2001  END IF
2002 
2003  IF (unit_nr_prv > 0) THEN
2004  WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
2005  WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
2006  DO ichar1 = 1, SIZE(indchar1)
2007  WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
2008  END DO
2009  WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2010  DO ichar2 = 1, SIZE(indchar2)
2011  WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
2012  END DO
2013  WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2014  DO ichar3 = 1, SIZE(indchar3)
2015  WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
2016  END DO
2017  WRITE (unit_nr_prv, '(A)') ")"
2018 
2019  WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
2020  DO ichar1 = 1, SIZE(map11)
2021  WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
2022  END DO
2023  WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2024  DO ichar1 = 1, SIZE(map12)
2025  WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
2026  END DO
2027  WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2028  DO ichar2 = 1, SIZE(map21)
2029  WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
2030  END DO
2031  WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2032  DO ichar2 = 1, SIZE(map22)
2033  WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
2034  END DO
2035  WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2036  DO ichar3 = 1, SIZE(map31)
2037  WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
2038  END DO
2039  WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2040  DO ichar3 = 1, SIZE(map32)
2041  WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
2042  END DO
2043  WRITE (unit_nr_prv, '(A)') ")"
2044  END IF
2045 
2046  END SUBROUTINE
2047 
2048 ! **************************************************************************************************
2049 !> \brief Initialize batched contraction for this tensor.
2050 !>
2051 !> Explanation: A batched contraction is a contraction performed in several consecutive steps
2052 !> by specification of bounds in dbt_contract. This can be used to reduce memory by
2053 !> a large factor. The routines dbt_batched_contract_init and
2054 !> dbt_batched_contract_finalize should be called to define the scope of a batched
2055 !> contraction as this enables important optimizations (adapting communication scheme to
2056 !> batches and adapting process grid to multiplication algorithm). The routines
2057 !> dbt_batched_contract_init and dbt_batched_contract_finalize must be
2058 !> called before the first and after the last contraction step on all 3 tensors.
2059 !>
2060 !> Requirements:
2061 !> - the tensors are in a compatible matrix layout (see documentation of
2062 !> `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
2063 !> disabled and a warning is issued.
2064 !> - within the scope of a batched contraction, it is not allowed to access or change tensor
2065 !> data except by calling the routines dbt_contract & dbt_copy.
2066 !> - the bounds affecting indices of the smallest tensor must not change in the course of a
2067 !> batched contraction (todo: get rid of this requirement).
2068 !>
2069 !> Side effects:
2070 !> - the parallel layout (process grid and distribution) of all tensors may change. In order
2071 !> to disable the process grid optimization including this side effect, call this routine
2072 !> only on the smallest of the 3 tensors.
2073 !>
2074 !> \note
2075 !> Note 1: for an example of batched contraction see `examples/dbt_example.F`.
2076 !> (todo: the example is outdated and should be updated).
2077 !>
2078 !> Note 2: it is meaningful to use this feature if the contraction consists of one batch only
2079 !> but if multiple contractions involving the same 3 tensors are performed
2080 !> (batched_contract_init and batched_contract_finalize must then be called before/after each
2081 !> contraction call). The process grid is then optimized after the first contraction
2082 !> and future contraction may profit from this optimization.
2083 !>
2084 !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2085 !> a new range. The size should be the number of ranges plus one, the last
2086 !> element being the block index plus one of the last block in the last range.
2087 !> For internal load balancing optimizations, optionally specify the index
2088 !> ranges of batched contraction.
2089 !> \author Patrick Seewald
2090 ! **************************************************************************************************
2091  SUBROUTINE dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2092  TYPE(dbt_type), INTENT(INOUT) :: tensor
2093  INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2094  INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
2095  INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_prv_1, batch_range_prv_2, batch_range_prv_3,&
2096  & batch_range_prv_4
2097  LOGICAL :: static_range
2098 
2099  CALL dbt_get_info(tensor, nblks_total=tdims)
2100 
2101  static_range = .true.
2102  IF (ndims_tensor(tensor) >= 1) THEN
2103  IF (PRESENT(batch_range_1)) THEN
2104  ALLOCATE (batch_range_prv_1, source=batch_range_1)
2105  static_range = .false.
2106  ELSE
2107  ALLOCATE (batch_range_prv_1(2))
2108  batch_range_prv_1(1) = 1
2109  batch_range_prv_1(2) = tdims(1) + 1
2110  END IF
2111  END IF
2112  IF (ndims_tensor(tensor) >= 2) THEN
2113  IF (PRESENT(batch_range_2)) THEN
2114  ALLOCATE (batch_range_prv_2, source=batch_range_2)
2115  static_range = .false.
2116  ELSE
2117  ALLOCATE (batch_range_prv_2(2))
2118  batch_range_prv_2(1) = 1
2119  batch_range_prv_2(2) = tdims(2) + 1
2120  END IF
2121  END IF
2122  IF (ndims_tensor(tensor) >= 3) THEN
2123  IF (PRESENT(batch_range_3)) THEN
2124  ALLOCATE (batch_range_prv_3, source=batch_range_3)
2125  static_range = .false.
2126  ELSE
2127  ALLOCATE (batch_range_prv_3(2))
2128  batch_range_prv_3(1) = 1
2129  batch_range_prv_3(2) = tdims(3) + 1
2130  END IF
2131  END IF
2132  IF (ndims_tensor(tensor) >= 4) THEN
2133  IF (PRESENT(batch_range_4)) THEN
2134  ALLOCATE (batch_range_prv_4, source=batch_range_4)
2135  static_range = .false.
2136  ELSE
2137  ALLOCATE (batch_range_prv_4(2))
2138  batch_range_prv_4(1) = 1
2139  batch_range_prv_4(2) = tdims(4) + 1
2140  END IF
2141  END IF
2142 
2143  ALLOCATE (tensor%contraction_storage)
2144  tensor%contraction_storage%static = static_range
2145  IF (static_range) THEN
2146  CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
2147  END IF
2148  tensor%contraction_storage%nsplit_avg = 0.0_dp
2149  tensor%contraction_storage%ibatch = 0
2150 
2151  IF (ndims_tensor(tensor) == 1) THEN
2152  CALL create_array_list(tensor%contraction_storage%batch_ranges, 1, &
2153  batch_range_prv_1)
2154  END IF
2155  IF (ndims_tensor(tensor) == 2) THEN
2156  CALL create_array_list(tensor%contraction_storage%batch_ranges, 2, &
2157  batch_range_prv_1, batch_range_prv_2)
2158  END IF
2159  IF (ndims_tensor(tensor) == 3) THEN
2160  CALL create_array_list(tensor%contraction_storage%batch_ranges, 3, &
2161  batch_range_prv_1, batch_range_prv_2, batch_range_prv_3)
2162  END IF
2163  IF (ndims_tensor(tensor) == 4) THEN
2164  CALL create_array_list(tensor%contraction_storage%batch_ranges, 4, &
2165  batch_range_prv_1, batch_range_prv_2, batch_range_prv_3, batch_range_prv_4)
2166  END IF
2167 
2168  END SUBROUTINE
2169 
2170 ! **************************************************************************************************
2171 !> \brief finalize batched contraction. This performs all communication that has been postponed in
2172 !> the contraction calls.
2173 !> \author Patrick Seewald
2174 ! **************************************************************************************************
2175  SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
2176  TYPE(dbt_type), INTENT(INOUT) :: tensor
2177  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2178  LOGICAL :: do_write
2179  INTEGER :: unit_nr_prv, handle
2180 
2181  CALL tensor%pgrid%mp_comm_2d%sync()
2182  CALL timeset("dbt_total", handle)
2183  unit_nr_prv = prep_output_unit(unit_nr)
2184 
2185  do_write = .false.
2186 
2187  IF (tensor%contraction_storage%static) THEN
2188  IF (tensor%matrix_rep%do_batched > 0) THEN
2189  IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .true.
2190  END IF
2191  CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
2192  END IF
2193 
2194  IF (do_write .AND. unit_nr_prv /= 0) THEN
2195  IF (unit_nr_prv > 0) THEN
2196  WRITE (unit_nr_prv, "(T2,A)") &
2197  "FINALIZING BATCHED PROCESSING OF MATMUL"
2198  END IF
2199  CALL dbt_write_tensor_info(tensor, unit_nr_prv)
2200  CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
2201  END IF
2202 
2203  CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2204  DEALLOCATE (tensor%contraction_storage)
2205  CALL tensor%pgrid%mp_comm_2d%sync()
2206  CALL timestop(handle)
2207 
2208  END SUBROUTINE
2209 
2210 ! **************************************************************************************************
2211 !> \brief change the process grid of a tensor
2212 !> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
2213 !> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2214 !> a new range. The size should be the number of ranges plus one, the last
2215 !> element being the block index plus one of the last block in the last range.
2216 !> For internal load balancing optimizations, optionally specify the index
2217 !> ranges of batched contraction.
2218 !> \author Patrick Seewald
2219 ! **************************************************************************************************
2220  SUBROUTINE dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2221  nodata, pgrid_changed, unit_nr)
2222  TYPE(dbt_type), INTENT(INOUT) :: tensor
2223  TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
2224  INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2225  !!
2226  LOGICAL, INTENT(IN), OPTIONAL :: nodata
2227  LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2228  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2229  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_change_pgrid'
2230  CHARACTER(default_string_length) :: name
2231  INTEGER :: handle
2232  INTEGER, ALLOCATABLE, DIMENSION(:) :: bs_1, bs_2, bs_3, bs_4, &
2233  dist_1, dist_2, dist_3, dist_4
2234  INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2235  tdims
2236  TYPE(dbt_type) :: t_tmp
2237  TYPE(dbt_distribution_type) :: dist
2238  INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2239  INTEGER, &
2240  DIMENSION(ndims_matrix_column(tensor)) :: map2
2241  LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware
2242  INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2243  INTEGER :: ind1, ind2, batch_size, ibatch
2244 
2245  IF (PRESENT(pgrid_changed)) pgrid_changed = .false.
2246  CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2247  CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2248 
2249  IF (all(pdims == pdims_ref)) THEN
2250  IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2251  IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2252  RETURN
2253  END IF
2254  END IF
2255  END IF
2256 
2257  CALL timeset(routinen, handle)
2258 
2259  IF (ndims_tensor(tensor) >= 1) THEN
2260  mem_aware(1) = PRESENT(batch_range_1)
2261  IF (mem_aware(1)) nbatch(1) = SIZE(batch_range_1) - 1
2262  END IF
2263  IF (ndims_tensor(tensor) >= 2) THEN
2264  mem_aware(2) = PRESENT(batch_range_2)
2265  IF (mem_aware(2)) nbatch(2) = SIZE(batch_range_2) - 1
2266  END IF
2267  IF (ndims_tensor(tensor) >= 3) THEN
2268  mem_aware(3) = PRESENT(batch_range_3)
2269  IF (mem_aware(3)) nbatch(3) = SIZE(batch_range_3) - 1
2270  END IF
2271  IF (ndims_tensor(tensor) >= 4) THEN
2272  mem_aware(4) = PRESENT(batch_range_4)
2273  IF (mem_aware(4)) nbatch(4) = SIZE(batch_range_4) - 1
2274  END IF
2275 
2276  CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
2277 
2278  IF (ndims_tensor(tensor) >= 1) THEN
2279  ALLOCATE (bs_1(dbt_nblks_total(tensor, 1)))
2280  CALL get_ith_array(tensor%blk_sizes, 1, bs_1)
2281  ALLOCATE (dist_1(tdims(1)))
2282  dist_1 = 0
2283  IF (mem_aware(1)) THEN
2284  DO ibatch = 1, nbatch(1)
2285  ind1 = batch_range_1(ibatch)
2286  ind2 = batch_range_1(ibatch + 1) - 1
2287  batch_size = ind2 - ind1 + 1
2288  CALL dbt_default_distvec(batch_size, pdims(1), &
2289  bs_1(ind1:ind2), dist_1(ind1:ind2))
2290  END DO
2291  ELSE
2292  CALL dbt_default_distvec(tdims(1), pdims(1), bs_1, dist_1)
2293  END IF
2294  END IF
2295  IF (ndims_tensor(tensor) >= 2) THEN
2296  ALLOCATE (bs_2(dbt_nblks_total(tensor, 2)))
2297  CALL get_ith_array(tensor%blk_sizes, 2, bs_2)
2298  ALLOCATE (dist_2(tdims(2)))
2299  dist_2 = 0
2300  IF (mem_aware(2)) THEN
2301  DO ibatch = 1, nbatch(2)
2302  ind1 = batch_range_2(ibatch)
2303  ind2 = batch_range_2(ibatch + 1) - 1
2304  batch_size = ind2 - ind1 + 1
2305  CALL dbt_default_distvec(batch_size, pdims(2), &
2306  bs_2(ind1:ind2), dist_2(ind1:ind2))
2307  END DO
2308  ELSE
2309  CALL dbt_default_distvec(tdims(2), pdims(2), bs_2, dist_2)
2310  END IF
2311  END IF
2312  IF (ndims_tensor(tensor) >= 3) THEN
2313  ALLOCATE (bs_3(dbt_nblks_total(tensor, 3)))
2314  CALL get_ith_array(tensor%blk_sizes, 3, bs_3)
2315  ALLOCATE (dist_3(tdims(3)))
2316  dist_3 = 0
2317  IF (mem_aware(3)) THEN
2318  DO ibatch = 1, nbatch(3)
2319  ind1 = batch_range_3(ibatch)
2320  ind2 = batch_range_3(ibatch + 1) - 1
2321  batch_size = ind2 - ind1 + 1
2322  CALL dbt_default_distvec(batch_size, pdims(3), &
2323  bs_3(ind1:ind2), dist_3(ind1:ind2))
2324  END DO
2325  ELSE
2326  CALL dbt_default_distvec(tdims(3), pdims(3), bs_3, dist_3)
2327  END IF
2328  END IF
2329  IF (ndims_tensor(tensor) >= 4) THEN
2330  ALLOCATE (bs_4(dbt_nblks_total(tensor, 4)))
2331  CALL get_ith_array(tensor%blk_sizes, 4, bs_4)
2332  ALLOCATE (dist_4(tdims(4)))
2333  dist_4 = 0
2334  IF (mem_aware(4)) THEN
2335  DO ibatch = 1, nbatch(4)
2336  ind1 = batch_range_4(ibatch)
2337  ind2 = batch_range_4(ibatch + 1) - 1
2338  batch_size = ind2 - ind1 + 1
2339  CALL dbt_default_distvec(batch_size, pdims(4), &
2340  bs_4(ind1:ind2), dist_4(ind1:ind2))
2341  END DO
2342  ELSE
2343  CALL dbt_default_distvec(tdims(4), pdims(4), bs_4, dist_4)
2344  END IF
2345  END IF
2346 
2347  CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2348  IF (ndims_tensor(tensor) == 2) THEN
2349  CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2)
2350  CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2)
2351  END IF
2352  IF (ndims_tensor(tensor) == 3) THEN
2353  CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3)
2354  CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3)
2355  END IF
2356  IF (ndims_tensor(tensor) == 4) THEN
2357  CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3, dist_4)
2358  CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3, bs_4)
2359  END IF
2360  CALL dbt_distribution_destroy(dist)
2361 
2362  IF (PRESENT(nodata)) THEN
2363  IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2364  ELSE
2365  CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2366  END IF
2367 
2368  CALL dbt_copy_contraction_storage(tensor, t_tmp)
2369 
2370  CALL dbt_destroy(tensor)
2371  tensor = t_tmp
2372 
2373  IF (PRESENT(unit_nr)) THEN
2374  IF (unit_nr > 0) THEN
2375  WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", trim(tensor%name)
2376  WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2377  CALL dbt_write_split_info(pgrid, unit_nr)
2378  END IF
2379  END IF
2380 
2381  IF (PRESENT(pgrid_changed)) pgrid_changed = .true.
2382 
2383  CALL timestop(handle)
2384  END SUBROUTINE
2385 
2386 ! **************************************************************************************************
2387 !> \brief map tensor to a new 2d process grid for the matrix representation.
2388 !> \author Patrick Seewald
2389 ! **************************************************************************************************
2390  SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2391  TYPE(dbt_type), INTENT(INOUT) :: tensor
2392  TYPE(mp_cart_type), INTENT(IN) :: mp_comm
2393  INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2394  LOGICAL, INTENT(IN), OPTIONAL :: nodata
2395  INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2396  LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2397  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2398  INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2399  INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2400  INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2401  TYPE(dbt_pgrid_type) :: pgrid
2402  INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2403  INTEGER, DIMENSION(:), ALLOCATABLE :: array
2404  INTEGER :: idim
2405 
2406  CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2407  CALL blk_dims_tensor(tensor, dims)
2408 
2409  IF (ALLOCATED(tensor%contraction_storage)) THEN
2410  associate(batch_ranges => tensor%contraction_storage%batch_ranges)
2411  nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2412  ! for good load balancing the process grid dimensions should be chosen adapted to the
2413  ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2414  ! the number of batches (number of index ranges).
2415  DO idim = 1, ndims_tensor(tensor)
2416  CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2417  dims(idim) = array(nbatches(idim) + 1) - array(1)
2418  DEALLOCATE (array)
2419  dims(idim) = dims(idim)/nbatches(idim)
2420  IF (dims(idim) <= 0) dims(idim) = 1
2421  END DO
2422  END associate
2423  END IF
2424 
2425  pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2426  IF (ALLOCATED(tensor%contraction_storage)) THEN
2427  IF (ndims_tensor(tensor) == 1) THEN
2428  CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1)
2429  CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, &
2430  nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2431  END IF
2432  IF (ndims_tensor(tensor) == 2) THEN
2433  CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2)
2434  CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, &
2435  nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2436  END IF
2437  IF (ndims_tensor(tensor) == 3) THEN
2438  CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3)
2439  CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, &
2440  nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2441  END IF
2442  IF (ndims_tensor(tensor) == 4) THEN
2443  CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2444  CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2445  nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2446  END IF
2447  ELSE
2448  CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2449  END IF
2450  CALL dbt_pgrid_destroy(pgrid)
2451 
2452  END SUBROUTINE
2453 
2454 END MODULE
struct tensor_ tensor
Definition: dbm_api.F:8
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Definition: dbm_api.F:529
Wrapper for allocating, copying and reshaping arrays.
Representation of arbitrary number of 1d integer arrays with arbitrary sizes. This is needed for gene...
pure logical function, public array_eq_i(arr1, arr2)
check whether two arrays are equal
subroutine, public get_arrays(list, data_1, data_2, data_3, data_4, i_selected)
Get all arrays contained in list.
subroutine, public create_array_list(list, ndata, data_1, data_2, data_3, data_4)
collects any number of arrays of different sizes into a single array (listcol_data),...
subroutine, public destroy_array_list(list)
destroy array list.
integer function, dimension(:), allocatable, public sizes_of_arrays(list)
sizes of arrays stored in list
subroutine, public get_ith_array(list, i, array_size, array)
get ith array
type(array_list) function, public array_sublist(list, i_selected)
extract a subset of arrays
subroutine, public reorder_arrays(list_in, list_out, order)
reorder array list.
logical function, public check_equal(list1, list2)
check whether two array lists are equal
Methods to operate on n-dimensional tensor blocks.
Definition: dbt_block.F:12
elemental logical function, public checker_tr(row, column)
Determines whether a transpose must be applied.
Definition: dbt_block.F:453
logical function, public dbt_iterator_blocks_left(iterator)
Generalization of block_iterator_blocks_left for tensors.
Definition: dbt_block.F:197
subroutine, public destroy_block(block)
Definition: dbt_block.F:435
pure integer function, public ndims_iterator(iterator)
Number of dimensions.
Definition: dbt_block.F:146
subroutine, public dbt_iterator_stop(iterator)
Generalization of block_iterator_stop for tensors.
Definition: dbt_block.F:134
subroutine, public dbt_iterator_start(iterator, tensor)
Generalization of block_iterator_start for tensors.
Definition: dbt_block.F:121
subroutine, public dbt_iterator_next_block(iterator, ind_nd, blk_size, blk_offset)
iterate over nd blocks of an nd rank tensor, index only (blocks must be retrieved by calling dbt_get_...
Definition: dbt_block.F:161
tensor index and mapping to DBM index
Definition: dbt_index.F:12
pure integer function, public ndims_mapping_row(map)
how many tensor dimensions are mapped to matrix row
Definition: dbt_index.F:141
pure integer function, dimension(size(order)), public dbt_inverse_order(order)
Invert order.
Definition: dbt_index.F:410
pure integer function, public ndims_mapping(map)
Definition: dbt_index.F:130
subroutine, public permute_index(map_in, map_out, order)
reorder tensor index (no data)
Definition: dbt_index.F:423
pure subroutine, public dbt_get_mapping_info(map, ndim_nd, ndim1_2d, ndim2_2d, dims_2d_i8, dims_2d, dims_nd, dims1_2d, dims2_2d, map1_2d, map2_2d, map_nd, base, col_major)
get mapping info
Definition: dbt_index.F:176
pure integer function, public ndims_mapping_column(map)
how many tensor dimensions are mapped to matrix column
Definition: dbt_index.F:151
pure integer function, dimension(map%ndim_nd), public get_nd_indices_tensor(map, ind_in)
transform 2d index to nd index, using info from index mapping.
Definition: dbt_index.F:368
DBT tensor Input / Output.
Definition: dbt_io.F:12
subroutine, public dbt_write_tensor_info(tensor, unit_nr, full_info)
Write tensor global info: block dimensions, full dimensions and process grid dimensions.
Definition: dbt_io.F:50
subroutine, public dbt_write_tensor_dist(tensor, unit_nr)
Write info on tensor distribution & load balance.
Definition: dbt_io.F:161
subroutine, public dbt_write_split_info(pgrid, unit_nr)
Definition: dbt_io.F:401
integer function, public prep_output_unit(unit_nr)
Definition: dbt_io.F:413
DBT tensor framework for block-sparse tensor contraction. Representation of n-rank tensors as DBT tal...
Definition: dbt_methods.F:16
subroutine, public dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
Copy tensor data. Redistributes tensor data according to distributions of target and source tensor....
Definition: dbt_methods.F:114
subroutine, public dbt_batched_contract_finalize(tensor, unit_nr)
finalize batched contraction. This performs all communication that has been postponed in the contract...
Definition: dbt_methods.F:2176
subroutine, public dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
copy matrix to tensor.
Definition: dbt_methods.F:325
subroutine, public dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, contract_1, notcontract_1, contract_2, notcontract_2, map_1, map_2, bounds_1, bounds_2, bounds_3, optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
Contract tensors by multiplying matrix representations. tensor_3(map_1, map_2) := alpha * tensor_1(no...
Definition: dbt_methods.F:499
subroutine, public dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
copy tensor to matrix
Definition: dbt_methods.F:386
subroutine, public dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
Initialize batched contraction for this tensor.
Definition: dbt_methods.F:2092
Routines to reshape / redistribute tensors.
subroutine, public dbt_reshape(tensor_in, tensor_out, summation, move_data)
copy data (involves reshape) tensor_out = tensor_out + tensor_in move_data memory optimization: trans...
Routines to split blocks and to convert between tensors with different block sizes.
Definition: dbt_split.F:12
subroutine, public dbt_split_copyback(tensor_split_in, tensor_out, summation)
Copy tensor with split blocks to tensor with original block sizes.
Definition: dbt_split.F:527
subroutine, public dbt_make_compatible_blocks(tensor1, tensor2, tensor1_split, tensor2_split, order, nodata1, nodata2, move_data)
split two tensors with same total sizes but different block sizes such that they have equal block siz...
Definition: dbt_split.F:788
subroutine, public dbt_crop(tensor_in, tensor_out, bounds, move_data)
Definition: dbt_split.F:934
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
Definition: dbt_tas_base.F:13
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
Definition: dbt_tas_base.F:999
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
Definition: dbt_tas_base.F:250
subroutine, public dbt_tas_finalize(matrix)
...
Definition: dbt_tas_base.F:327
type(dbt_tas_split_info) function, public dbt_tas_info(matrix)
get info on mpi grid splitting
Definition: dbt_tas_base.F:822
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
Definition: dbt_tas_mm.F:18
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
Definition: dbt_tas_mm.F:1699
subroutine, public dbt_tas_batched_mm_init(matrix)
...
Definition: dbt_tas_mm.F:1650
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
Definition: dbt_tas_mm.F:1663
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
Definition: dbt_tas_mm.F:105
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
Definition: dbt_tas_mm.F:1733
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
Definition: dbt_tas_split.F:13
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
Definition: dbt_tas_split.F:50
integer, parameter, public colsplit
Definition: dbt_tas_split.F:50
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
type(mp_cart_type) function, public dbt_tas_mp_comm(mp_comm, split_rowcol, nsplit)
Create default cartesian process grid that is consistent with default split heuristic of dbt_tas_crea...
real(dp), parameter, public default_nsplit_accept_ratio
Definition: dbt_tas_split.F:52
real(dp), parameter, public default_pdims_accept_ratio
Definition: dbt_tas_split.F:51
subroutine, public dbt_tas_info_hold(split_info)
...
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
Definition: dbt_tas_types.F:13
DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
Definition: dbt_types.F:12
subroutine, public dbt_pgrid_destroy(pgrid, keep_comm)
destroy process grid
Definition: dbt_types.F:905
subroutine, public dbt_distribution_new(dist, pgrid, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4)
Create a tensor distribution.
Definition: dbt_types.F:886
subroutine, public blk_dims_tensor(tensor, dims)
tensor block dimensions
Definition: dbt_types.F:1466
subroutine, public dims_tensor(tensor, dims)
tensor dimensions
Definition: dbt_types.F:1238
subroutine, public dbt_copy_contraction_storage(tensor_in, tensor_out)
Definition: dbt_types.F:1888
type(dbt_pgrid_type) function, public dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, nsplit, dimsplit)
Create a default nd process topology that is consistent with a given 2d topology. Purpose: a nd tenso...
Definition: dbt_types.F:653
subroutine, public dbt_destroy(tensor)
Destroy a tensor.
Definition: dbt_types.F:1410
pure integer function, public dbt_max_nblks_local(tensor)
returns an estimate of maximum number of local blocks in tensor (irrespective of the actual number of...
Definition: dbt_types.F:1850
subroutine, public dbt_get_info(tensor, nblks_total, nfull_total, nblks_local, nfull_local, pdims, my_ploc, blks_local_1, blks_local_2, blks_local_3, blks_local_4, proc_dist_1, proc_dist_2, proc_dist_3, proc_dist_4, blk_size_1, blk_size_2, blk_size_3, blk_size_4, blk_offset_1, blk_offset_2, blk_offset_3, blk_offset_4, distribution, name)
As block_get_info but for tensors.
Definition: dbt_types.F:1656
subroutine, public dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm)
Create a tensor distribution.
Definition: dbt_types.F:787
type(dbt_distribution_type) function, public dbt_distribution(tensor)
get distribution from tensor
Definition: dbt_types.F:980
pure integer function, public ndims_tensor(tensor)
tensor rank
Definition: dbt_types.F:1227
pure integer function, public dbt_nblks_total(tensor, idim)
total numbers of blocks along dimension idim
Definition: dbt_types.F:1617
pure integer function, public dbt_get_num_blocks(tensor)
As block_get_num_blocks: get number of local blocks.
Definition: dbt_types.F:1759
subroutine, public dbt_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
Definition: dbt_types.F:1876
subroutine, public dbt_hold(tensor)
reference counting for tensors (only needed for communicator handle that must be freed when no longer...
Definition: dbt_types.F:1188
subroutine, public dbt_clear(tensor)
Clear tensor (s.t. it does not contain any blocks)
Definition: dbt_types.F:1779
subroutine, public dbt_finalize(tensor)
Finalize tensor, as block_finalize. This should be taken care of internally in DBT tensors,...
Definition: dbt_types.F:1790
subroutine, public mp_environ_pgrid(pgrid, dims, task_coor)
as mp_environ but for special pgrid type
Definition: dbt_types.F:768
subroutine, public dbt_get_stored_coordinates(tensor, ind_nd, processor)
Generalization of block_get_stored_coordinates for tensors.
Definition: dbt_types.F:1510
integer(kind=int_8) function, public dbt_get_num_blocks_total(tensor)
Get total number of blocks.
Definition: dbt_types.F:1769
pure integer(int_8) function, public ndims_matrix_row(tensor)
how many tensor dimensions are mapped to matrix row
Definition: dbt_types.F:1204
pure integer(int_8) function, public ndims_matrix_column(tensor)
how many tensor dimensions are mapped to matrix column
Definition: dbt_types.F:1216
subroutine, public dbt_filter(tensor, eps)
As block_filter.
Definition: dbt_types.F:1588
subroutine, public dbt_distribution_destroy(dist)
Destroy tensor distribution.
Definition: dbt_types.F:926
subroutine, public dbt_scale(tensor, alpha)
as block_scale
Definition: dbt_types.F:1799
Defines the basic variable types.
Definition: kinds.F:23
integer, parameter, public int_8
Definition: kinds.F:54
integer, parameter, public dp
Definition: kinds.F:34
integer, parameter, public default_string_length
Definition: kinds.F:57
Interface to the message passing library MPI.
All kind of helpful little routines.
Definition: util.F:14