(git:ed6f26b)
Loading...
Searching...
No Matches
dbt_methods.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief DBT tensor framework for block-sparse tensor contraction.
10!> Representation of n-rank tensors as DBT tall-and-skinny matrices.
11!> Support for arbitrary redistribution between different representations.
12!> Support for arbitrary tensor contractions
13!> \todo implement checks and error messages
14!> \author Patrick Seewald
15! **************************************************************************************************
17
18
19 USE cp_dbcsr_api, ONLY: &
23 USE dbt_allocate_wrap, ONLY: &
25 USE dbt_array_list_methods, ONLY: &
28 USE dbm_api, ONLY: &
30 USE dbt_tas_types, ONLY: &
32 USE dbt_tas_base, ONLY: &
34 USE dbt_tas_mm, ONLY: &
37 USE dbt_block, ONLY: &
41 USE dbt_index, ONLY: &
44 USE dbt_types, ONLY: &
53 USE kinds, ONLY: &
55 USE message_passing, ONLY: &
57 USE util, ONLY: &
58 sort
59 USE dbt_reshape_ops, ONLY: &
61 USE dbt_tas_split, ONLY: &
64 USE dbt_split, ONLY: &
66 USE dbt_io, ONLY: &
69
70#include "../base/base_uses.f90"
71
72 IMPLICIT NONE
73 PRIVATE
74 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
75
76 PUBLIC :: &
78 dbt_copy, &
93
94CONTAINS
95
96! **************************************************************************************************
97!> \brief Copy tensor data.
98!> Redistributes tensor data according to distributions of target and source tensor.
99!> Permutes tensor index according to `order` argument (if present).
100!> Source and target tensor formats are arbitrary as long as the following requirements are met:
101!> * source and target tensors have the same rank and the same sizes in each dimension in terms
102!> of tensor elements (block sizes don't need to be the same).
103!> If `order` argument is present, sizes must match after index permutation.
104!> OR
105!> * target tensor is not yet created, in this case an exact copy of source tensor is returned.
106!> \param tensor_in Source
107!> \param tensor_out Target
108!> \param order Permutation of target tensor index.
109!> Exact same convention as order argument of RESHAPE intrinsic.
110!> \param bounds crop tensor data: start and end index for each tensor dimension
111!> \author Patrick Seewald
112! **************************************************************************************************
113 SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
114 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
115 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
116 INTENT(IN), OPTIONAL :: order
117 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
118 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
119 INTENT(IN), OPTIONAL :: bounds
120 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
121 INTEGER :: handle
122
123 CALL tensor_in%pgrid%mp_comm_2d%sync()
124 CALL timeset("dbt_total", handle)
125
126 ! make sure that it is safe to use dbt_copy during a batched contraction
127 CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.true.)
128 CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.true.)
129
130 CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
131 CALL tensor_in%pgrid%mp_comm_2d%sync()
132 CALL timestop(handle)
133 END SUBROUTINE
134
135! **************************************************************************************************
136!> \brief expert routine for copying a tensor. For internal use only.
137!> \author Patrick Seewald
138! **************************************************************************************************
139 SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
140 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
141 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
142 INTENT(IN), OPTIONAL :: order
143 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
144 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
145 INTENT(IN), OPTIONAL :: bounds
146 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
147
148 TYPE(dbt_type), POINTER :: in_tmp_1, in_tmp_2, &
149 in_tmp_3, out_tmp_1
150 INTEGER :: handle, unit_nr_prv
151 INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
152
153 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy'
154 LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
155 summation_prv, new_in_1, new_in_2, &
156 new_in_3, new_out_1, block_compatible, &
157 move_prv
158 TYPE(array_list) :: blk_sizes_in
159
160 CALL timeset(routinen, handle)
161
162 cpassert(tensor_out%valid)
163
164 unit_nr_prv = prep_output_unit(unit_nr)
165
166 IF (PRESENT(move_data)) THEN
167 move_prv = move_data
168 ELSE
169 move_prv = .false.
170 END IF
171
172 dist_compatible_tas = .false.
173 dist_compatible_tensor = .false.
174 block_compatible = .false.
175 new_in_1 = .false.
176 new_in_2 = .false.
177 new_in_3 = .false.
178 new_out_1 = .false.
179
180 IF (PRESENT(summation)) THEN
181 summation_prv = summation
182 ELSE
183 summation_prv = .false.
184 END IF
185
186 IF (PRESENT(bounds)) THEN
187 ALLOCATE (in_tmp_1)
188 CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
189 new_in_1 = .true.
190 move_prv = .true.
191 ELSE
192 in_tmp_1 => tensor_in
193 END IF
194
195 IF (PRESENT(order)) THEN
196 CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
197 block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
198 ELSE
199 block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
200 END IF
201
202 IF (.NOT. block_compatible) THEN
203 ALLOCATE (in_tmp_2, out_tmp_1)
204 CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
205 nodata2=.NOT. summation_prv, move_data=move_prv)
206 new_in_2 = .true.; new_out_1 = .true.
207 move_prv = .true.
208 ELSE
209 in_tmp_2 => in_tmp_1
210 out_tmp_1 => tensor_out
211 END IF
212
213 IF (PRESENT(order)) THEN
214 ALLOCATE (in_tmp_3)
215 CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
216 new_in_3 = .true.
217 ELSE
218 in_tmp_3 => in_tmp_2
219 END IF
220
221 ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
222 ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
223 CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
224
225 ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
226 ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
227 CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
228
229 IF (.NOT. PRESENT(order)) THEN
230 IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
231 dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
232 ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
233 dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234 END IF
235 END IF
236
237 IF (dist_compatible_tas) THEN
238 CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
239 IF (move_prv) CALL dbt_clear(in_tmp_3)
240 ELSEIF (dist_compatible_tensor) THEN
241 CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
242 IF (move_prv) CALL dbt_clear(in_tmp_3)
243 ELSE
244 CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
245 END IF
246
247 IF (new_in_1) THEN
248 CALL dbt_destroy(in_tmp_1)
249 DEALLOCATE (in_tmp_1)
250 END IF
251
252 IF (new_in_2) THEN
253 CALL dbt_destroy(in_tmp_2)
254 DEALLOCATE (in_tmp_2)
255 END IF
256
257 IF (new_in_3) THEN
258 CALL dbt_destroy(in_tmp_3)
259 DEALLOCATE (in_tmp_3)
260 END IF
261
262 IF (new_out_1) THEN
263 IF (unit_nr_prv /= 0) THEN
264 CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
265 END IF
266 CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
267 CALL dbt_destroy(out_tmp_1)
268 DEALLOCATE (out_tmp_1)
269 END IF
270
271 CALL timestop(handle)
272
273 END SUBROUTINE
274
275! **************************************************************************************************
276!> \brief copy without communication, requires that both tensors have same process grid and distribution
277!> \param summation Whether to sum matrices b = a + b
278!> \author Patrick Seewald
279! **************************************************************************************************
280 SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
281 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
282 TYPE(dbt_type), INTENT(INOUT) :: tensor_out
283 LOGICAL, INTENT(IN), OPTIONAL :: summation
284 TYPE(dbt_iterator_type) :: iter
285 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
286 TYPE(block_nd) :: blk_data
287 LOGICAL :: found
288
289 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_nocomm'
290 INTEGER :: handle
291
292 CALL timeset(routinen, handle)
293 cpassert(tensor_out%valid)
294
295 IF (PRESENT(summation)) THEN
296 IF (.NOT. summation) CALL dbt_clear(tensor_out)
297 ELSE
298 CALL dbt_clear(tensor_out)
299 END IF
300
301 CALL dbt_reserve_blocks(tensor_in, tensor_out)
302
303!$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
304!$OMP PRIVATE(iter,ind_nd,blk_data,found)
305 CALL dbt_iterator_start(iter, tensor_in)
306 DO WHILE (dbt_iterator_blocks_left(iter))
307 CALL dbt_iterator_next_block(iter, ind_nd)
308 CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
309 cpassert(found)
310 CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
311 CALL destroy_block(blk_data)
312 END DO
313 CALL dbt_iterator_stop(iter)
314!$OMP END PARALLEL
315
316 CALL timestop(handle)
317 END SUBROUTINE
318
319! **************************************************************************************************
320!> \brief copy matrix to tensor.
321!> \param summation tensor_out = tensor_out + matrix_in
322!> \author Patrick Seewald
323! **************************************************************************************************
324 SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
325 TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in
326 TYPE(dbt_type), INTENT(INOUT) :: tensor_out
327 LOGICAL, INTENT(IN), OPTIONAL :: summation
328 TYPE(dbcsr_type), POINTER :: matrix_in_desym
329
330 INTEGER, DIMENSION(2) :: ind_2d
331 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: block_arr
332 REAL(kind=dp), DIMENSION(:, :), POINTER :: block
333 TYPE(dbcsr_iterator_type) :: iter
334
335 INTEGER :: handle
336 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_matrix_to_tensor'
337
338 CALL timeset(routinen, handle)
339 cpassert(tensor_out%valid)
340
341 NULLIFY (block)
342
343 IF (dbcsr_has_symmetry(matrix_in)) THEN
344 ALLOCATE (matrix_in_desym)
345 CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
346 ELSE
347 matrix_in_desym => matrix_in
348 END IF
349
350 IF (PRESENT(summation)) THEN
351 IF (.NOT. summation) CALL dbt_clear(tensor_out)
352 ELSE
353 CALL dbt_clear(tensor_out)
354 END IF
355
356 CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
357
358!$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
359!$OMP PRIVATE(iter,ind_2d,block,block_arr)
360 CALL dbcsr_iterator_start(iter, matrix_in_desym)
361 DO WHILE (dbcsr_iterator_blocks_left(iter))
362 CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block)
363 CALL allocate_any(block_arr, source=block)
364 CALL dbt_put_block(tensor_out, ind_2d, shape(block_arr), block_arr, summation=summation)
365 DEALLOCATE (block_arr)
366 END DO
367 CALL dbcsr_iterator_stop(iter)
368!$OMP END PARALLEL
369
370 IF (dbcsr_has_symmetry(matrix_in)) THEN
371 CALL dbcsr_release(matrix_in_desym)
372 DEALLOCATE (matrix_in_desym)
373 END IF
374
375 CALL timestop(handle)
376
377 END SUBROUTINE
378
379! **************************************************************************************************
380!> \brief copy tensor to matrix
381!> \param summation matrix_out = matrix_out + tensor_in
382!> \author Patrick Seewald
383! **************************************************************************************************
384 SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
385 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
386 TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
387 LOGICAL, INTENT(IN), OPTIONAL :: summation
388 TYPE(dbt_iterator_type) :: iter
389 INTEGER :: handle
390 INTEGER, DIMENSION(2) :: ind_2d
391 REAL(kind=dp), DIMENSION(:, :), ALLOCATABLE :: block
392 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_tensor_to_matrix'
393 LOGICAL :: found
394
395 CALL timeset(routinen, handle)
396
397 IF (PRESENT(summation)) THEN
398 IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
399 ELSE
400 CALL dbcsr_clear(matrix_out)
401 END IF
402
403 CALL dbt_reserve_blocks(tensor_in, matrix_out)
404
405!$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
406!$OMP PRIVATE(iter,ind_2d,block,found)
407 CALL dbt_iterator_start(iter, tensor_in)
408 DO WHILE (dbt_iterator_blocks_left(iter))
409 CALL dbt_iterator_next_block(iter, ind_2d)
410 IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) cycle
411
412 CALL dbt_get_block(tensor_in, ind_2d, block, found)
413 cpassert(found)
414
415 IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
416 CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), transpose(block), summation=summation)
417 ELSE
418 CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
419 END IF
420 DEALLOCATE (block)
421 END DO
422 CALL dbt_iterator_stop(iter)
423!$OMP END PARALLEL
424
425 CALL timestop(handle)
426
427 END SUBROUTINE
428
429! **************************************************************************************************
430!> \brief Contract tensors by multiplying matrix representations.
431!> tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
432!> * tensor_2(contract_2, notcontract_2)
433!> + beta * tensor_3(map_1, map_2)
434!>
435!> \note
436!> note 1: block sizes of the corresponding indices need to be the same in all tensors.
437!>
438!> note 2: for best performance the tensors should have been created in matrix layouts
439!> compatible with the contraction, e.g. tensor_1 should have been created with either
440!> map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
441!> map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
442!> tensor_3 and map_1 / map_2).
443!> Furthermore the two largest tensors involved in the contraction should map both to either
444!> tall or short matrices: the largest matrix dimension should be "on the same side"
445!> and should have identical distribution (which is always the case if the distributions were
446!> obtained with dbt_default_distvec).
447!>
448!> note 3: if the same tensor occurs in multiple contractions, a different tensor object should
449!> be created for each contraction and the data should be copied between the tensors by use of
450!> dbt_copy. If the same tensor object is used in multiple contractions,
451!> matrix layouts are not compatible for all contractions (see note 2).
452!>
453!> note 4: automatic optimizations are enabled by using the feature of batched contraction, see
454!> dbt_batched_contract_init, dbt_batched_contract_finalize.
455!> The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
456!>
457!> \param tensor_1 first tensor (in)
458!> \param tensor_2 second tensor (in)
459!> \param contract_1 indices of tensor_1 to contract
460!> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
461!> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
462!> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
463!> \param notcontract_1 indices of tensor_1 not to contract
464!> \param notcontract_2 indices of tensor_2 not to contract
465!> \param tensor_3 contracted tensor (out)
466!> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
467!> start and end index of an index range over which to contract.
468!> For use in batched contraction.
469!> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
470!> For use in batched contraction.
471!> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
472!> For use in batched contraction.
473!> \param optimize_dist Whether distribution should be optimized internally. In the current
474!> implementation this guarantees optimal parameters only for dense matrices.
475!> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
476!> This can be used to choose optimal process grids for subsequent tensor
477!> contractions with tensors of similar shape and sparsity. Under some conditions,
478!> pgrid_opt_1 can not be returned, in this case the pointer is not associated.
479!> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
480!> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
481!> \param filter_eps As in DBM mm
482!> \param flop As in DBM mm
483!> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
484!> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
485!> \param unit_nr output unit for logging
486!> set it to -1 on ranks that should not write (and any valid unit number on
487!> ranks that should write output) if 0 on ALL ranks, no output is written
488!> \param log_verbose verbose logging (for testing only)
489!> \author Patrick Seewald
490! **************************************************************************************************
491 SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
492 contract_1, notcontract_1, &
493 contract_2, notcontract_2, &
494 map_1, map_2, &
495 bounds_1, bounds_2, bounds_3, &
496 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
497 filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
498 REAL(dp), INTENT(IN) :: alpha
499 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
500 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
501 REAL(dp), INTENT(IN) :: beta
502 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
503 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
504 INTEGER, DIMENSION(:), INTENT(IN) :: map_1
505 INTEGER, DIMENSION(:), INTENT(IN) :: map_2
506 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
507 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
508 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
509 INTEGER, DIMENSION(2, SIZE(contract_1)), &
510 INTENT(IN), OPTIONAL :: bounds_1
511 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
512 INTENT(IN), OPTIONAL :: bounds_2
513 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
514 INTENT(IN), OPTIONAL :: bounds_3
515 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
516 TYPE(dbt_pgrid_type), INTENT(OUT), &
517 POINTER, OPTIONAL :: pgrid_opt_1
518 TYPE(dbt_pgrid_type), INTENT(OUT), &
519 POINTER, OPTIONAL :: pgrid_opt_2
520 TYPE(dbt_pgrid_type), INTENT(OUT), &
521 POINTER, OPTIONAL :: pgrid_opt_3
522 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
523 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
524 LOGICAL, INTENT(IN), OPTIONAL :: move_data
525 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
526 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
527 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
528
529 INTEGER :: handle
530
531 CALL tensor_1%pgrid%mp_comm_2d%sync()
532 CALL timeset("dbt_total", handle)
533 CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
534 contract_1, notcontract_1, &
535 contract_2, notcontract_2, &
536 map_1, map_2, &
537 bounds_1=bounds_1, &
538 bounds_2=bounds_2, &
539 bounds_3=bounds_3, &
540 optimize_dist=optimize_dist, &
541 pgrid_opt_1=pgrid_opt_1, &
542 pgrid_opt_2=pgrid_opt_2, &
543 pgrid_opt_3=pgrid_opt_3, &
544 filter_eps=filter_eps, &
545 flop=flop, &
546 move_data=move_data, &
547 retain_sparsity=retain_sparsity, &
548 unit_nr=unit_nr, &
549 log_verbose=log_verbose)
550 CALL tensor_1%pgrid%mp_comm_2d%sync()
551 CALL timestop(handle)
552
553 END SUBROUTINE
554
555! **************************************************************************************************
556!> \brief expert routine for tensor contraction. For internal use only.
557!> \param nblks_local number of local blocks on this MPI rank
558!> \author Patrick Seewald
559! **************************************************************************************************
560 SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
561 contract_1, notcontract_1, &
562 contract_2, notcontract_2, &
563 map_1, map_2, &
564 bounds_1, bounds_2, bounds_3, &
565 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
566 filter_eps, flop, move_data, retain_sparsity, &
567 nblks_local, unit_nr, log_verbose)
568 REAL(dp), INTENT(IN) :: alpha
569 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
570 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
571 REAL(dp), INTENT(IN) :: beta
572 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
573 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
574 INTEGER, DIMENSION(:), INTENT(IN) :: map_1
575 INTEGER, DIMENSION(:), INTENT(IN) :: map_2
576 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
577 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
578 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
579 INTEGER, DIMENSION(2, SIZE(contract_1)), &
580 INTENT(IN), OPTIONAL :: bounds_1
581 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
582 INTENT(IN), OPTIONAL :: bounds_2
583 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
584 INTENT(IN), OPTIONAL :: bounds_3
585 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
586 TYPE(dbt_pgrid_type), INTENT(OUT), &
587 POINTER, OPTIONAL :: pgrid_opt_1
588 TYPE(dbt_pgrid_type), INTENT(OUT), &
589 POINTER, OPTIONAL :: pgrid_opt_2
590 TYPE(dbt_pgrid_type), INTENT(OUT), &
591 POINTER, OPTIONAL :: pgrid_opt_3
592 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
593 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
594 LOGICAL, INTENT(IN), OPTIONAL :: move_data
595 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
596 INTEGER, INTENT(OUT), OPTIONAL :: nblks_local
597 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
598 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
599
600 TYPE(dbt_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
601 TYPE(dbt_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
602 TYPE(dbt_type), POINTER :: tensor_crop_1, tensor_crop_2
603 TYPE(dbt_type), POINTER :: tensor_small, tensor_large
604
605 LOGICAL :: assert_stmt, tensors_remapped
606 INTEGER :: max_mm_dim, max_tensor, &
607 unit_nr_prv, ref_tensor, handle
608 TYPE(mp_cart_type) :: mp_comm_opt
609 INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod
610 INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
611 INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod
612 INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
613 INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod
614 INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod
615 LOGICAL :: trans_1, trans_2, trans_3
616 LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
617 INTEGER :: ndims1, ndims2, ndims3
618 INTEGER :: occ_1, occ_2
619 INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3
620
621 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_contract'
622 CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
623 indchar2_mod, indchar3_mod
624 CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
625 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
626 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
627 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
628 LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
629 pgrid_changed_any, do_change_pgrid(2)
630 TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
631 INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
632 REAL(dp) :: pdim_ratio, pdim_ratio_opt
633
634 NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
635 tensor_small)
636
637 CALL timeset(routinen, handle)
638
639 cpassert(tensor_1%valid)
640 cpassert(tensor_2%valid)
641 cpassert(tensor_3%valid)
642
643 assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
644 cpassert(assert_stmt)
645
646 assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
647 cpassert(assert_stmt)
648
649 assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
650 cpassert(assert_stmt)
651
652 assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
653 cpassert(assert_stmt)
654
655 assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
656 cpassert(assert_stmt)
657
658 assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
659 cpassert(assert_stmt)
660
661 unit_nr_prv = prep_output_unit(unit_nr)
662
663 IF (PRESENT(flop)) flop = 0
664 IF (PRESENT(nblks_local)) nblks_local = 0
665
666 IF (PRESENT(move_data)) THEN
667 move_data_1 = move_data
668 move_data_2 = move_data
669 ELSE
670 move_data_1 = .false.
671 move_data_2 = .false.
672 END IF
673
674 nodata_3 = .true.
675 IF (PRESENT(retain_sparsity)) THEN
676 IF (retain_sparsity) nodata_3 = .false.
677 END IF
678
679 CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
680 contract_1, notcontract_1, &
681 contract_2, notcontract_2, &
682 bounds_t1, bounds_t2, &
683 bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
684 do_crop_1=do_crop_1, do_crop_2=do_crop_2)
685
686 IF (do_crop_1) THEN
687 ALLOCATE (tensor_crop_1)
688 CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
689 move_data_1 = .true.
690 ELSE
691 tensor_crop_1 => tensor_1
692 END IF
693
694 IF (do_crop_2) THEN
695 ALLOCATE (tensor_crop_2)
696 CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
697 move_data_2 = .true.
698 ELSE
699 tensor_crop_2 => tensor_2
700 END IF
701
702 ! shortcut for empty tensors
703 ! this is needed to avoid unnecessary work in case user contracts different portions of a
704 ! tensor consecutively to save memory
705 associate(mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
706 occ_1 = dbt_get_num_blocks(tensor_crop_1)
707 CALL mp_comm%max(occ_1)
708 occ_2 = dbt_get_num_blocks(tensor_crop_2)
709 CALL mp_comm%max(occ_2)
710 END associate
711
712 IF (occ_1 == 0 .OR. occ_2 == 0) THEN
713 CALL dbt_scale(tensor_3, beta)
714 IF (do_crop_1) THEN
715 CALL dbt_destroy(tensor_crop_1)
716 DEALLOCATE (tensor_crop_1)
717 END IF
718 IF (do_crop_2) THEN
719 CALL dbt_destroy(tensor_crop_2)
720 DEALLOCATE (tensor_crop_2)
721 END IF
722
723 CALL timestop(handle)
724 RETURN
725 END IF
726
727 IF (unit_nr_prv /= 0) THEN
728 IF (unit_nr_prv > 0) THEN
729 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
730 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
731 trim(tensor_crop_1%name), 'x', trim(tensor_crop_2%name), '=', trim(tensor_3%name)
732 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
733 END IF
734 CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
735 CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
736 CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
737 CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
738 END IF
739
740 ! align tensor index with data, tensor data is not modified
741 ndims1 = ndims_tensor(tensor_crop_1)
742 ndims2 = ndims_tensor(tensor_crop_2)
743 ndims3 = ndims_tensor(tensor_3)
744 ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
745 ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
746 ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
747
748 ! labeling tensor index with letters
749
750 indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
751 indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
752 indchar2(contract_2) = indchar1(contract_1)
753 indchar3(map_1) = indchar1(notcontract_1)
754 indchar3(map_2) = indchar2(notcontract_2)
755
756 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
757 tensor_crop_2, indchar2, &
758 tensor_3, indchar3, unit_nr_prv)
759 IF (unit_nr_prv > 0) THEN
760 WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
761 END IF
762
763 CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
764 tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
765
766 CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
767 tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
768
769 CALL align_tensor(tensor_3, map_1, map_2, &
770 tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
771
772 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
773 tensor_algn_2, indchar2_mod, &
774 tensor_algn_3, indchar3_mod, unit_nr_prv)
775
776 ALLOCATE (dims1(ndims1))
777 ALLOCATE (dims2(ndims2))
778 ALLOCATE (dims3(ndims3))
779
780 ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
781 ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
782 CALL blk_dims_tensor(tensor_crop_1, dims1)
783 CALL blk_dims_tensor(tensor_crop_2, dims2)
784 CALL blk_dims_tensor(tensor_3, dims3)
785
786 max_mm_dim = maxloc([product(int(dims1(notcontract_1), int_8)), &
787 product(int(dims1(contract_1), int_8)), &
788 product(int(dims2(notcontract_2), int_8))], dim=1)
789 max_tensor = maxloc([product(int(dims1, int_8)), product(int(dims2, int_8)), product(int(dims3, int_8))], dim=1)
790 SELECT CASE (max_mm_dim)
791 CASE (1)
792 IF (unit_nr_prv > 0) THEN
793 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
794 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
795 END IF
796 CALL index_linked_sort(contract_1_mod, contract_2_mod)
797 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
798 SELECT CASE (max_tensor)
799 CASE (1)
800 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
801 CASE (3)
802 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
803 CASE DEFAULT
804 cpabort("should not happen")
805 END SELECT
806
807 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
808 contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
809 trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
810 move_data_1=move_data_1, unit_nr=unit_nr_prv)
811
812 CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
813 new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
814
815 SELECT CASE (ref_tensor)
816 CASE (1)
817 tensor_large => tensor_contr_1
818 CASE (2)
819 tensor_large => tensor_contr_3
820 END SELECT
821 tensor_small => tensor_contr_2
822
823 CASE (2)
824 IF (unit_nr_prv > 0) THEN
825 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
826 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
827 END IF
828
829 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
830 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
831 SELECT CASE (max_tensor)
832 CASE (1)
833 CALL index_linked_sort(contract_1_mod, contract_2_mod)
834 CASE (2)
835 CALL index_linked_sort(contract_2_mod, contract_1_mod)
836 CASE DEFAULT
837 cpabort("should not happen")
838 END SELECT
839
840 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
841 notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
842 trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
843 move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
844 trans_1 = .NOT. trans_1
845
846 CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
847 new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
848
849 SELECT CASE (ref_tensor)
850 CASE (1)
851 tensor_large => tensor_contr_1
852 CASE (2)
853 tensor_large => tensor_contr_2
854 END SELECT
855 tensor_small => tensor_contr_3
856
857 CASE (3)
858 IF (unit_nr_prv > 0) THEN
859 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
860 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
861 END IF
862 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
863 CALL index_linked_sort(contract_2_mod, contract_1_mod)
864 SELECT CASE (max_tensor)
865 CASE (2)
866 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
867 CASE (3)
868 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
869 CASE DEFAULT
870 cpabort("should not happen")
871 END SELECT
872
873 CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
874 contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
875 trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
876 move_data_1=move_data_2, unit_nr=unit_nr_prv)
877
878 trans_2 = .NOT. trans_2
879 trans_3 = .NOT. trans_3
880
881 CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
882 trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
883
884 SELECT CASE (ref_tensor)
885 CASE (1)
886 tensor_large => tensor_contr_2
887 CASE (2)
888 tensor_large => tensor_contr_3
889 END SELECT
890 tensor_small => tensor_contr_1
891
892 END SELECT
893
894 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
895 tensor_contr_2, indchar2_mod, &
896 tensor_contr_3, indchar3_mod, unit_nr_prv)
897 IF (unit_nr_prv /= 0) THEN
898 IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
899 IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
900 IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
901 IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
902 END IF
903
904 CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
905 tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
906 beta, &
907 tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
908 unit_nr=unit_nr_prv, log_verbose=log_verbose, &
909 split_opt=split_opt, &
910 move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
911
912 IF (PRESENT(pgrid_opt_1)) THEN
913 IF (.NOT. new_1) THEN
914 ALLOCATE (pgrid_opt_1)
915 pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
916 END IF
917 END IF
918
919 IF (PRESENT(pgrid_opt_2)) THEN
920 IF (.NOT. new_2) THEN
921 ALLOCATE (pgrid_opt_2)
922 pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
923 END IF
924 END IF
925
926 IF (PRESENT(pgrid_opt_3)) THEN
927 IF (.NOT. new_3) THEN
928 ALLOCATE (pgrid_opt_3)
929 pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
930 END IF
931 END IF
932
933 do_batched = tensor_small%matrix_rep%do_batched > 0
934
935 tensors_remapped = .false.
936 IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .true.
937
938 IF (tensors_remapped .AND. do_batched) THEN
939 CALL cp_warn(__location__, &
940 "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
941 END IF
942
943 ! optimize process grid during batched contraction
944 do_change_pgrid(:) = .false.
945 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
946 associate(storage => tensor_small%contraction_storage)
947 cpassert(storage%static)
948 split = dbt_tas_info(tensor_large%matrix_rep)
949 do_change_pgrid(:) = &
950 update_contraction_storage(storage, split_opt, split)
951
952 IF (any(do_change_pgrid)) THEN
953 mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, nint(storage%nsplit_avg))
954 CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
955 nint(storage%nsplit_avg), own_comm=.true.)
956 pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
957 END IF
958
959 END associate
960
961 IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
962 ! check if new grid has better subgrid, if not there is no need to change process grid
963 pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
964 pdims_sub = split%mp_comm_group%num_pe_cart
965
966 pdim_ratio = maxval(real(pdims_sub, dp))/minval(pdims_sub)
967 pdim_ratio_opt = maxval(real(pdims_sub_opt, dp))/minval(pdims_sub_opt)
968 IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
969 do_change_pgrid(1) = .false.
970 CALL dbt_tas_release_info(split_opt_avg)
971 END IF
972 END IF
973 END IF
974
975 IF (unit_nr_prv /= 0) THEN
976 do_write_3 = .true.
977 IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
978 IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .false.
979 END IF
980 IF (do_write_3) THEN
981 CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
982 CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
983 END IF
984 END IF
985
986 IF (new_3) THEN
987 ! need redistribute if we created new tensor for tensor 3
988 CALL dbt_scale(tensor_algn_3, beta)
989 CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.true., move_data=.true.)
990 IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
991 ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
992 ! pointer to data of tensor_3
993 END IF
994
995 ! transfer contraction storage
996 CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
997 CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
998 CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
999
1000 IF (unit_nr_prv /= 0) THEN
1001 IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1002 IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1003 END IF
1004
1005 CALL dbt_destroy(tensor_algn_1)
1006 CALL dbt_destroy(tensor_algn_2)
1007 CALL dbt_destroy(tensor_algn_3)
1008
1009 IF (do_crop_1) THEN
1010 CALL dbt_destroy(tensor_crop_1)
1011 DEALLOCATE (tensor_crop_1)
1012 END IF
1013
1014 IF (do_crop_2) THEN
1015 CALL dbt_destroy(tensor_crop_2)
1016 DEALLOCATE (tensor_crop_2)
1017 END IF
1018
1019 IF (new_1) THEN
1020 CALL dbt_destroy(tensor_contr_1)
1021 DEALLOCATE (tensor_contr_1)
1022 END IF
1023 IF (new_2) THEN
1024 CALL dbt_destroy(tensor_contr_2)
1025 DEALLOCATE (tensor_contr_2)
1026 END IF
1027 IF (new_3) THEN
1028 CALL dbt_destroy(tensor_contr_3)
1029 DEALLOCATE (tensor_contr_3)
1030 END IF
1031
1032 IF (PRESENT(move_data)) THEN
1033 IF (move_data) THEN
1034 CALL dbt_clear(tensor_1)
1035 CALL dbt_clear(tensor_2)
1036 END IF
1037 END IF
1038
1039 IF (unit_nr_prv > 0) THEN
1040 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1041 WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1042 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1043 END IF
1044
1045 IF (any(do_change_pgrid)) THEN
1046 pgrid_changed_any = .false.
1047 SELECT CASE (max_mm_dim)
1048 CASE (1)
1049 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1050 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1051 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1052 pgrid_changed=pgrid_changed, &
1053 unit_nr=unit_nr_prv)
1054 IF (pgrid_changed) pgrid_changed_any = .true.
1055 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1056 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1057 pgrid_changed=pgrid_changed, &
1058 unit_nr=unit_nr_prv)
1059 IF (pgrid_changed) pgrid_changed_any = .true.
1060 END IF
1061 IF (pgrid_changed_any) THEN
1062 IF (tensor_2%matrix_rep%do_batched == 3) THEN
1063 ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1064 ! in TAS multiply algorithm
1065 CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1066 END IF
1067 END IF
1068 CASE (2)
1069 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1070 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1071 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1072 pgrid_changed=pgrid_changed, &
1073 unit_nr=unit_nr_prv)
1074 IF (pgrid_changed) pgrid_changed_any = .true.
1075 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1076 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1077 pgrid_changed=pgrid_changed, &
1078 unit_nr=unit_nr_prv)
1079 IF (pgrid_changed) pgrid_changed_any = .true.
1080 END IF
1081 IF (pgrid_changed_any) THEN
1082 IF (tensor_3%matrix_rep%do_batched == 3) THEN
1083 CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1084 END IF
1085 END IF
1086 CASE (3)
1087 IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1088 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1089 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1090 pgrid_changed=pgrid_changed, &
1091 unit_nr=unit_nr_prv)
1092 IF (pgrid_changed) pgrid_changed_any = .true.
1093 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1094 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1095 pgrid_changed=pgrid_changed, &
1096 unit_nr=unit_nr_prv)
1097 IF (pgrid_changed) pgrid_changed_any = .true.
1098 END IF
1099 IF (pgrid_changed_any) THEN
1100 IF (tensor_1%matrix_rep%do_batched == 3) THEN
1101 CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1102 END IF
1103 END IF
1104 END SELECT
1105 CALL dbt_tas_release_info(split_opt_avg)
1106 END IF
1107
1108 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1109 ! freeze TAS process grids if tensor grids were optimized
1110 CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.true.)
1111 CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.true.)
1112 CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.true.)
1113 END IF
1114
1115 CALL dbt_tas_release_info(split_opt)
1116
1117 CALL timestop(handle)
1118
1119 END SUBROUTINE
1120
1121! **************************************************************************************************
1122!> \brief align tensor index with data
1123!> \author Patrick Seewald
1124! **************************************************************************************************
1125 SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1126 tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1127 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1128 INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in
1129 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1130 INTEGER, DIMENSION(SIZE(contract_in)), &
1131 INTENT(OUT) :: contract_out
1132 INTEGER, DIMENSION(SIZE(notcontract_in)), &
1133 INTENT(OUT) :: notcontract_out
1134 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1135 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1136 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1137
1138 CALL dbt_align_index(tensor_in, tensor_out, order=align)
1139 contract_out = align(contract_in)
1140 notcontract_out = align(notcontract_in)
1141 indp_out(align) = indp_in
1142
1143 END SUBROUTINE
1144
1145! **************************************************************************************************
1146!> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1147!> matrix multiplication. This routine reshapes the two largest of the three tensors.
1148!> Redistribution is avoided if tensors already in a consistent layout.
1149!> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1150!> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
1151!> 1:1 correspondence with ind1_linked
1152!> \param trans1 transpose flag of matrix rep. tensor 1
1153!> \param trans2 transpose flag of matrix rep. tensor 2
1154!> \param new1 whether a new tensor 1 was created
1155!> \param new2 whether a new tensor 2 was created
1156!> \param nodata1 don't copy data of tensor 1
1157!> \param nodata2 don't copy data of tensor 2
1158!> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
1159!> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
1160!> \param optimize_dist experimental: optimize distribution
1161!> \param unit_nr output unit
1162!> \author Patrick Seewald
1163! **************************************************************************************************
1164 SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1165 ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1166 nodata1, nodata2, move_data_1, &
1167 move_data_2, optimize_dist, unit_nr)
1168 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor1
1169 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor2
1170 TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out
1171 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free
1172 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked
1173 LOGICAL, INTENT(OUT) :: trans1, trans2
1174 LOGICAL, INTENT(OUT) :: new1, new2
1175 INTEGER, INTENT(OUT) :: ref_tensor
1176 LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
1177 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
1178 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
1179 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1180 INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1181 unit_nr_prv
1182 TYPE(mp_cart_type) :: comm_2d
1183 TYPE(array_list) :: dist_list
1184 INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims
1185 TYPE(dbt_distribution_type) :: dist_in
1186 INTEGER(KIND=int_8) :: nblkrows, nblkcols
1187 LOGICAL :: optimize_dist_prv
1188 INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1189 INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1190
1191 NULLIFY (tensor1_out, tensor2_out)
1192
1193 unit_nr_prv = prep_output_unit(unit_nr)
1194
1195 CALL blk_dims_tensor(tensor1, dims1)
1196 CALL blk_dims_tensor(tensor2, dims2)
1197
1198 IF (product(int(dims1, int_8)) .GE. product(int(dims2, int_8))) THEN
1199 ref_tensor = 1
1200 ELSE
1201 ref_tensor = 2
1202 END IF
1203
1204 IF (PRESENT(optimize_dist)) THEN
1205 optimize_dist_prv = optimize_dist
1206 ELSE
1207 optimize_dist_prv = .false.
1208 END IF
1209
1210 compat1 = compat_map(tensor1%nd_index, ind1_linked)
1211 compat2 = compat_map(tensor2%nd_index, ind2_linked)
1212 compat1_old = compat1
1213 compat2_old = compat2
1214
1215 IF (unit_nr_prv > 0) THEN
1216 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1%name), ":"
1217 SELECT CASE (compat1)
1218 CASE (0)
1219 WRITE (unit_nr_prv, '(A)') "Not compatible"
1220 CASE (1)
1221 WRITE (unit_nr_prv, '(A)') "Normal"
1222 CASE (2)
1223 WRITE (unit_nr_prv, '(A)') "Transposed"
1224 END SELECT
1225 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2%name), ":"
1226 SELECT CASE (compat2)
1227 CASE (0)
1228 WRITE (unit_nr_prv, '(A)') "Not compatible"
1229 CASE (1)
1230 WRITE (unit_nr_prv, '(A)') "Normal"
1231 CASE (2)
1232 WRITE (unit_nr_prv, '(A)') "Transposed"
1233 END SELECT
1234 END IF
1235
1236 new1 = .false.
1237 new2 = .false.
1238
1239 IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1240 new1 = .true.
1241 END IF
1242
1243 IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1244 new2 = .true.
1245 END IF
1246
1247 IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1248 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1249 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor1%name)
1250 nblkrows = product(int(dims1(ind1_linked), kind=int_8))
1251 nblkcols = product(int(dims1(ind1_free), kind=int_8))
1252 comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1253 ALLOCATE (tensor1_out)
1254 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1255 nodata=nodata1, move_data=move_data_1)
1256 CALL comm_2d%free()
1257 compat1 = 1
1258 ELSE
1259 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1260 tensor1_out => tensor1
1261 END IF
1262 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1263 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1264 trim(tensor2%name), "compatible with", trim(tensor1%name)
1265 dist_in = dbt_distribution(tensor1_out)
1266 dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1267 IF (compat1 == 1) THEN ! linked index is first 2d dimension
1268 ! get distribution of linked index, tensor 2 must adopt this distribution
1269 ! get grid dimensions of linked index
1270 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1271 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1272 ALLOCATE (tensor2_out)
1273 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1274 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1275 ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1276 ! get distribution of linked index, tensor 2 must adopt this distribution
1277 ! get grid dimensions of linked index
1278 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1279 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1280 ALLOCATE (tensor2_out)
1281 CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1282 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1283 ELSE
1284 cpabort("should not happen")
1285 END IF
1286 compat2 = compat1
1287 ELSE
1288 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1289 tensor2_out => tensor2
1290 END IF
1291 ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1292 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1293 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor2%name)
1294 nblkrows = product(int(dims2(ind2_linked), kind=int_8))
1295 nblkcols = product(int(dims2(ind2_free), kind=int_8))
1296 comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1297 ALLOCATE (tensor2_out)
1298 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1299 CALL comm_2d%free()
1300 compat2 = 1
1301 ELSE
1302 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1303 tensor2_out => tensor2
1304 END IF
1305 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1306 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", trim(tensor1%name), &
1307 "compatible with", trim(tensor2%name)
1308 dist_in = dbt_distribution(tensor2_out)
1309 dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1310 IF (compat2 == 1) THEN
1311 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1312 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1313 ALLOCATE (tensor1_out)
1314 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1315 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1316 ELSEIF (compat2 == 2) THEN
1317 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1318 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1319 ALLOCATE (tensor1_out)
1320 CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1321 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1322 ELSE
1323 cpabort("should not happen")
1324 END IF
1325 compat1 = compat2
1326 ELSE
1327 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1328 tensor1_out => tensor1
1329 END IF
1330 END IF
1331
1332 SELECT CASE (compat1)
1333 CASE (1)
1334 trans1 = .false.
1335 CASE (2)
1336 trans1 = .true.
1337 CASE DEFAULT
1338 cpabort("should not happen")
1339 END SELECT
1340
1341 SELECT CASE (compat2)
1342 CASE (1)
1343 trans2 = .false.
1344 CASE (2)
1345 trans2 = .true.
1346 CASE DEFAULT
1347 cpabort("should not happen")
1348 END SELECT
1349
1350 IF (unit_nr_prv > 0) THEN
1351 IF (compat1 .NE. compat1_old) THEN
1352 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1_out%name), ":"
1353 SELECT CASE (compat1)
1354 CASE (0)
1355 WRITE (unit_nr_prv, '(A)') "Not compatible"
1356 CASE (1)
1357 WRITE (unit_nr_prv, '(A)') "Normal"
1358 CASE (2)
1359 WRITE (unit_nr_prv, '(A)') "Transposed"
1360 END SELECT
1361 END IF
1362 IF (compat2 .NE. compat2_old) THEN
1363 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2_out%name), ":"
1364 SELECT CASE (compat2)
1365 CASE (0)
1366 WRITE (unit_nr_prv, '(A)') "Not compatible"
1367 CASE (1)
1368 WRITE (unit_nr_prv, '(A)') "Normal"
1369 CASE (2)
1370 WRITE (unit_nr_prv, '(A)') "Transposed"
1371 END SELECT
1372 END IF
1373 END IF
1374
1375 IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .true.
1376 IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .true.
1377
1378 END SUBROUTINE
1379
1380! **************************************************************************************************
1381!> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1382!> matrix multiplication. This routine reshapes the smallest of the three tensors.
1383!> \param ind1 index that should be mapped to first matrix dimension
1384!> \param ind2 index that should be mapped to second matrix dimension
1385!> \param trans transpose flag of matrix rep.
1386!> \param new whether a new tensor was created for tensor_out
1387!> \param nodata don't copy tensor data
1388!> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
1389!> \param unit_nr output unit
1390!> \author Patrick Seewald
1391! **************************************************************************************************
1392 SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1393 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor_in
1394 INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2
1395 TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor_out
1396 LOGICAL, INTENT(OUT) :: trans
1397 LOGICAL, INTENT(OUT) :: new
1398 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1399 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1400 INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1401 LOGICAL :: nodata_prv
1402
1403 NULLIFY (tensor_out)
1404 IF (PRESENT(nodata)) THEN
1405 nodata_prv = nodata
1406 ELSE
1407 nodata_prv = .false.
1408 END IF
1409
1410 unit_nr_prv = prep_output_unit(unit_nr)
1411
1412 new = .false.
1413 compat1 = compat_map(tensor_in%nd_index, ind1)
1414 compat2 = compat_map(tensor_in%nd_index, ind2)
1415 compat1_old = compat1; compat2_old = compat2
1416 IF (unit_nr_prv > 0) THEN
1417 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_in%name), ":"
1418 IF (compat1 == 1 .AND. compat2 == 2) THEN
1419 WRITE (unit_nr_prv, '(A)') "Normal"
1420 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1421 WRITE (unit_nr_prv, '(A)') "Transposed"
1422 ELSE
1423 WRITE (unit_nr_prv, '(A)') "Not compatible"
1424 END IF
1425 END IF
1426 IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1427
1428 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor_in%name)
1429
1430 ALLOCATE (tensor_out)
1431 CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1432 CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1433 compat1 = 1
1434 compat2 = 2
1435 new = .true.
1436 ELSE
1437 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor_in%name)
1438 tensor_out => tensor_in
1439 END IF
1440
1441 IF (compat1 == 1 .AND. compat2 == 2) THEN
1442 trans = .false.
1443 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1444 trans = .true.
1445 ELSE
1446 cpabort("this should not happen")
1447 END IF
1448
1449 IF (unit_nr_prv > 0) THEN
1450 IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1451 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_out%name), ":"
1452 IF (compat1 == 1 .AND. compat2 == 2) THEN
1453 WRITE (unit_nr_prv, '(A)') "Normal"
1454 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1455 WRITE (unit_nr_prv, '(A)') "Transposed"
1456 ELSE
1457 WRITE (unit_nr_prv, '(A)') "Not compatible"
1458 END IF
1459 END IF
1460 END IF
1461
1462 END SUBROUTINE
1463
1464! **************************************************************************************************
1465!> \brief update contraction storage that keeps track of process grids during a batched contraction
1466!> and decide if tensor process grid needs to be optimized
1467!> \param split_opt optimized TAS process grid
1468!> \param split current TAS process grid
1469!> \author Patrick Seewald
1470! **************************************************************************************************
1471 FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1472 TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
1473 TYPE(dbt_tas_split_info), INTENT(IN) :: split_opt
1474 TYPE(dbt_tas_split_info), INTENT(IN) :: split
1475 INTEGER, DIMENSION(2) :: pdims, pdims_sub
1476 LOGICAL, DIMENSION(2) :: do_change_pgrid
1477 REAL(kind=dp) :: change_criterion, pdims_ratio
1478 INTEGER :: nsplit_opt, nsplit
1479
1480 cpassert(ALLOCATED(split_opt%ngroup_opt))
1481 nsplit_opt = split_opt%ngroup_opt
1482 nsplit = split%ngroup
1483
1484 pdims = split%mp_comm%num_pe_cart
1485
1486 storage%ibatch = storage%ibatch + 1
1487
1488 storage%nsplit_avg = (storage%nsplit_avg*real(storage%ibatch - 1, dp) + real(nsplit_opt, dp)) &
1489 /real(storage%ibatch, dp)
1490
1491 SELECT CASE (split_opt%split_rowcol)
1492 CASE (rowsplit)
1493 pdims_ratio = real(pdims(1), dp)/pdims(2)
1494 CASE (colsplit)
1495 pdims_ratio = real(pdims(2), dp)/pdims(1)
1496 END SELECT
1497
1498 do_change_pgrid(:) = .false.
1499
1500 ! check for process grid dimensions
1501 pdims_sub = split%mp_comm_group%num_pe_cart
1502 change_criterion = maxval(real(pdims_sub, dp))/minval(pdims_sub)
1503 IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .true.
1504
1505 ! check for split factor
1506 change_criterion = max(real(nsplit, dp)/storage%nsplit_avg, real(storage%nsplit_avg, dp)/nsplit)
1507 IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .true.
1508
1509 END FUNCTION
1510
1511! **************************************************************************************************
1512!> \brief Check if 2d index is compatible with tensor index
1513!> \author Patrick Seewald
1514! **************************************************************************************************
1515 FUNCTION compat_map(nd_index, compat_ind)
1516 TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1517 INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind
1518 INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1519 INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1520 INTEGER :: compat_map
1521
1522 CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1523
1524 compat_map = 0
1525 IF (array_eq_i(map1, compat_ind)) THEN
1526 compat_map = 1
1527 ELSEIF (array_eq_i(map2, compat_ind)) THEN
1528 compat_map = 2
1529 END IF
1530
1531 END FUNCTION
1532
1533! **************************************************************************************************
1534!> \brief
1535!> \author Patrick Seewald
1536! **************************************************************************************************
1537 SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1538 INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1539 INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices
1540
1541 CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1542 ind_dep(:) = ind_dep(sort_indices)
1543
1544 END SUBROUTINE
1545
1546! **************************************************************************************************
1547!> \brief
1548!> \author Patrick Seewald
1549! **************************************************************************************************
1550 FUNCTION opt_pgrid(tensor, tas_split_info)
1551 TYPE(dbt_type), INTENT(IN) :: tensor
1552 TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
1553 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1554 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1555 TYPE(dbt_pgrid_type) :: opt_pgrid
1556 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1557
1558 CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1559 CALL blk_dims_tensor(tensor, dims)
1560 opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1561
1562 ALLOCATE (opt_pgrid%tas_split_info, source=tas_split_info)
1563 CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1564 END FUNCTION
1565
1566! **************************************************************************************************
1567!> \brief Copy tensor to tensor with modified index mapping
1568!> \param map1_2d new index mapping
1569!> \param map2_2d new index mapping
1570!> \author Patrick Seewald
1571! **************************************************************************************************
1572 SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1573 mp_dims_1, mp_dims_2, name, nodata, move_data)
1574 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1575 INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
1576 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1577 CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1578 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1579 CLASS(mp_comm_type), INTENT(IN), OPTIONAL :: comm_2d
1580 TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1581 INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1582 INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1583 CHARACTER(len=default_string_length) :: name_tmp
1584 INTEGER, DIMENSION(:), ALLOCATABLE :: blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4, &
1585 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4
1586 TYPE(dbt_distribution_type) :: dist
1587 TYPE(mp_cart_type) :: comm_2d_prv
1588 INTEGER :: handle, i
1589 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1590 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_remap'
1591 LOGICAL :: nodata_prv
1592 TYPE(dbt_pgrid_type) :: comm_nd
1593
1594 CALL timeset(routinen, handle)
1595
1596 IF (PRESENT(name)) THEN
1597 name_tmp = name
1598 ELSE
1599 name_tmp = tensor_in%name
1600 END IF
1601 IF (PRESENT(dist1)) THEN
1602 cpassert(PRESENT(mp_dims_1))
1603 END IF
1604
1605 IF (PRESENT(dist2)) THEN
1606 cpassert(PRESENT(mp_dims_2))
1607 END IF
1608
1609 IF (PRESENT(comm_2d)) THEN
1610 comm_2d_prv = comm_2d
1611 ELSE
1612 comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1613 END IF
1614
1615 comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1616 CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1617
1618 IF (ndims_tensor(tensor_in) == 2) THEN
1619 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2)
1620 END IF
1621 IF (ndims_tensor(tensor_in) == 3) THEN
1622 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3)
1623 END IF
1624 IF (ndims_tensor(tensor_in) == 4) THEN
1625 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1626 END IF
1627
1628 IF (ndims_tensor(tensor_in) == 2) THEN
1629 IF (PRESENT(dist1)) THEN
1630 IF (any(map1_2d == 1)) THEN
1631 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1632 CALL get_ith_array(dist1, i, nd_dist_1)
1633 END IF
1634 END IF
1635
1636 IF (PRESENT(dist2)) THEN
1637 IF (any(map2_2d == 1)) THEN
1638 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1639 CALL get_ith_array(dist2, i, nd_dist_1)
1640 END IF
1641 END IF
1642
1643 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1644 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1645 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1646 END IF
1647 IF (PRESENT(dist1)) THEN
1648 IF (any(map1_2d == 2)) THEN
1649 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1650 CALL get_ith_array(dist1, i, nd_dist_2)
1651 END IF
1652 END IF
1653
1654 IF (PRESENT(dist2)) THEN
1655 IF (any(map2_2d == 2)) THEN
1656 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1657 CALL get_ith_array(dist2, i, nd_dist_2)
1658 END IF
1659 END IF
1660
1661 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1662 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1663 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1664 END IF
1665 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1666 nd_dist_1, nd_dist_2, own_comm=.true.)
1667 END IF
1668 IF (ndims_tensor(tensor_in) == 3) THEN
1669 IF (PRESENT(dist1)) THEN
1670 IF (any(map1_2d == 1)) THEN
1671 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1672 CALL get_ith_array(dist1, i, nd_dist_1)
1673 END IF
1674 END IF
1675
1676 IF (PRESENT(dist2)) THEN
1677 IF (any(map2_2d == 1)) THEN
1678 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1679 CALL get_ith_array(dist2, i, nd_dist_1)
1680 END IF
1681 END IF
1682
1683 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1684 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1685 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1686 END IF
1687 IF (PRESENT(dist1)) THEN
1688 IF (any(map1_2d == 2)) THEN
1689 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1690 CALL get_ith_array(dist1, i, nd_dist_2)
1691 END IF
1692 END IF
1693
1694 IF (PRESENT(dist2)) THEN
1695 IF (any(map2_2d == 2)) THEN
1696 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1697 CALL get_ith_array(dist2, i, nd_dist_2)
1698 END IF
1699 END IF
1700
1701 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1702 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1703 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1704 END IF
1705 IF (PRESENT(dist1)) THEN
1706 IF (any(map1_2d == 3)) THEN
1707 i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1708 CALL get_ith_array(dist1, i, nd_dist_3)
1709 END IF
1710 END IF
1711
1712 IF (PRESENT(dist2)) THEN
1713 IF (any(map2_2d == 3)) THEN
1714 i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1715 CALL get_ith_array(dist2, i, nd_dist_3)
1716 END IF
1717 END IF
1718
1719 IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1720 ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1721 CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1722 END IF
1723 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1724 nd_dist_1, nd_dist_2, nd_dist_3, own_comm=.true.)
1725 END IF
1726 IF (ndims_tensor(tensor_in) == 4) THEN
1727 IF (PRESENT(dist1)) THEN
1728 IF (any(map1_2d == 1)) THEN
1729 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1730 CALL get_ith_array(dist1, i, nd_dist_1)
1731 END IF
1732 END IF
1733
1734 IF (PRESENT(dist2)) THEN
1735 IF (any(map2_2d == 1)) THEN
1736 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1737 CALL get_ith_array(dist2, i, nd_dist_1)
1738 END IF
1739 END IF
1740
1741 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1742 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1743 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1744 END IF
1745 IF (PRESENT(dist1)) THEN
1746 IF (any(map1_2d == 2)) THEN
1747 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1748 CALL get_ith_array(dist1, i, nd_dist_2)
1749 END IF
1750 END IF
1751
1752 IF (PRESENT(dist2)) THEN
1753 IF (any(map2_2d == 2)) THEN
1754 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1755 CALL get_ith_array(dist2, i, nd_dist_2)
1756 END IF
1757 END IF
1758
1759 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1760 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1761 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1762 END IF
1763 IF (PRESENT(dist1)) THEN
1764 IF (any(map1_2d == 3)) THEN
1765 i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1766 CALL get_ith_array(dist1, i, nd_dist_3)
1767 END IF
1768 END IF
1769
1770 IF (PRESENT(dist2)) THEN
1771 IF (any(map2_2d == 3)) THEN
1772 i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1773 CALL get_ith_array(dist2, i, nd_dist_3)
1774 END IF
1775 END IF
1776
1777 IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1778 ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1779 CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1780 END IF
1781 IF (PRESENT(dist1)) THEN
1782 IF (any(map1_2d == 4)) THEN
1783 i = minloc(map1_2d, dim=1, mask=map1_2d == 4) ! i is location of idim in map1_2d
1784 CALL get_ith_array(dist1, i, nd_dist_4)
1785 END IF
1786 END IF
1787
1788 IF (PRESENT(dist2)) THEN
1789 IF (any(map2_2d == 4)) THEN
1790 i = minloc(map2_2d, dim=1, mask=map2_2d == 4) ! i is location of idim in map2_2d
1791 CALL get_ith_array(dist2, i, nd_dist_4)
1792 END IF
1793 END IF
1794
1795 IF (.NOT. ALLOCATED(nd_dist_4)) THEN
1796 ALLOCATE (nd_dist_4(SIZE(blk_sizes_4)))
1797 CALL dbt_default_distvec(SIZE(blk_sizes_4), pdims(4), blk_sizes_4, nd_dist_4)
1798 END IF
1799 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1800 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm=.true.)
1801 END IF
1802
1803 IF (ndims_tensor(tensor_in) == 2) THEN
1804 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1805 blk_sizes_1, blk_sizes_2)
1806 END IF
1807 IF (ndims_tensor(tensor_in) == 3) THEN
1808 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1809 blk_sizes_1, blk_sizes_2, blk_sizes_3)
1810 END IF
1811 IF (ndims_tensor(tensor_in) == 4) THEN
1812 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1813 blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1814 END IF
1815
1816 IF (PRESENT(nodata)) THEN
1817 nodata_prv = nodata
1818 ELSE
1819 nodata_prv = .false.
1820 END IF
1821
1822 IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1823 CALL dbt_distribution_destroy(dist)
1824
1825 CALL timestop(handle)
1826 END SUBROUTINE
1827
1828! **************************************************************************************************
1829!> \brief Align index with data
1830!> \param order permutation resulting from alignment
1831!> \author Patrick Seewald
1832! **************************************************************************************************
1833 SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1834 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1835 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1836 INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1837 INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1838 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1839 INTENT(OUT), OPTIONAL :: order
1840 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1841 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_align_index'
1842 INTEGER :: handle
1843
1844 CALL timeset(routinen, handle)
1845
1846 CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1847 order_prv = dbt_inverse_order([map1_2d, map2_2d])
1848 CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1849
1850 IF (PRESENT(order)) order = order_prv
1851
1852 CALL timestop(handle)
1853 END SUBROUTINE
1854
1855! **************************************************************************************************
1856!> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
1857!> \author Patrick Seewald
1858! **************************************************************************************************
1859 SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1860 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1861 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1862 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1863 INTENT(IN) :: order
1864
1865 TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1866 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_permute_index'
1867 INTEGER :: handle
1868 INTEGER :: ndims
1869
1870 CALL timeset(routinen, handle)
1871
1872 ndims = ndims_tensor(tensor_in)
1873
1874 CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1875 CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1876 CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1877
1878 tensor_out%matrix_rep => tensor_in%matrix_rep
1879 tensor_out%owns_matrix = .false.
1880
1881 tensor_out%nd_index = nd_index_rs
1882 tensor_out%nd_index_blk = nd_index_blk_rs
1883 tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1884 IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1885 ALLOCATE (tensor_out%pgrid%tas_split_info, source=tensor_in%pgrid%tas_split_info)
1886 END IF
1887 tensor_out%refcount => tensor_in%refcount
1888 CALL dbt_hold(tensor_out)
1889
1890 CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1891 CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1892 CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1893 CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1894 ALLOCATE (tensor_out%nblks_local(ndims))
1895 ALLOCATE (tensor_out%nfull_local(ndims))
1896 tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1897 tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1898 tensor_out%name = tensor_in%name
1899 tensor_out%valid = .true.
1900
1901 IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1902 ALLOCATE (tensor_out%contraction_storage, source=tensor_in%contraction_storage)
1903 CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1904 CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1905 END IF
1906
1907 CALL timestop(handle)
1908 END SUBROUTINE
1909
1910! **************************************************************************************************
1911!> \brief Map contraction bounds to bounds referring to tensor indices
1912!> see dbt_contract for docu of dummy arguments
1913!> \param bounds_t1 bounds mapped to tensor_1
1914!> \param bounds_t2 bounds mapped to tensor_2
1915!> \param do_crop_1 whether tensor 1 should be cropped
1916!> \param do_crop_2 whether tensor 2 should be cropped
1917!> \author Patrick Seewald
1918! **************************************************************************************************
1919 SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1920 contract_1, notcontract_1, &
1921 contract_2, notcontract_2, &
1922 bounds_t1, bounds_t2, &
1923 bounds_1, bounds_2, bounds_3, &
1924 do_crop_1, do_crop_2)
1925
1926 TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2
1927 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, &
1928 notcontract_1, notcontract_2
1929 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1930 INTENT(OUT) :: bounds_t1
1931 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1932 INTENT(OUT) :: bounds_t2
1933 INTEGER, DIMENSION(2, SIZE(contract_1)), &
1934 INTENT(IN), OPTIONAL :: bounds_1
1935 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1936 INTENT(IN), OPTIONAL :: bounds_2
1937 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1938 INTENT(IN), OPTIONAL :: bounds_3
1939 LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2
1940 LOGICAL, DIMENSION(2) :: do_crop
1941
1942 do_crop = .false.
1943
1944 bounds_t1(1, :) = 1
1945 CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1946
1947 bounds_t2(1, :) = 1
1948 CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1949
1950 IF (PRESENT(bounds_1)) THEN
1951 bounds_t1(:, contract_1) = bounds_1
1952 do_crop(1) = .true.
1953 bounds_t2(:, contract_2) = bounds_1
1954 do_crop(2) = .true.
1955 END IF
1956
1957 IF (PRESENT(bounds_2)) THEN
1958 bounds_t1(:, notcontract_1) = bounds_2
1959 do_crop(1) = .true.
1960 END IF
1961
1962 IF (PRESENT(bounds_3)) THEN
1963 bounds_t2(:, notcontract_2) = bounds_3
1964 do_crop(2) = .true.
1965 END IF
1966
1967 IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1968 IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1969
1970 END SUBROUTINE
1971
1972! **************************************************************************************************
1973!> \brief print tensor contraction indices in a human readable way
1974!> \param indchar1 characters printed for index of tensor 1
1975!> \param indchar2 characters printed for index of tensor 2
1976!> \param indchar3 characters printed for index of tensor 3
1977!> \param unit_nr output unit
1978!> \author Patrick Seewald
1979! **************************************************************************************************
1980 SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1981 TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1982 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1983 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1984 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1985 INTEGER, INTENT(IN) :: unit_nr
1986 INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1987 INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1988 INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1989 INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1990 INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1991 INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1992 INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1993
1994 unit_nr_prv = prep_output_unit(unit_nr)
1995
1996 IF (unit_nr_prv /= 0) THEN
1997 CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1998 CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
1999 CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
2000 END IF
2001
2002 IF (unit_nr_prv > 0) THEN
2003 WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
2004 WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
2005 DO ichar1 = 1, SIZE(indchar1)
2006 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
2007 END DO
2008 WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2009 DO ichar2 = 1, SIZE(indchar2)
2010 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
2011 END DO
2012 WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2013 DO ichar3 = 1, SIZE(indchar3)
2014 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
2015 END DO
2016 WRITE (unit_nr_prv, '(A)') ")"
2017
2018 WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
2019 DO ichar1 = 1, SIZE(map11)
2020 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
2021 END DO
2022 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2023 DO ichar1 = 1, SIZE(map12)
2024 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
2025 END DO
2026 WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2027 DO ichar2 = 1, SIZE(map21)
2028 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
2029 END DO
2030 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2031 DO ichar2 = 1, SIZE(map22)
2032 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
2033 END DO
2034 WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2035 DO ichar3 = 1, SIZE(map31)
2036 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
2037 END DO
2038 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2039 DO ichar3 = 1, SIZE(map32)
2040 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
2041 END DO
2042 WRITE (unit_nr_prv, '(A)') ")"
2043 END IF
2044
2045 END SUBROUTINE
2046
2047! **************************************************************************************************
2048!> \brief Initialize batched contraction for this tensor.
2049!>
2050!> Explanation: A batched contraction is a contraction performed in several consecutive steps
2051!> by specification of bounds in dbt_contract. This can be used to reduce memory by
2052!> a large factor. The routines dbt_batched_contract_init and
2053!> dbt_batched_contract_finalize should be called to define the scope of a batched
2054!> contraction as this enables important optimizations (adapting communication scheme to
2055!> batches and adapting process grid to multiplication algorithm). The routines
2056!> dbt_batched_contract_init and dbt_batched_contract_finalize must be
2057!> called before the first and after the last contraction step on all 3 tensors.
2058!>
2059!> Requirements:
2060!> - the tensors are in a compatible matrix layout (see documentation of
2061!> `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
2062!> disabled and a warning is issued.
2063!> - within the scope of a batched contraction, it is not allowed to access or change tensor
2064!> data except by calling the routines dbt_contract & dbt_copy.
2065!> - the bounds affecting indices of the smallest tensor must not change in the course of a
2066!> batched contraction (todo: get rid of this requirement).
2067!>
2068!> Side effects:
2069!> - the parallel layout (process grid and distribution) of all tensors may change. In order
2070!> to disable the process grid optimization including this side effect, call this routine
2071!> only on the smallest of the 3 tensors.
2072!>
2073!> \note
2074!> Note 1: for an example of batched contraction see `examples/dbt_example.F`.
2075!> (todo: the example is outdated and should be updated).
2076!>
2077!> Note 2: it is meaningful to use this feature if the contraction consists of one batch only
2078!> but if multiple contractions involving the same 3 tensors are performed
2079!> (batched_contract_init and batched_contract_finalize must then be called before/after each
2080!> contraction call). The process grid is then optimized after the first contraction
2081!> and future contraction may profit from this optimization.
2082!>
2083!> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2084!> a new range. The size should be the number of ranges plus one, the last
2085!> element being the block index plus one of the last block in the last range.
2086!> For internal load balancing optimizations, optionally specify the index
2087!> ranges of batched contraction.
2088!> \author Patrick Seewald
2089! **************************************************************************************************
2090 SUBROUTINE dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2091 TYPE(dbt_type), INTENT(INOUT) :: tensor
2092 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2093 INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
2094 INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_prv_1, batch_range_prv_2, batch_range_prv_3,&
2095 & batch_range_prv_4
2096 LOGICAL :: static_range
2097
2098 CALL dbt_get_info(tensor, nblks_total=tdims)
2099
2100 static_range = .true.
2101 IF (ndims_tensor(tensor) >= 1) THEN
2102 IF (PRESENT(batch_range_1)) THEN
2103 ALLOCATE (batch_range_prv_1, source=batch_range_1)
2104 static_range = .false.
2105 ELSE
2106 ALLOCATE (batch_range_prv_1(2))
2107 batch_range_prv_1(1) = 1
2108 batch_range_prv_1(2) = tdims(1) + 1
2109 END IF
2110 END IF
2111 IF (ndims_tensor(tensor) >= 2) THEN
2112 IF (PRESENT(batch_range_2)) THEN
2113 ALLOCATE (batch_range_prv_2, source=batch_range_2)
2114 static_range = .false.
2115 ELSE
2116 ALLOCATE (batch_range_prv_2(2))
2117 batch_range_prv_2(1) = 1
2118 batch_range_prv_2(2) = tdims(2) + 1
2119 END IF
2120 END IF
2121 IF (ndims_tensor(tensor) >= 3) THEN
2122 IF (PRESENT(batch_range_3)) THEN
2123 ALLOCATE (batch_range_prv_3, source=batch_range_3)
2124 static_range = .false.
2125 ELSE
2126 ALLOCATE (batch_range_prv_3(2))
2127 batch_range_prv_3(1) = 1
2128 batch_range_prv_3(2) = tdims(3) + 1
2129 END IF
2130 END IF
2131 IF (ndims_tensor(tensor) >= 4) THEN
2132 IF (PRESENT(batch_range_4)) THEN
2133 ALLOCATE (batch_range_prv_4, source=batch_range_4)
2134 static_range = .false.
2135 ELSE
2136 ALLOCATE (batch_range_prv_4(2))
2137 batch_range_prv_4(1) = 1
2138 batch_range_prv_4(2) = tdims(4) + 1
2139 END IF
2140 END IF
2141
2142 ALLOCATE (tensor%contraction_storage)
2143 tensor%contraction_storage%static = static_range
2144 IF (static_range) THEN
2145 CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
2146 END IF
2147 tensor%contraction_storage%nsplit_avg = 0.0_dp
2148 tensor%contraction_storage%ibatch = 0
2149
2150 IF (ndims_tensor(tensor) == 1) THEN
2151 CALL create_array_list(tensor%contraction_storage%batch_ranges, 1, &
2152 batch_range_prv_1)
2153 END IF
2154 IF (ndims_tensor(tensor) == 2) THEN
2155 CALL create_array_list(tensor%contraction_storage%batch_ranges, 2, &
2156 batch_range_prv_1, batch_range_prv_2)
2157 END IF
2158 IF (ndims_tensor(tensor) == 3) THEN
2159 CALL create_array_list(tensor%contraction_storage%batch_ranges, 3, &
2160 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3)
2161 END IF
2162 IF (ndims_tensor(tensor) == 4) THEN
2163 CALL create_array_list(tensor%contraction_storage%batch_ranges, 4, &
2164 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3, batch_range_prv_4)
2165 END IF
2166
2167 END SUBROUTINE
2168
2169! **************************************************************************************************
2170!> \brief finalize batched contraction. This performs all communication that has been postponed in
2171!> the contraction calls.
2172!> \author Patrick Seewald
2173! **************************************************************************************************
2174 SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
2175 TYPE(dbt_type), INTENT(INOUT) :: tensor
2176 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2177 LOGICAL :: do_write
2178 INTEGER :: unit_nr_prv, handle
2179
2180 CALL tensor%pgrid%mp_comm_2d%sync()
2181 CALL timeset("dbt_total", handle)
2182 unit_nr_prv = prep_output_unit(unit_nr)
2183
2184 do_write = .false.
2185
2186 IF (tensor%contraction_storage%static) THEN
2187 IF (tensor%matrix_rep%do_batched > 0) THEN
2188 IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .true.
2189 END IF
2190 CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
2191 END IF
2192
2193 IF (do_write .AND. unit_nr_prv /= 0) THEN
2194 IF (unit_nr_prv > 0) THEN
2195 WRITE (unit_nr_prv, "(T2,A)") &
2196 "FINALIZING BATCHED PROCESSING OF MATMUL"
2197 END IF
2198 CALL dbt_write_tensor_info(tensor, unit_nr_prv)
2199 CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
2200 END IF
2201
2202 CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2203 DEALLOCATE (tensor%contraction_storage)
2204 CALL tensor%pgrid%mp_comm_2d%sync()
2205 CALL timestop(handle)
2206
2207 END SUBROUTINE
2208
2209! **************************************************************************************************
2210!> \brief change the process grid of a tensor
2211!> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
2212!> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2213!> a new range. The size should be the number of ranges plus one, the last
2214!> element being the block index plus one of the last block in the last range.
2215!> For internal load balancing optimizations, optionally specify the index
2216!> ranges of batched contraction.
2217!> \author Patrick Seewald
2218! **************************************************************************************************
2219 SUBROUTINE dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2220 nodata, pgrid_changed, unit_nr)
2221 TYPE(dbt_type), INTENT(INOUT) :: tensor
2222 TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
2223 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2224 !!
2225 LOGICAL, INTENT(IN), OPTIONAL :: nodata
2226 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2227 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2228 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_change_pgrid'
2229 CHARACTER(default_string_length) :: name
2230 INTEGER :: handle
2231 INTEGER, ALLOCATABLE, DIMENSION(:) :: bs_1, bs_2, bs_3, bs_4, &
2232 dist_1, dist_2, dist_3, dist_4
2233 INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2234 tdims
2235 TYPE(dbt_type) :: t_tmp
2236 TYPE(dbt_distribution_type) :: dist
2237 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2238 INTEGER, &
2239 DIMENSION(ndims_matrix_column(tensor)) :: map2
2240 LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware
2241 INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2242 INTEGER :: ind1, ind2, batch_size, ibatch
2243
2244 IF (PRESENT(pgrid_changed)) pgrid_changed = .false.
2245 CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2246 CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2247
2248 IF (all(pdims == pdims_ref)) THEN
2249 IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2250 IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2251 RETURN
2252 END IF
2253 END IF
2254 END IF
2255
2256 CALL timeset(routinen, handle)
2257
2258 IF (ndims_tensor(tensor) >= 1) THEN
2259 mem_aware(1) = PRESENT(batch_range_1)
2260 IF (mem_aware(1)) nbatch(1) = SIZE(batch_range_1) - 1
2261 END IF
2262 IF (ndims_tensor(tensor) >= 2) THEN
2263 mem_aware(2) = PRESENT(batch_range_2)
2264 IF (mem_aware(2)) nbatch(2) = SIZE(batch_range_2) - 1
2265 END IF
2266 IF (ndims_tensor(tensor) >= 3) THEN
2267 mem_aware(3) = PRESENT(batch_range_3)
2268 IF (mem_aware(3)) nbatch(3) = SIZE(batch_range_3) - 1
2269 END IF
2270 IF (ndims_tensor(tensor) >= 4) THEN
2271 mem_aware(4) = PRESENT(batch_range_4)
2272 IF (mem_aware(4)) nbatch(4) = SIZE(batch_range_4) - 1
2273 END IF
2274
2275 CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
2276
2277 IF (ndims_tensor(tensor) >= 1) THEN
2278 ALLOCATE (bs_1(dbt_nblks_total(tensor, 1)))
2279 CALL get_ith_array(tensor%blk_sizes, 1, bs_1)
2280 ALLOCATE (dist_1(tdims(1)))
2281 dist_1 = 0
2282 IF (mem_aware(1)) THEN
2283 DO ibatch = 1, nbatch(1)
2284 ind1 = batch_range_1(ibatch)
2285 ind2 = batch_range_1(ibatch + 1) - 1
2286 batch_size = ind2 - ind1 + 1
2287 CALL dbt_default_distvec(batch_size, pdims(1), &
2288 bs_1(ind1:ind2), dist_1(ind1:ind2))
2289 END DO
2290 ELSE
2291 CALL dbt_default_distvec(tdims(1), pdims(1), bs_1, dist_1)
2292 END IF
2293 END IF
2294 IF (ndims_tensor(tensor) >= 2) THEN
2295 ALLOCATE (bs_2(dbt_nblks_total(tensor, 2)))
2296 CALL get_ith_array(tensor%blk_sizes, 2, bs_2)
2297 ALLOCATE (dist_2(tdims(2)))
2298 dist_2 = 0
2299 IF (mem_aware(2)) THEN
2300 DO ibatch = 1, nbatch(2)
2301 ind1 = batch_range_2(ibatch)
2302 ind2 = batch_range_2(ibatch + 1) - 1
2303 batch_size = ind2 - ind1 + 1
2304 CALL dbt_default_distvec(batch_size, pdims(2), &
2305 bs_2(ind1:ind2), dist_2(ind1:ind2))
2306 END DO
2307 ELSE
2308 CALL dbt_default_distvec(tdims(2), pdims(2), bs_2, dist_2)
2309 END IF
2310 END IF
2311 IF (ndims_tensor(tensor) >= 3) THEN
2312 ALLOCATE (bs_3(dbt_nblks_total(tensor, 3)))
2313 CALL get_ith_array(tensor%blk_sizes, 3, bs_3)
2314 ALLOCATE (dist_3(tdims(3)))
2315 dist_3 = 0
2316 IF (mem_aware(3)) THEN
2317 DO ibatch = 1, nbatch(3)
2318 ind1 = batch_range_3(ibatch)
2319 ind2 = batch_range_3(ibatch + 1) - 1
2320 batch_size = ind2 - ind1 + 1
2321 CALL dbt_default_distvec(batch_size, pdims(3), &
2322 bs_3(ind1:ind2), dist_3(ind1:ind2))
2323 END DO
2324 ELSE
2325 CALL dbt_default_distvec(tdims(3), pdims(3), bs_3, dist_3)
2326 END IF
2327 END IF
2328 IF (ndims_tensor(tensor) >= 4) THEN
2329 ALLOCATE (bs_4(dbt_nblks_total(tensor, 4)))
2330 CALL get_ith_array(tensor%blk_sizes, 4, bs_4)
2331 ALLOCATE (dist_4(tdims(4)))
2332 dist_4 = 0
2333 IF (mem_aware(4)) THEN
2334 DO ibatch = 1, nbatch(4)
2335 ind1 = batch_range_4(ibatch)
2336 ind2 = batch_range_4(ibatch + 1) - 1
2337 batch_size = ind2 - ind1 + 1
2338 CALL dbt_default_distvec(batch_size, pdims(4), &
2339 bs_4(ind1:ind2), dist_4(ind1:ind2))
2340 END DO
2341 ELSE
2342 CALL dbt_default_distvec(tdims(4), pdims(4), bs_4, dist_4)
2343 END IF
2344 END IF
2345
2346 CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2347 IF (ndims_tensor(tensor) == 2) THEN
2348 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2)
2349 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2)
2350 END IF
2351 IF (ndims_tensor(tensor) == 3) THEN
2352 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3)
2353 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3)
2354 END IF
2355 IF (ndims_tensor(tensor) == 4) THEN
2356 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3, dist_4)
2357 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3, bs_4)
2358 END IF
2359 CALL dbt_distribution_destroy(dist)
2360
2361 IF (PRESENT(nodata)) THEN
2362 IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2363 ELSE
2364 CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2365 END IF
2366
2367 CALL dbt_copy_contraction_storage(tensor, t_tmp)
2368
2369 CALL dbt_destroy(tensor)
2370 tensor = t_tmp
2371
2372 IF (PRESENT(unit_nr)) THEN
2373 IF (unit_nr > 0) THEN
2374 WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", trim(tensor%name)
2375 WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2376 CALL dbt_write_split_info(pgrid, unit_nr)
2377 END IF
2378 END IF
2379
2380 IF (PRESENT(pgrid_changed)) pgrid_changed = .true.
2381
2382 CALL timestop(handle)
2383 END SUBROUTINE
2384
2385! **************************************************************************************************
2386!> \brief map tensor to a new 2d process grid for the matrix representation.
2387!> \author Patrick Seewald
2388! **************************************************************************************************
2389 SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2390 TYPE(dbt_type), INTENT(INOUT) :: tensor
2391 TYPE(mp_cart_type), INTENT(IN) :: mp_comm
2392 INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2393 LOGICAL, INTENT(IN), OPTIONAL :: nodata
2394 INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2395 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2396 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2397 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2398 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2399 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2400 TYPE(dbt_pgrid_type) :: pgrid
2401 INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2402 INTEGER, DIMENSION(:), ALLOCATABLE :: array
2403 INTEGER :: idim
2404
2405 CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2406 CALL blk_dims_tensor(tensor, dims)
2407
2408 IF (ALLOCATED(tensor%contraction_storage)) THEN
2409 associate(batch_ranges => tensor%contraction_storage%batch_ranges)
2410 nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2411 ! for good load balancing the process grid dimensions should be chosen adapted to the
2412 ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2413 ! the number of batches (number of index ranges).
2414 DO idim = 1, ndims_tensor(tensor)
2415 CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2416 dims(idim) = array(nbatches(idim) + 1) - array(1)
2417 DEALLOCATE (array)
2418 dims(idim) = dims(idim)/nbatches(idim)
2419 IF (dims(idim) <= 0) dims(idim) = 1
2420 END DO
2421 END associate
2422 END IF
2423
2424 pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2425 IF (ALLOCATED(tensor%contraction_storage)) THEN
2426 IF (ndims_tensor(tensor) == 1) THEN
2427 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1)
2428 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, &
2429 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2430 END IF
2431 IF (ndims_tensor(tensor) == 2) THEN
2432 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2)
2433 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, &
2434 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2435 END IF
2436 IF (ndims_tensor(tensor) == 3) THEN
2437 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3)
2438 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, &
2439 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2440 END IF
2441 IF (ndims_tensor(tensor) == 4) THEN
2442 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2443 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2444 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2445 END IF
2446 ELSE
2447 CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2448 END IF
2449 CALL dbt_pgrid_destroy(pgrid)
2450
2451 END SUBROUTINE
2452
2453END MODULE
struct tensor_ tensor
logical function, public dbcsr_has_symmetry(matrix)
...
subroutine, public dbcsr_iterator_next_block(iterator, row, column, block, block_number_argument_has_been_removed, row_size, col_size, row_offset, col_offset)
...
logical function, public dbcsr_iterator_blocks_left(iterator)
...
subroutine, public dbcsr_iterator_stop(iterator)
...
subroutine, public dbcsr_desymmetrize(matrix_a, matrix_b)
...
subroutine, public dbcsr_iterator_start(iterator, matrix, shared, dynamic, dynamic_byrows)
...
subroutine, public dbcsr_release(matrix)
...
subroutine, public dbcsr_clear(matrix)
...
subroutine, public dbcsr_put_block(matrix, row, col, block, summation)
...
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Definition dbm_api.F:529
Wrapper for allocating, copying and reshaping arrays.
Representation of arbitrary number of 1d integer arrays with arbitrary sizes. This is needed for gene...
subroutine, public get_arrays(list, data_1, data_2, data_3, data_4, i_selected)
Get all arrays contained in list.
subroutine, public create_array_list(list, ndata, data_1, data_2, data_3, data_4)
collects any number of arrays of different sizes into a single array (listcol_data),...
subroutine, public destroy_array_list(list)
destroy array list.
integer function, dimension(:), allocatable, public sizes_of_arrays(list)
sizes of arrays stored in list
type(array_list) function, public array_sublist(list, i_selected)
extract a subset of arrays
subroutine, public reorder_arrays(list_in, list_out, order)
reorder array list.
logical function, public check_equal(list1, list2)
check whether two array lists are equal
Methods to operate on n-dimensional tensor blocks.
Definition dbt_block.F:12
elemental logical function, public checker_tr(row, column)
Determines whether a transpose must be applied.
Definition dbt_block.F:453
logical function, public dbt_iterator_blocks_left(iterator)
Generalization of block_iterator_blocks_left for tensors.
Definition dbt_block.F:197
subroutine, public destroy_block(block)
Definition dbt_block.F:435
pure integer function, public ndims_iterator(iterator)
Number of dimensions.
Definition dbt_block.F:146
subroutine, public dbt_iterator_stop(iterator)
Generalization of block_iterator_stop for tensors.
Definition dbt_block.F:134
subroutine, public dbt_iterator_start(iterator, tensor)
Generalization of block_iterator_start for tensors.
Definition dbt_block.F:121
subroutine, public dbt_iterator_next_block(iterator, ind_nd, blk_size, blk_offset)
iterate over nd blocks of an nd rank tensor, index only (blocks must be retrieved by calling dbt_get_...
Definition dbt_block.F:161
tensor index and mapping to DBM index
Definition dbt_index.F:12
pure integer function, public ndims_mapping_row(map)
how many tensor dimensions are mapped to matrix row
Definition dbt_index.F:141
pure integer function, dimension(size(order)), public dbt_inverse_order(order)
Invert order.
Definition dbt_index.F:410
pure integer function, public ndims_mapping(map)
Definition dbt_index.F:130
subroutine, public permute_index(map_in, map_out, order)
reorder tensor index (no data)
Definition dbt_index.F:423
pure subroutine, public dbt_get_mapping_info(map, ndim_nd, ndim1_2d, ndim2_2d, dims_2d_i8, dims_2d, dims_nd, dims1_2d, dims2_2d, map1_2d, map2_2d, map_nd, base, col_major)
get mapping info
Definition dbt_index.F:176
pure integer function, public ndims_mapping_column(map)
how many tensor dimensions are mapped to matrix column
Definition dbt_index.F:151
pure integer function, dimension(map%ndim_nd), public get_nd_indices_tensor(map, ind_in)
transform 2d index to nd index, using info from index mapping.
Definition dbt_index.F:368
DBT tensor Input / Output.
Definition dbt_io.F:12
subroutine, public dbt_write_tensor_info(tensor, unit_nr, full_info)
Write tensor global info: block dimensions, full dimensions and process grid dimensions.
Definition dbt_io.F:50
subroutine, public dbt_write_tensor_dist(tensor, unit_nr)
Write info on tensor distribution & load balance.
Definition dbt_io.F:161
subroutine, public dbt_write_split_info(pgrid, unit_nr)
Definition dbt_io.F:401
integer function, public prep_output_unit(unit_nr)
Definition dbt_io.F:413
DBT tensor framework for block-sparse tensor contraction. Representation of n-rank tensors as DBT tal...
Definition dbt_methods.F:16
subroutine, public dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
Copy tensor data. Redistributes tensor data according to distributions of target and source tensor....
subroutine, public dbt_batched_contract_finalize(tensor, unit_nr)
finalize batched contraction. This performs all communication that has been postponed in the contract...
subroutine, public dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
copy matrix to tensor.
subroutine, public dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, contract_1, notcontract_1, contract_2, notcontract_2, map_1, map_2, bounds_1, bounds_2, bounds_3, optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
Contract tensors by multiplying matrix representations. tensor_3(map_1, map_2) := alpha * tensor_1(no...
subroutine, public dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
copy tensor to matrix
subroutine, public dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
Initialize batched contraction for this tensor.
Routines to reshape / redistribute tensors.
subroutine, public dbt_reshape(tensor_in, tensor_out, summation, move_data)
copy data (involves reshape) tensor_out = tensor_out + tensor_in move_data memory optimization: trans...
Routines to split blocks and to convert between tensors with different block sizes.
Definition dbt_split.F:12
subroutine, public dbt_split_copyback(tensor_split_in, tensor_out, summation)
Copy tensor with split blocks to tensor with original block sizes.
Definition dbt_split.F:527
subroutine, public dbt_make_compatible_blocks(tensor1, tensor2, tensor1_split, tensor2_split, order, nodata1, nodata2, move_data)
split two tensors with same total sizes but different block sizes such that they have equal block siz...
Definition dbt_split.F:788
subroutine, public dbt_crop(tensor_in, tensor_out, bounds, move_data)
Definition dbt_split.F:934
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
subroutine, public dbt_tas_finalize(matrix)
...
type(dbt_tas_split_info) function, pointer, public dbt_tas_info(matrix)
get info on mpi grid splitting
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
Definition dbt_tas_mm.F:18
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
subroutine, public dbt_tas_batched_mm_init(matrix)
...
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
Definition dbt_tas_mm.F:105
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
integer, parameter, public colsplit
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
real(dp), parameter, public default_nsplit_accept_ratio
real(dp), parameter, public default_pdims_accept_ratio
subroutine, public dbt_tas_info_hold(split_info)
...
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
Definition dbt_types.F:12
subroutine, public dbt_pgrid_destroy(pgrid, keep_comm)
destroy process grid
Definition dbt_types.F:894
subroutine, public dbt_distribution_new(dist, pgrid, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4)
Create a tensor distribution.
Definition dbt_types.F:875
subroutine, public blk_dims_tensor(tensor, dims)
tensor block dimensions
Definition dbt_types.F:1455
subroutine, public dims_tensor(tensor, dims)
tensor dimensions
Definition dbt_types.F:1227
subroutine, public dbt_copy_contraction_storage(tensor_in, tensor_out)
Definition dbt_types.F:1877
type(dbt_pgrid_type) function, public dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, nsplit, dimsplit)
Create a default nd process topology that is consistent with a given 2d topology. Purpose: a nd tenso...
Definition dbt_types.F:653
subroutine, public dbt_destroy(tensor)
Destroy a tensor.
Definition dbt_types.F:1399
pure integer function, public dbt_max_nblks_local(tensor)
returns an estimate of maximum number of local blocks in tensor (irrespective of the actual number of...
Definition dbt_types.F:1839
subroutine, public dbt_get_info(tensor, nblks_total, nfull_total, nblks_local, nfull_local, pdims, my_ploc, blks_local_1, blks_local_2, blks_local_3, blks_local_4, proc_dist_1, proc_dist_2, proc_dist_3, proc_dist_4, blk_size_1, blk_size_2, blk_size_3, blk_size_4, blk_offset_1, blk_offset_2, blk_offset_3, blk_offset_4, distribution, name)
As block_get_info but for tensors.
Definition dbt_types.F:1645
subroutine, public dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm)
Create a tensor distribution.
Definition dbt_types.F:787
type(dbt_distribution_type) function, public dbt_distribution(tensor)
get distribution from tensor
Definition dbt_types.F:969
pure integer function, public ndims_tensor(tensor)
tensor rank
Definition dbt_types.F:1216
pure integer function, public dbt_nblks_total(tensor, idim)
total numbers of blocks along dimension idim
Definition dbt_types.F:1606
pure integer function, public dbt_get_num_blocks(tensor)
As block_get_num_blocks: get number of local blocks.
Definition dbt_types.F:1748
subroutine, public dbt_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
Definition dbt_types.F:1865
subroutine, public dbt_hold(tensor)
reference counting for tensors (only needed for communicator handle that must be freed when no longer...
Definition dbt_types.F:1177
subroutine, public dbt_clear(tensor)
Clear tensor (s.t. it does not contain any blocks)
Definition dbt_types.F:1768
subroutine, public dbt_finalize(tensor)
Finalize tensor, as block_finalize. This should be taken care of internally in DBT tensors,...
Definition dbt_types.F:1779
subroutine, public mp_environ_pgrid(pgrid, dims, task_coor)
as mp_environ but for special pgrid type
Definition dbt_types.F:768
subroutine, public dbt_get_stored_coordinates(tensor, ind_nd, processor)
Generalization of block_get_stored_coordinates for tensors.
Definition dbt_types.F:1499
integer(kind=int_8) function, public dbt_get_num_blocks_total(tensor)
Get total number of blocks.
Definition dbt_types.F:1758
pure integer(int_8) function, public ndims_matrix_row(tensor)
how many tensor dimensions are mapped to matrix row
Definition dbt_types.F:1193
pure integer(int_8) function, public ndims_matrix_column(tensor)
how many tensor dimensions are mapped to matrix column
Definition dbt_types.F:1205
subroutine, public dbt_filter(tensor, eps)
As block_filter.
Definition dbt_types.F:1577
subroutine, public dbt_distribution_destroy(dist)
Destroy tensor distribution.
Definition dbt_types.F:915
subroutine, public dbt_scale(tensor, alpha)
as block_scale
Definition dbt_types.F:1788
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
Interface to the message passing library MPI.
All kind of helpful little routines.
Definition util.F:14