(git:374b731)
Loading...
Searching...
No Matches
dbt_methods.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief DBT tensor framework for block-sparse tensor contraction.
10!> Representation of n-rank tensors as DBT tall-and-skinny matrices.
11!> Support for arbitrary redistribution between different representations.
12!> Support for arbitrary tensor contractions
13!> \todo implement checks and error messages
14!> \author Patrick Seewald
15! **************************************************************************************************
17
18
19 USE dbcsr_api, ONLY: &
20 dbcsr_type, dbcsr_release, &
21 dbcsr_iterator_type, dbcsr_iterator_start, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
22 dbcsr_has_symmetry, dbcsr_desymmetrize, dbcsr_put_block, dbcsr_clear, dbcsr_iterator_stop
23 USE dbt_allocate_wrap, ONLY: &
25 USE dbt_array_list_methods, ONLY: &
28 USE dbm_api, ONLY: &
30 USE dbt_tas_types, ONLY: &
32 USE dbt_tas_base, ONLY: &
34 USE dbt_tas_mm, ONLY: &
37 USE dbt_block, ONLY: &
41 USE dbt_index, ONLY: &
44 USE dbt_types, ONLY: &
53 USE kinds, ONLY: &
55 USE message_passing, ONLY: &
57 USE util, ONLY: &
58 sort
59 USE dbt_reshape_ops, ONLY: &
61 USE dbt_tas_split, ONLY: &
64 USE dbt_split, ONLY: &
66 USE dbt_io, ONLY: &
69
70#include "../base/base_uses.f90"
71
72 IMPLICIT NONE
73 PRIVATE
74 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_methods'
75
76 PUBLIC :: &
78 dbt_copy, &
93
94CONTAINS
95
96! **************************************************************************************************
97!> \brief Copy tensor data.
98!> Redistributes tensor data according to distributions of target and source tensor.
99!> Permutes tensor index according to `order` argument (if present).
100!> Source and target tensor formats are arbitrary as long as the following requirements are met:
101!> * source and target tensors have the same rank and the same sizes in each dimension in terms
102!> of tensor elements (block sizes don't need to be the same).
103!> If `order` argument is present, sizes must match after index permutation.
104!> OR
105!> * target tensor is not yet created, in this case an exact copy of source tensor is returned.
106!> \param tensor_in Source
107!> \param tensor_out Target
108!> \param order Permutation of target tensor index.
109!> Exact same convention as order argument of RESHAPE intrinsic.
110!> \param bounds crop tensor data: start and end index for each tensor dimension
111!> \author Patrick Seewald
112! **************************************************************************************************
113 SUBROUTINE dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
114 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
115 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
116 INTENT(IN), OPTIONAL :: order
117 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
118 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
119 INTENT(IN), OPTIONAL :: bounds
120 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
121 INTEGER :: handle
122
123 CALL tensor_in%pgrid%mp_comm_2d%sync()
124 CALL timeset("dbt_total", handle)
125
126 ! make sure that it is safe to use dbt_copy during a batched contraction
127 CALL dbt_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.true.)
128 CALL dbt_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.true.)
129
130 CALL dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
131 CALL tensor_in%pgrid%mp_comm_2d%sync()
132 CALL timestop(handle)
133 END SUBROUTINE
134
135! **************************************************************************************************
136!> \brief expert routine for copying a tensor. For internal use only.
137!> \author Patrick Seewald
138! **************************************************************************************************
139 SUBROUTINE dbt_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
140 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out
141 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
142 INTENT(IN), OPTIONAL :: order
143 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data
144 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
145 INTENT(IN), OPTIONAL :: bounds
146 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
147
148 TYPE(dbt_type), POINTER :: in_tmp_1, in_tmp_2, &
149 in_tmp_3, out_tmp_1
150 INTEGER :: handle, unit_nr_prv
151 INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
152
153 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy'
154 LOGICAL :: dist_compatible_tas, dist_compatible_tensor, &
155 summation_prv, new_in_1, new_in_2, &
156 new_in_3, new_out_1, block_compatible, &
157 move_prv
158 TYPE(array_list) :: blk_sizes_in
159
160 CALL timeset(routinen, handle)
161
162 cpassert(tensor_out%valid)
163
164 unit_nr_prv = prep_output_unit(unit_nr)
165
166 IF (PRESENT(move_data)) THEN
167 move_prv = move_data
168 ELSE
169 move_prv = .false.
170 END IF
171
172 dist_compatible_tas = .false.
173 dist_compatible_tensor = .false.
174 block_compatible = .false.
175 new_in_1 = .false.
176 new_in_2 = .false.
177 new_in_3 = .false.
178 new_out_1 = .false.
179
180 IF (PRESENT(summation)) THEN
181 summation_prv = summation
182 ELSE
183 summation_prv = .false.
184 END IF
185
186 IF (PRESENT(bounds)) THEN
187 ALLOCATE (in_tmp_1)
188 CALL dbt_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
189 new_in_1 = .true.
190 move_prv = .true.
191 ELSE
192 in_tmp_1 => tensor_in
193 END IF
194
195 IF (PRESENT(order)) THEN
196 CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
197 block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
198 ELSE
199 block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
200 END IF
201
202 IF (.NOT. block_compatible) THEN
203 ALLOCATE (in_tmp_2, out_tmp_1)
204 CALL dbt_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
205 nodata2=.NOT. summation_prv, move_data=move_prv)
206 new_in_2 = .true.; new_out_1 = .true.
207 move_prv = .true.
208 ELSE
209 in_tmp_2 => in_tmp_1
210 out_tmp_1 => tensor_out
211 END IF
212
213 IF (PRESENT(order)) THEN
214 ALLOCATE (in_tmp_3)
215 CALL dbt_permute_index(in_tmp_2, in_tmp_3, order)
216 new_in_3 = .true.
217 ELSE
218 in_tmp_3 => in_tmp_2
219 END IF
220
221 ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
222 ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
223 CALL dbt_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
224
225 ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
226 ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
227 CALL dbt_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
228
229 IF (.NOT. PRESENT(order)) THEN
230 IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
231 dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
232 ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
233 dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234 END IF
235 END IF
236
237 IF (dist_compatible_tas) THEN
238 CALL dbt_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
239 IF (move_prv) CALL dbt_clear(in_tmp_3)
240 ELSEIF (dist_compatible_tensor) THEN
241 CALL dbt_copy_nocomm(in_tmp_3, out_tmp_1, summation)
242 IF (move_prv) CALL dbt_clear(in_tmp_3)
243 ELSE
244 CALL dbt_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
245 END IF
246
247 IF (new_in_1) THEN
248 CALL dbt_destroy(in_tmp_1)
249 DEALLOCATE (in_tmp_1)
250 END IF
251
252 IF (new_in_2) THEN
253 CALL dbt_destroy(in_tmp_2)
254 DEALLOCATE (in_tmp_2)
255 END IF
256
257 IF (new_in_3) THEN
258 CALL dbt_destroy(in_tmp_3)
259 DEALLOCATE (in_tmp_3)
260 END IF
261
262 IF (new_out_1) THEN
263 IF (unit_nr_prv /= 0) THEN
264 CALL dbt_write_tensor_dist(out_tmp_1, unit_nr)
265 END IF
266 CALL dbt_split_copyback(out_tmp_1, tensor_out, summation)
267 CALL dbt_destroy(out_tmp_1)
268 DEALLOCATE (out_tmp_1)
269 END IF
270
271 CALL timestop(handle)
272
273 END SUBROUTINE
274
275! **************************************************************************************************
276!> \brief copy without communication, requires that both tensors have same process grid and distribution
277!> \param summation Whether to sum matrices b = a + b
278!> \author Patrick Seewald
279! **************************************************************************************************
280 SUBROUTINE dbt_copy_nocomm(tensor_in, tensor_out, summation)
281 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
282 TYPE(dbt_type), INTENT(INOUT) :: tensor_out
283 LOGICAL, INTENT(IN), OPTIONAL :: summation
284 TYPE(dbt_iterator_type) :: iter
285 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd
286 TYPE(block_nd) :: blk_data
287 LOGICAL :: found
288
289 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_nocomm'
290 INTEGER :: handle
291
292 CALL timeset(routinen, handle)
293 cpassert(tensor_out%valid)
294
295 IF (PRESENT(summation)) THEN
296 IF (.NOT. summation) CALL dbt_clear(tensor_out)
297 ELSE
298 CALL dbt_clear(tensor_out)
299 END IF
300
301 CALL dbt_reserve_blocks(tensor_in, tensor_out)
302
303!$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,tensor_out,summation) &
304!$OMP PRIVATE(iter,ind_nd,blk_data,found)
305 CALL dbt_iterator_start(iter, tensor_in)
306 DO WHILE (dbt_iterator_blocks_left(iter))
307 CALL dbt_iterator_next_block(iter, ind_nd)
308 CALL dbt_get_block(tensor_in, ind_nd, blk_data, found)
309 cpassert(found)
310 CALL dbt_put_block(tensor_out, ind_nd, blk_data, summation=summation)
311 CALL destroy_block(blk_data)
312 END DO
313 CALL dbt_iterator_stop(iter)
314!$OMP END PARALLEL
315
316 CALL timestop(handle)
317 END SUBROUTINE
318
319! **************************************************************************************************
320!> \brief copy matrix to tensor.
321!> \param summation tensor_out = tensor_out + matrix_in
322!> \author Patrick Seewald
323! **************************************************************************************************
324 SUBROUTINE dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
325 TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in
326 TYPE(dbt_type), INTENT(INOUT) :: tensor_out
327 LOGICAL, INTENT(IN), OPTIONAL :: summation
328 TYPE(dbcsr_type), POINTER :: matrix_in_desym
329
330 INTEGER, DIMENSION(2) :: ind_2d
331 REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: block_arr
332 REAL(kind=dp), DIMENSION(:, :), POINTER :: block
333 TYPE(dbcsr_iterator_type) :: iter
334 LOGICAL :: tr
335
336 INTEGER :: handle
337 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_matrix_to_tensor'
338
339 CALL timeset(routinen, handle)
340 cpassert(tensor_out%valid)
341
342 NULLIFY (block)
343
344 IF (dbcsr_has_symmetry(matrix_in)) THEN
345 ALLOCATE (matrix_in_desym)
346 CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
347 ELSE
348 matrix_in_desym => matrix_in
349 END IF
350
351 IF (PRESENT(summation)) THEN
352 IF (.NOT. summation) CALL dbt_clear(tensor_out)
353 ELSE
354 CALL dbt_clear(tensor_out)
355 END IF
356
357 CALL dbt_reserve_blocks(matrix_in_desym, tensor_out)
358
359!$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in_desym,tensor_out,summation) &
360!$OMP PRIVATE(iter,ind_2d,block,tr,block_arr)
361 CALL dbcsr_iterator_start(iter, matrix_in_desym)
362 DO WHILE (dbcsr_iterator_blocks_left(iter))
363 CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
364 CALL allocate_any(block_arr, source=block)
365 CALL dbt_put_block(tensor_out, ind_2d, shape(block_arr), block_arr, summation=summation)
366 DEALLOCATE (block_arr)
367 END DO
368 CALL dbcsr_iterator_stop(iter)
369!$OMP END PARALLEL
370
371 IF (dbcsr_has_symmetry(matrix_in)) THEN
372 CALL dbcsr_release(matrix_in_desym)
373 DEALLOCATE (matrix_in_desym)
374 END IF
375
376 CALL timestop(handle)
377
378 END SUBROUTINE
379
380! **************************************************************************************************
381!> \brief copy tensor to matrix
382!> \param summation matrix_out = matrix_out + tensor_in
383!> \author Patrick Seewald
384! **************************************************************************************************
385 SUBROUTINE dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
386 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
387 TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
388 LOGICAL, INTENT(IN), OPTIONAL :: summation
389 TYPE(dbt_iterator_type) :: iter
390 INTEGER :: handle
391 INTEGER, DIMENSION(2) :: ind_2d
392 REAL(kind=dp), DIMENSION(:, :), ALLOCATABLE :: block
393 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_copy_tensor_to_matrix'
394 LOGICAL :: found
395
396 CALL timeset(routinen, handle)
397
398 IF (PRESENT(summation)) THEN
399 IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
400 ELSE
401 CALL dbcsr_clear(matrix_out)
402 END IF
403
404 CALL dbt_reserve_blocks(tensor_in, matrix_out)
405
406!$OMP PARALLEL DEFAULT(NONE) SHARED(tensor_in,matrix_out,summation) &
407!$OMP PRIVATE(iter,ind_2d,block,found)
408 CALL dbt_iterator_start(iter, tensor_in)
409 DO WHILE (dbt_iterator_blocks_left(iter))
410 CALL dbt_iterator_next_block(iter, ind_2d)
411 IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) cycle
412
413 CALL dbt_get_block(tensor_in, ind_2d, block, found)
414 cpassert(found)
415
416 IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
417 CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), transpose(block), summation=summation)
418 ELSE
419 CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
420 END IF
421 DEALLOCATE (block)
422 END DO
423 CALL dbt_iterator_stop(iter)
424!$OMP END PARALLEL
425
426 CALL timestop(handle)
427
428 END SUBROUTINE
429
430! **************************************************************************************************
431!> \brief Contract tensors by multiplying matrix representations.
432!> tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
433!> * tensor_2(contract_2, notcontract_2)
434!> + beta * tensor_3(map_1, map_2)
435!>
436!> \note
437!> note 1: block sizes of the corresponding indices need to be the same in all tensors.
438!>
439!> note 2: for best performance the tensors should have been created in matrix layouts
440!> compatible with the contraction, e.g. tensor_1 should have been created with either
441!> map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
442!> map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
443!> tensor_3 and map_1 / map_2).
444!> Furthermore the two largest tensors involved in the contraction should map both to either
445!> tall or short matrices: the largest matrix dimension should be "on the same side"
446!> and should have identical distribution (which is always the case if the distributions were
447!> obtained with dbt_default_distvec).
448!>
449!> note 3: if the same tensor occurs in multiple contractions, a different tensor object should
450!> be created for each contraction and the data should be copied between the tensors by use of
451!> dbt_copy. If the same tensor object is used in multiple contractions,
452!> matrix layouts are not compatible for all contractions (see note 2).
453!>
454!> note 4: automatic optimizations are enabled by using the feature of batched contraction, see
455!> dbt_batched_contract_init, dbt_batched_contract_finalize.
456!> The arguments bounds_1, bounds_2, bounds_3 give the index ranges of the batches.
457!>
458!> \param tensor_1 first tensor (in)
459!> \param tensor_2 second tensor (in)
460!> \param contract_1 indices of tensor_1 to contract
461!> \param contract_2 indices of tensor_2 to contract (1:1 with contract_1)
462!> \param map_1 which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
463!> \param map_2 which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
464!> \param notcontract_1 indices of tensor_1 not to contract
465!> \param notcontract_2 indices of tensor_2 not to contract
466!> \param tensor_3 contracted tensor (out)
467!> \param bounds_1 bounds corresponding to contract_1 AKA contract_2:
468!> start and end index of an index range over which to contract.
469!> For use in batched contraction.
470!> \param bounds_2 bounds corresponding to notcontract_1: start and end index of an index range.
471!> For use in batched contraction.
472!> \param bounds_3 bounds corresponding to notcontract_2: start and end index of an index range.
473!> For use in batched contraction.
474!> \param optimize_dist Whether distribution should be optimized internally. In the current
475!> implementation this guarantees optimal parameters only for dense matrices.
476!> \param pgrid_opt_1 Optionally return optimal process grid for tensor_1.
477!> This can be used to choose optimal process grids for subsequent tensor
478!> contractions with tensors of similar shape and sparsity. Under some conditions,
479!> pgrid_opt_1 can not be returned, in this case the pointer is not associated.
480!> \param pgrid_opt_2 Optionally return optimal process grid for tensor_2.
481!> \param pgrid_opt_3 Optionally return optimal process grid for tensor_3.
482!> \param filter_eps As in DBM mm
483!> \param flop As in DBM mm
484!> \param move_data memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
485!> \param retain_sparsity enforce the sparsity pattern of the existing tensor_3; default is no
486!> \param unit_nr output unit for logging
487!> set it to -1 on ranks that should not write (and any valid unit number on
488!> ranks that should write output) if 0 on ALL ranks, no output is written
489!> \param log_verbose verbose logging (for testing only)
490!> \author Patrick Seewald
491! **************************************************************************************************
492 SUBROUTINE dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
493 contract_1, notcontract_1, &
494 contract_2, notcontract_2, &
495 map_1, map_2, &
496 bounds_1, bounds_2, bounds_3, &
497 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
498 filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
499 REAL(dp), INTENT(IN) :: alpha
500 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
501 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
502 REAL(dp), INTENT(IN) :: beta
503 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
504 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
505 INTEGER, DIMENSION(:), INTENT(IN) :: map_1
506 INTEGER, DIMENSION(:), INTENT(IN) :: map_2
507 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
508 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
509 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
510 INTEGER, DIMENSION(2, SIZE(contract_1)), &
511 INTENT(IN), OPTIONAL :: bounds_1
512 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
513 INTENT(IN), OPTIONAL :: bounds_2
514 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
515 INTENT(IN), OPTIONAL :: bounds_3
516 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
517 TYPE(dbt_pgrid_type), INTENT(OUT), &
518 POINTER, OPTIONAL :: pgrid_opt_1
519 TYPE(dbt_pgrid_type), INTENT(OUT), &
520 POINTER, OPTIONAL :: pgrid_opt_2
521 TYPE(dbt_pgrid_type), INTENT(OUT), &
522 POINTER, OPTIONAL :: pgrid_opt_3
523 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
524 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
525 LOGICAL, INTENT(IN), OPTIONAL :: move_data
526 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
527 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
528 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
529
530 INTEGER :: handle
531
532 CALL tensor_1%pgrid%mp_comm_2d%sync()
533 CALL timeset("dbt_total", handle)
534 CALL dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
535 contract_1, notcontract_1, &
536 contract_2, notcontract_2, &
537 map_1, map_2, &
538 bounds_1=bounds_1, &
539 bounds_2=bounds_2, &
540 bounds_3=bounds_3, &
541 optimize_dist=optimize_dist, &
542 pgrid_opt_1=pgrid_opt_1, &
543 pgrid_opt_2=pgrid_opt_2, &
544 pgrid_opt_3=pgrid_opt_3, &
545 filter_eps=filter_eps, &
546 flop=flop, &
547 move_data=move_data, &
548 retain_sparsity=retain_sparsity, &
549 unit_nr=unit_nr, &
550 log_verbose=log_verbose)
551 CALL tensor_1%pgrid%mp_comm_2d%sync()
552 CALL timestop(handle)
553
554 END SUBROUTINE
555
556! **************************************************************************************************
557!> \brief expert routine for tensor contraction. For internal use only.
558!> \param nblks_local number of local blocks on this MPI rank
559!> \author Patrick Seewald
560! **************************************************************************************************
561 SUBROUTINE dbt_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
562 contract_1, notcontract_1, &
563 contract_2, notcontract_2, &
564 map_1, map_2, &
565 bounds_1, bounds_2, bounds_3, &
566 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
567 filter_eps, flop, move_data, retain_sparsity, &
568 nblks_local, unit_nr, log_verbose)
569 REAL(dp), INTENT(IN) :: alpha
570 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_1
571 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_2
572 REAL(dp), INTENT(IN) :: beta
573 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1
574 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2
575 INTEGER, DIMENSION(:), INTENT(IN) :: map_1
576 INTEGER, DIMENSION(:), INTENT(IN) :: map_2
577 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1
578 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2
579 TYPE(dbt_type), INTENT(INOUT), TARGET :: tensor_3
580 INTEGER, DIMENSION(2, SIZE(contract_1)), &
581 INTENT(IN), OPTIONAL :: bounds_1
582 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
583 INTENT(IN), OPTIONAL :: bounds_2
584 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
585 INTENT(IN), OPTIONAL :: bounds_3
586 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
587 TYPE(dbt_pgrid_type), INTENT(OUT), &
588 POINTER, OPTIONAL :: pgrid_opt_1
589 TYPE(dbt_pgrid_type), INTENT(OUT), &
590 POINTER, OPTIONAL :: pgrid_opt_2
591 TYPE(dbt_pgrid_type), INTENT(OUT), &
592 POINTER, OPTIONAL :: pgrid_opt_3
593 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
594 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
595 LOGICAL, INTENT(IN), OPTIONAL :: move_data
596 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
597 INTEGER, INTENT(OUT), OPTIONAL :: nblks_local
598 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr
599 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
600
601 TYPE(dbt_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3
602 TYPE(dbt_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3
603 TYPE(dbt_type), POINTER :: tensor_crop_1, tensor_crop_2
604 TYPE(dbt_type), POINTER :: tensor_small, tensor_large
605
606 LOGICAL :: assert_stmt, tensors_remapped
607 INTEGER :: max_mm_dim, max_tensor, &
608 unit_nr_prv, ref_tensor, handle
609 TYPE(mp_cart_type) :: mp_comm_opt
610 INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod
611 INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod
612 INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod
613 INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod
614 INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod
615 INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod
616 LOGICAL :: trans_1, trans_2, trans_3
617 LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2
618 INTEGER :: ndims1, ndims2, ndims3
619 INTEGER :: occ_1, occ_2
620 INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3
621
622 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_contract'
623 CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, &
624 indchar2_mod, indchar3_mod
625 CHARACTER(LEN=1), DIMENSION(15), SAVE :: alph = &
626 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
627 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
628 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
629 LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
630 pgrid_changed_any, do_change_pgrid(2)
631 TYPE(dbt_tas_split_info) :: split_opt, split, split_opt_avg
632 INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_sub, pdims_sub_opt
633 REAL(dp) :: pdim_ratio, pdim_ratio_opt
634
635 NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
636 tensor_small)
637
638 CALL timeset(routinen, handle)
639
640 cpassert(tensor_1%valid)
641 cpassert(tensor_2%valid)
642 cpassert(tensor_3%valid)
643
644 assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
645 cpassert(assert_stmt)
646
647 assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
648 cpassert(assert_stmt)
649
650 assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
651 cpassert(assert_stmt)
652
653 assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
654 cpassert(assert_stmt)
655
656 assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
657 cpassert(assert_stmt)
658
659 assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
660 cpassert(assert_stmt)
661
662 unit_nr_prv = prep_output_unit(unit_nr)
663
664 IF (PRESENT(flop)) flop = 0
665 IF (PRESENT(nblks_local)) nblks_local = 0
666
667 IF (PRESENT(move_data)) THEN
668 move_data_1 = move_data
669 move_data_2 = move_data
670 ELSE
671 move_data_1 = .false.
672 move_data_2 = .false.
673 END IF
674
675 nodata_3 = .true.
676 IF (PRESENT(retain_sparsity)) THEN
677 IF (retain_sparsity) nodata_3 = .false.
678 END IF
679
680 CALL dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
681 contract_1, notcontract_1, &
682 contract_2, notcontract_2, &
683 bounds_t1, bounds_t2, &
684 bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
685 do_crop_1=do_crop_1, do_crop_2=do_crop_2)
686
687 IF (do_crop_1) THEN
688 ALLOCATE (tensor_crop_1)
689 CALL dbt_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
690 move_data_1 = .true.
691 ELSE
692 tensor_crop_1 => tensor_1
693 END IF
694
695 IF (do_crop_2) THEN
696 ALLOCATE (tensor_crop_2)
697 CALL dbt_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
698 move_data_2 = .true.
699 ELSE
700 tensor_crop_2 => tensor_2
701 END IF
702
703 ! shortcut for empty tensors
704 ! this is needed to avoid unnecessary work in case user contracts different portions of a
705 ! tensor consecutively to save memory
706 associate(mp_comm => tensor_crop_1%pgrid%mp_comm_2d)
707 occ_1 = dbt_get_num_blocks(tensor_crop_1)
708 CALL mp_comm%max(occ_1)
709 occ_2 = dbt_get_num_blocks(tensor_crop_2)
710 CALL mp_comm%max(occ_2)
711 END associate
712
713 IF (occ_1 == 0 .OR. occ_2 == 0) THEN
714 CALL dbt_scale(tensor_3, beta)
715 IF (do_crop_1) THEN
716 CALL dbt_destroy(tensor_crop_1)
717 DEALLOCATE (tensor_crop_1)
718 END IF
719 IF (do_crop_2) THEN
720 CALL dbt_destroy(tensor_crop_2)
721 DEALLOCATE (tensor_crop_2)
722 END IF
723
724 CALL timestop(handle)
725 RETURN
726 END IF
727
728 IF (unit_nr_prv /= 0) THEN
729 IF (unit_nr_prv > 0) THEN
730 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
731 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBT TENSOR CONTRACTION:", &
732 trim(tensor_crop_1%name), 'x', trim(tensor_crop_2%name), '=', trim(tensor_3%name)
733 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
734 END IF
735 CALL dbt_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
736 CALL dbt_write_tensor_dist(tensor_crop_1, unit_nr_prv)
737 CALL dbt_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
738 CALL dbt_write_tensor_dist(tensor_crop_2, unit_nr_prv)
739 END IF
740
741 ! align tensor index with data, tensor data is not modified
742 ndims1 = ndims_tensor(tensor_crop_1)
743 ndims2 = ndims_tensor(tensor_crop_2)
744 ndims3 = ndims_tensor(tensor_3)
745 ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
746 ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
747 ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
748
749 ! labeling tensor index with letters
750
751 indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
752 indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
753 indchar2(contract_2) = indchar1(contract_1)
754 indchar3(map_1) = indchar1(notcontract_1)
755 indchar3(map_2) = indchar2(notcontract_2)
756
757 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_crop_1, indchar1, &
758 tensor_crop_2, indchar2, &
759 tensor_3, indchar3, unit_nr_prv)
760 IF (unit_nr_prv > 0) THEN
761 WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
762 END IF
763
764 CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
765 tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
766
767 CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
768 tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
769
770 CALL align_tensor(tensor_3, map_1, map_2, &
771 tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
772
773 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_algn_1, indchar1_mod, &
774 tensor_algn_2, indchar2_mod, &
775 tensor_algn_3, indchar3_mod, unit_nr_prv)
776
777 ALLOCATE (dims1(ndims1))
778 ALLOCATE (dims2(ndims2))
779 ALLOCATE (dims3(ndims3))
780
781 ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
782 ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
783 CALL blk_dims_tensor(tensor_crop_1, dims1)
784 CALL blk_dims_tensor(tensor_crop_2, dims2)
785 CALL blk_dims_tensor(tensor_3, dims3)
786
787 max_mm_dim = maxloc([product(int(dims1(notcontract_1), int_8)), &
788 product(int(dims1(contract_1), int_8)), &
789 product(int(dims2(notcontract_2), int_8))], dim=1)
790 max_tensor = maxloc([product(int(dims1, int_8)), product(int(dims2, int_8)), product(int(dims3, int_8))], dim=1)
791 SELECT CASE (max_mm_dim)
792 CASE (1)
793 IF (unit_nr_prv > 0) THEN
794 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
795 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
796 END IF
797 CALL index_linked_sort(contract_1_mod, contract_2_mod)
798 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
799 SELECT CASE (max_tensor)
800 CASE (1)
801 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
802 CASE (3)
803 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
804 CASE DEFAULT
805 cpabort("should not happen")
806 END SELECT
807
808 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
809 contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
810 trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
811 move_data_1=move_data_1, unit_nr=unit_nr_prv)
812
813 CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
814 new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
815
816 SELECT CASE (ref_tensor)
817 CASE (1)
818 tensor_large => tensor_contr_1
819 CASE (2)
820 tensor_large => tensor_contr_3
821 END SELECT
822 tensor_small => tensor_contr_2
823
824 CASE (2)
825 IF (unit_nr_prv > 0) THEN
826 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
827 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
828 END IF
829
830 CALL index_linked_sort(notcontract_1_mod, map_1_mod)
831 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
832 SELECT CASE (max_tensor)
833 CASE (1)
834 CALL index_linked_sort(contract_1_mod, contract_2_mod)
835 CASE (2)
836 CALL index_linked_sort(contract_2_mod, contract_1_mod)
837 CASE DEFAULT
838 cpabort("should not happen")
839 END SELECT
840
841 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
842 notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
843 trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
844 move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
845 trans_1 = .NOT. trans_1
846
847 CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
848 new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
849
850 SELECT CASE (ref_tensor)
851 CASE (1)
852 tensor_large => tensor_contr_1
853 CASE (2)
854 tensor_large => tensor_contr_2
855 END SELECT
856 tensor_small => tensor_contr_3
857
858 CASE (3)
859 IF (unit_nr_prv > 0) THEN
860 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
861 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
862 END IF
863 CALL index_linked_sort(map_1_mod, notcontract_1_mod)
864 CALL index_linked_sort(contract_2_mod, contract_1_mod)
865 SELECT CASE (max_tensor)
866 CASE (2)
867 CALL index_linked_sort(notcontract_2_mod, map_2_mod)
868 CASE (3)
869 CALL index_linked_sort(map_2_mod, notcontract_2_mod)
870 CASE DEFAULT
871 cpabort("should not happen")
872 END SELECT
873
874 CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
875 contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
876 trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
877 move_data_1=move_data_2, unit_nr=unit_nr_prv)
878
879 trans_2 = .NOT. trans_2
880 trans_3 = .NOT. trans_3
881
882 CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
883 trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
884
885 SELECT CASE (ref_tensor)
886 CASE (1)
887 tensor_large => tensor_contr_2
888 CASE (2)
889 tensor_large => tensor_contr_3
890 END SELECT
891 tensor_small => tensor_contr_1
892
893 END SELECT
894
895 IF (unit_nr_prv /= 0) CALL dbt_print_contraction_index(tensor_contr_1, indchar1_mod, &
896 tensor_contr_2, indchar2_mod, &
897 tensor_contr_3, indchar3_mod, unit_nr_prv)
898 IF (unit_nr_prv /= 0) THEN
899 IF (new_1) CALL dbt_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
900 IF (new_1) CALL dbt_write_tensor_dist(tensor_contr_1, unit_nr_prv)
901 IF (new_2) CALL dbt_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
902 IF (new_2) CALL dbt_write_tensor_dist(tensor_contr_2, unit_nr_prv)
903 END IF
904
905 CALL dbt_tas_multiply(trans_1, trans_2, trans_3, alpha, &
906 tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
907 beta, &
908 tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
909 unit_nr=unit_nr_prv, log_verbose=log_verbose, &
910 split_opt=split_opt, &
911 move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
912
913 IF (PRESENT(pgrid_opt_1)) THEN
914 IF (.NOT. new_1) THEN
915 ALLOCATE (pgrid_opt_1)
916 pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
917 END IF
918 END IF
919
920 IF (PRESENT(pgrid_opt_2)) THEN
921 IF (.NOT. new_2) THEN
922 ALLOCATE (pgrid_opt_2)
923 pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
924 END IF
925 END IF
926
927 IF (PRESENT(pgrid_opt_3)) THEN
928 IF (.NOT. new_3) THEN
929 ALLOCATE (pgrid_opt_3)
930 pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
931 END IF
932 END IF
933
934 do_batched = tensor_small%matrix_rep%do_batched > 0
935
936 tensors_remapped = .false.
937 IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .true.
938
939 IF (tensors_remapped .AND. do_batched) THEN
940 CALL cp_warn(__location__, &
941 "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
942 END IF
943
944 ! optimize process grid during batched contraction
945 do_change_pgrid(:) = .false.
946 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
947 associate(storage => tensor_small%contraction_storage)
948 cpassert(storage%static)
949 split = dbt_tas_info(tensor_large%matrix_rep)
950 do_change_pgrid(:) = &
951 update_contraction_storage(storage, split_opt, split)
952
953 IF (any(do_change_pgrid)) THEN
954 mp_comm_opt = dbt_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, nint(storage%nsplit_avg))
955 CALL dbt_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
956 nint(storage%nsplit_avg), own_comm=.true.)
957 pdims_2d_opt = split_opt_avg%mp_comm%num_pe_cart
958 END IF
959
960 END associate
961
962 IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
963 ! check if new grid has better subgrid, if not there is no need to change process grid
964 pdims_sub_opt = split_opt_avg%mp_comm_group%num_pe_cart
965 pdims_sub = split%mp_comm_group%num_pe_cart
966
967 pdim_ratio = maxval(real(pdims_sub, dp))/minval(pdims_sub)
968 pdim_ratio_opt = maxval(real(pdims_sub_opt, dp))/minval(pdims_sub_opt)
969 IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
970 do_change_pgrid(1) = .false.
971 CALL dbt_tas_release_info(split_opt_avg)
972 END IF
973 END IF
974 END IF
975
976 IF (unit_nr_prv /= 0) THEN
977 do_write_3 = .true.
978 IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
979 IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .false.
980 END IF
981 IF (do_write_3) THEN
982 CALL dbt_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
983 CALL dbt_write_tensor_dist(tensor_contr_3, unit_nr_prv)
984 END IF
985 END IF
986
987 IF (new_3) THEN
988 ! need redistribute if we created new tensor for tensor 3
989 CALL dbt_scale(tensor_algn_3, beta)
990 CALL dbt_copy_expert(tensor_contr_3, tensor_algn_3, summation=.true., move_data=.true.)
991 IF (PRESENT(filter_eps)) CALL dbt_filter(tensor_algn_3, filter_eps)
992 ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
993 ! pointer to data of tensor_3
994 END IF
995
996 ! transfer contraction storage
997 CALL dbt_copy_contraction_storage(tensor_contr_1, tensor_1)
998 CALL dbt_copy_contraction_storage(tensor_contr_2, tensor_2)
999 CALL dbt_copy_contraction_storage(tensor_contr_3, tensor_3)
1000
1001 IF (unit_nr_prv /= 0) THEN
1002 IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1003 IF (new_3 .AND. do_write_3) CALL dbt_write_tensor_dist(tensor_3, unit_nr_prv)
1004 END IF
1005
1006 CALL dbt_destroy(tensor_algn_1)
1007 CALL dbt_destroy(tensor_algn_2)
1008 CALL dbt_destroy(tensor_algn_3)
1009
1010 IF (do_crop_1) THEN
1011 CALL dbt_destroy(tensor_crop_1)
1012 DEALLOCATE (tensor_crop_1)
1013 END IF
1014
1015 IF (do_crop_2) THEN
1016 CALL dbt_destroy(tensor_crop_2)
1017 DEALLOCATE (tensor_crop_2)
1018 END IF
1019
1020 IF (new_1) THEN
1021 CALL dbt_destroy(tensor_contr_1)
1022 DEALLOCATE (tensor_contr_1)
1023 END IF
1024 IF (new_2) THEN
1025 CALL dbt_destroy(tensor_contr_2)
1026 DEALLOCATE (tensor_contr_2)
1027 END IF
1028 IF (new_3) THEN
1029 CALL dbt_destroy(tensor_contr_3)
1030 DEALLOCATE (tensor_contr_3)
1031 END IF
1032
1033 IF (PRESENT(move_data)) THEN
1034 IF (move_data) THEN
1035 CALL dbt_clear(tensor_1)
1036 CALL dbt_clear(tensor_2)
1037 END IF
1038 END IF
1039
1040 IF (unit_nr_prv > 0) THEN
1041 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1042 WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1043 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1044 END IF
1045
1046 IF (any(do_change_pgrid)) THEN
1047 pgrid_changed_any = .false.
1048 SELECT CASE (max_mm_dim)
1049 CASE (1)
1050 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1051 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1052 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1053 pgrid_changed=pgrid_changed, &
1054 unit_nr=unit_nr_prv)
1055 IF (pgrid_changed) pgrid_changed_any = .true.
1056 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1057 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1058 pgrid_changed=pgrid_changed, &
1059 unit_nr=unit_nr_prv)
1060 IF (pgrid_changed) pgrid_changed_any = .true.
1061 END IF
1062 IF (pgrid_changed_any) THEN
1063 IF (tensor_2%matrix_rep%do_batched == 3) THEN
1064 ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1065 ! in TAS multiply algorithm
1066 CALL dbt_tas_batched_mm_complete(tensor_2%matrix_rep)
1067 END IF
1068 END IF
1069 CASE (2)
1070 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1071 CALL dbt_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1072 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1073 pgrid_changed=pgrid_changed, &
1074 unit_nr=unit_nr_prv)
1075 IF (pgrid_changed) pgrid_changed_any = .true.
1076 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1077 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1078 pgrid_changed=pgrid_changed, &
1079 unit_nr=unit_nr_prv)
1080 IF (pgrid_changed) pgrid_changed_any = .true.
1081 END IF
1082 IF (pgrid_changed_any) THEN
1083 IF (tensor_3%matrix_rep%do_batched == 3) THEN
1084 CALL dbt_tas_batched_mm_complete(tensor_3%matrix_rep)
1085 END IF
1086 END IF
1087 CASE (3)
1088 IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1089 CALL dbt_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1090 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1091 pgrid_changed=pgrid_changed, &
1092 unit_nr=unit_nr_prv)
1093 IF (pgrid_changed) pgrid_changed_any = .true.
1094 CALL dbt_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1095 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1096 pgrid_changed=pgrid_changed, &
1097 unit_nr=unit_nr_prv)
1098 IF (pgrid_changed) pgrid_changed_any = .true.
1099 END IF
1100 IF (pgrid_changed_any) THEN
1101 IF (tensor_1%matrix_rep%do_batched == 3) THEN
1102 CALL dbt_tas_batched_mm_complete(tensor_1%matrix_rep)
1103 END IF
1104 END IF
1105 END SELECT
1106 CALL dbt_tas_release_info(split_opt_avg)
1107 END IF
1108
1109 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1110 ! freeze TAS process grids if tensor grids were optimized
1111 CALL dbt_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.true.)
1112 CALL dbt_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.true.)
1113 CALL dbt_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.true.)
1114 END IF
1115
1116 CALL dbt_tas_release_info(split_opt)
1117
1118 CALL timestop(handle)
1119
1120 END SUBROUTINE
1121
1122! **************************************************************************************************
1123!> \brief align tensor index with data
1124!> \author Patrick Seewald
1125! **************************************************************************************************
1126 SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1127 tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1128 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1129 INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in
1130 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1131 INTEGER, DIMENSION(SIZE(contract_in)), &
1132 INTENT(OUT) :: contract_out
1133 INTEGER, DIMENSION(SIZE(notcontract_in)), &
1134 INTENT(OUT) :: notcontract_out
1135 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1136 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1137 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1138
1139 CALL dbt_align_index(tensor_in, tensor_out, order=align)
1140 contract_out = align(contract_in)
1141 notcontract_out = align(notcontract_in)
1142 indp_out(align) = indp_in
1143
1144 END SUBROUTINE
1145
1146! **************************************************************************************************
1147!> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1148!> matrix multiplication. This routine reshapes the two largest of the three tensors.
1149!> Redistribution is avoided if tensors already in a consistent layout.
1150!> \param ind1_free indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1151!> \param ind1_linked indices of tensor 1 that are linked to indices of tensor 2
1152!> 1:1 correspondence with ind1_linked
1153!> \param trans1 transpose flag of matrix rep. tensor 1
1154!> \param trans2 transpose flag of matrix rep. tensor 2
1155!> \param new1 whether a new tensor 1 was created
1156!> \param new2 whether a new tensor 2 was created
1157!> \param nodata1 don't copy data of tensor 1
1158!> \param nodata2 don't copy data of tensor 2
1159!> \param move_data_1 memory optimization: transfer data s.t. tensor1 may be empty on return
1160!> \param move_data_2 memory optimization: transfer data s.t. tensor2 may be empty on return
1161!> \param optimize_dist experimental: optimize distribution
1162!> \param unit_nr output unit
1163!> \author Patrick Seewald
1164! **************************************************************************************************
1165 SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1166 ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1167 nodata1, nodata2, move_data_1, &
1168 move_data_2, optimize_dist, unit_nr)
1169 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor1
1170 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor2
1171 TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out
1172 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free
1173 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked
1174 LOGICAL, INTENT(OUT) :: trans1, trans2
1175 LOGICAL, INTENT(OUT) :: new1, new2
1176 INTEGER, INTENT(OUT) :: ref_tensor
1177 LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
1178 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
1179 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
1180 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1181 INTEGER :: compat1, compat1_old, compat2, compat2_old, &
1182 unit_nr_prv
1183 TYPE(mp_cart_type) :: comm_2d
1184 TYPE(array_list) :: dist_list
1185 INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims
1186 TYPE(dbt_distribution_type) :: dist_in
1187 INTEGER(KIND=int_8) :: nblkrows, nblkcols
1188 LOGICAL :: optimize_dist_prv
1189 INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1190 INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1191
1192 NULLIFY (tensor1_out, tensor2_out)
1193
1194 unit_nr_prv = prep_output_unit(unit_nr)
1195
1196 CALL blk_dims_tensor(tensor1, dims1)
1197 CALL blk_dims_tensor(tensor2, dims2)
1198
1199 IF (product(int(dims1, int_8)) .GE. product(int(dims2, int_8))) THEN
1200 ref_tensor = 1
1201 ELSE
1202 ref_tensor = 2
1203 END IF
1204
1205 IF (PRESENT(optimize_dist)) THEN
1206 optimize_dist_prv = optimize_dist
1207 ELSE
1208 optimize_dist_prv = .false.
1209 END IF
1210
1211 compat1 = compat_map(tensor1%nd_index, ind1_linked)
1212 compat2 = compat_map(tensor2%nd_index, ind2_linked)
1213 compat1_old = compat1
1214 compat2_old = compat2
1215
1216 IF (unit_nr_prv > 0) THEN
1217 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1%name), ":"
1218 SELECT CASE (compat1)
1219 CASE (0)
1220 WRITE (unit_nr_prv, '(A)') "Not compatible"
1221 CASE (1)
1222 WRITE (unit_nr_prv, '(A)') "Normal"
1223 CASE (2)
1224 WRITE (unit_nr_prv, '(A)') "Transposed"
1225 END SELECT
1226 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2%name), ":"
1227 SELECT CASE (compat2)
1228 CASE (0)
1229 WRITE (unit_nr_prv, '(A)') "Not compatible"
1230 CASE (1)
1231 WRITE (unit_nr_prv, '(A)') "Normal"
1232 CASE (2)
1233 WRITE (unit_nr_prv, '(A)') "Transposed"
1234 END SELECT
1235 END IF
1236
1237 new1 = .false.
1238 new2 = .false.
1239
1240 IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1241 new1 = .true.
1242 END IF
1243
1244 IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1245 new2 = .true.
1246 END IF
1247
1248 IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1249 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1250 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor1%name)
1251 nblkrows = product(int(dims1(ind1_linked), kind=int_8))
1252 nblkcols = product(int(dims1(ind1_free), kind=int_8))
1253 comm_2d = dbt_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1254 ALLOCATE (tensor1_out)
1255 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1256 nodata=nodata1, move_data=move_data_1)
1257 CALL comm_2d%free()
1258 compat1 = 1
1259 ELSE
1260 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1261 tensor1_out => tensor1
1262 END IF
1263 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1264 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1265 trim(tensor2%name), "compatible with", trim(tensor1%name)
1266 dist_in = dbt_distribution(tensor1_out)
1267 dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1268 IF (compat1 == 1) THEN ! linked index is first 2d dimension
1269 ! get distribution of linked index, tensor 2 must adopt this distribution
1270 ! get grid dimensions of linked index
1271 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1272 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1273 ALLOCATE (tensor2_out)
1274 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1275 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1276 ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1277 ! get distribution of linked index, tensor 2 must adopt this distribution
1278 ! get grid dimensions of linked index
1279 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1280 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1281 ALLOCATE (tensor2_out)
1282 CALL dbt_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1283 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1284 ELSE
1285 cpabort("should not happen")
1286 END IF
1287 compat2 = compat1
1288 ELSE
1289 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1290 tensor2_out => tensor2
1291 END IF
1292 ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1293 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1294 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor2%name)
1295 nblkrows = product(int(dims2(ind2_linked), kind=int_8))
1296 nblkcols = product(int(dims2(ind2_free), kind=int_8))
1297 comm_2d = dbt_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1298 ALLOCATE (tensor2_out)
1299 CALL dbt_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1300 CALL comm_2d%free()
1301 compat2 = 1
1302 ELSE
1303 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor2%name)
1304 tensor2_out => tensor2
1305 END IF
1306 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1307 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", trim(tensor1%name), &
1308 "compatible with", trim(tensor2%name)
1309 dist_in = dbt_distribution(tensor2_out)
1310 dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1311 IF (compat2 == 1) THEN
1312 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1313 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1314 ALLOCATE (tensor1_out)
1315 CALL dbt_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1316 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1317 ELSEIF (compat2 == 2) THEN
1318 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1319 CALL dbt_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1320 ALLOCATE (tensor1_out)
1321 CALL dbt_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1322 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1323 ELSE
1324 cpabort("should not happen")
1325 END IF
1326 compat1 = compat2
1327 ELSE
1328 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor1%name)
1329 tensor1_out => tensor1
1330 END IF
1331 END IF
1332
1333 SELECT CASE (compat1)
1334 CASE (1)
1335 trans1 = .false.
1336 CASE (2)
1337 trans1 = .true.
1338 CASE DEFAULT
1339 cpabort("should not happen")
1340 END SELECT
1341
1342 SELECT CASE (compat2)
1343 CASE (1)
1344 trans2 = .false.
1345 CASE (2)
1346 trans2 = .true.
1347 CASE DEFAULT
1348 cpabort("should not happen")
1349 END SELECT
1350
1351 IF (unit_nr_prv > 0) THEN
1352 IF (compat1 .NE. compat1_old) THEN
1353 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor1_out%name), ":"
1354 SELECT CASE (compat1)
1355 CASE (0)
1356 WRITE (unit_nr_prv, '(A)') "Not compatible"
1357 CASE (1)
1358 WRITE (unit_nr_prv, '(A)') "Normal"
1359 CASE (2)
1360 WRITE (unit_nr_prv, '(A)') "Transposed"
1361 END SELECT
1362 END IF
1363 IF (compat2 .NE. compat2_old) THEN
1364 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor2_out%name), ":"
1365 SELECT CASE (compat2)
1366 CASE (0)
1367 WRITE (unit_nr_prv, '(A)') "Not compatible"
1368 CASE (1)
1369 WRITE (unit_nr_prv, '(A)') "Normal"
1370 CASE (2)
1371 WRITE (unit_nr_prv, '(A)') "Transposed"
1372 END SELECT
1373 END IF
1374 END IF
1375
1376 IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .true.
1377 IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .true.
1378
1379 END SUBROUTINE
1380
1381! **************************************************************************************************
1382!> \brief Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1383!> matrix multiplication. This routine reshapes the smallest of the three tensors.
1384!> \param ind1 index that should be mapped to first matrix dimension
1385!> \param ind2 index that should be mapped to second matrix dimension
1386!> \param trans transpose flag of matrix rep.
1387!> \param new whether a new tensor was created for tensor_out
1388!> \param nodata don't copy tensor data
1389!> \param move_data memory optimization: transfer data s.t. tensor_in may be empty on return
1390!> \param unit_nr output unit
1391!> \author Patrick Seewald
1392! **************************************************************************************************
1393 SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1394 TYPE(dbt_type), TARGET, INTENT(INOUT) :: tensor_in
1395 INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2
1396 TYPE(dbt_type), POINTER, INTENT(OUT) :: tensor_out
1397 LOGICAL, INTENT(OUT) :: trans
1398 LOGICAL, INTENT(OUT) :: new
1399 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1400 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1401 INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1402 LOGICAL :: nodata_prv
1403
1404 NULLIFY (tensor_out)
1405 IF (PRESENT(nodata)) THEN
1406 nodata_prv = nodata
1407 ELSE
1408 nodata_prv = .false.
1409 END IF
1410
1411 unit_nr_prv = prep_output_unit(unit_nr)
1412
1413 new = .false.
1414 compat1 = compat_map(tensor_in%nd_index, ind1)
1415 compat2 = compat_map(tensor_in%nd_index, ind2)
1416 compat1_old = compat1; compat2_old = compat2
1417 IF (unit_nr_prv > 0) THEN
1418 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_in%name), ":"
1419 IF (compat1 == 1 .AND. compat2 == 2) THEN
1420 WRITE (unit_nr_prv, '(A)') "Normal"
1421 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1422 WRITE (unit_nr_prv, '(A)') "Transposed"
1423 ELSE
1424 WRITE (unit_nr_prv, '(A)') "Not compatible"
1425 END IF
1426 END IF
1427 IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1428
1429 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", trim(tensor_in%name)
1430
1431 ALLOCATE (tensor_out)
1432 CALL dbt_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1433 CALL dbt_copy_contraction_storage(tensor_in, tensor_out)
1434 compat1 = 1
1435 compat2 = 2
1436 new = .true.
1437 ELSE
1438 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", trim(tensor_in%name)
1439 tensor_out => tensor_in
1440 END IF
1441
1442 IF (compat1 == 1 .AND. compat2 == 2) THEN
1443 trans = .false.
1444 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1445 trans = .true.
1446 ELSE
1447 cpabort("this should not happen")
1448 END IF
1449
1450 IF (unit_nr_prv > 0) THEN
1451 IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1452 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", trim(tensor_out%name), ":"
1453 IF (compat1 == 1 .AND. compat2 == 2) THEN
1454 WRITE (unit_nr_prv, '(A)') "Normal"
1455 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1456 WRITE (unit_nr_prv, '(A)') "Transposed"
1457 ELSE
1458 WRITE (unit_nr_prv, '(A)') "Not compatible"
1459 END IF
1460 END IF
1461 END IF
1462
1463 END SUBROUTINE
1464
1465! **************************************************************************************************
1466!> \brief update contraction storage that keeps track of process grids during a batched contraction
1467!> and decide if tensor process grid needs to be optimized
1468!> \param split_opt optimized TAS process grid
1469!> \param split current TAS process grid
1470!> \author Patrick Seewald
1471! **************************************************************************************************
1472 FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1473 TYPE(dbt_contraction_storage), INTENT(INOUT) :: storage
1474 TYPE(dbt_tas_split_info), INTENT(IN) :: split_opt
1475 TYPE(dbt_tas_split_info), INTENT(IN) :: split
1476 INTEGER, DIMENSION(2) :: pdims, pdims_sub
1477 LOGICAL, DIMENSION(2) :: do_change_pgrid
1478 REAL(kind=dp) :: change_criterion, pdims_ratio
1479 INTEGER :: nsplit_opt, nsplit
1480
1481 cpassert(ALLOCATED(split_opt%ngroup_opt))
1482 nsplit_opt = split_opt%ngroup_opt
1483 nsplit = split%ngroup
1484
1485 pdims = split%mp_comm%num_pe_cart
1486
1487 storage%ibatch = storage%ibatch + 1
1488
1489 storage%nsplit_avg = (storage%nsplit_avg*real(storage%ibatch - 1, dp) + real(nsplit_opt, dp)) &
1490 /real(storage%ibatch, dp)
1491
1492 SELECT CASE (split_opt%split_rowcol)
1493 CASE (rowsplit)
1494 pdims_ratio = real(pdims(1), dp)/pdims(2)
1495 CASE (colsplit)
1496 pdims_ratio = real(pdims(2), dp)/pdims(1)
1497 END SELECT
1498
1499 do_change_pgrid(:) = .false.
1500
1501 ! check for process grid dimensions
1502 pdims_sub = split%mp_comm_group%num_pe_cart
1503 change_criterion = maxval(real(pdims_sub, dp))/minval(pdims_sub)
1504 IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .true.
1505
1506 ! check for split factor
1507 change_criterion = max(real(nsplit, dp)/storage%nsplit_avg, real(storage%nsplit_avg, dp)/nsplit)
1508 IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .true.
1509
1510 END FUNCTION
1511
1512! **************************************************************************************************
1513!> \brief Check if 2d index is compatible with tensor index
1514!> \author Patrick Seewald
1515! **************************************************************************************************
1516 FUNCTION compat_map(nd_index, compat_ind)
1517 TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1518 INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind
1519 INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1520 INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1521 INTEGER :: compat_map
1522
1523 CALL dbt_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1524
1525 compat_map = 0
1526 IF (array_eq_i(map1, compat_ind)) THEN
1527 compat_map = 1
1528 ELSEIF (array_eq_i(map2, compat_ind)) THEN
1529 compat_map = 2
1530 END IF
1531
1532 END FUNCTION
1533
1534! **************************************************************************************************
1535!> \brief
1536!> \author Patrick Seewald
1537! **************************************************************************************************
1538 SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1539 INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1540 INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices
1541
1542 CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1543 ind_dep(:) = ind_dep(sort_indices)
1544
1545 END SUBROUTINE
1546
1547! **************************************************************************************************
1548!> \brief
1549!> \author Patrick Seewald
1550! **************************************************************************************************
1551 FUNCTION opt_pgrid(tensor, tas_split_info)
1552 TYPE(dbt_type), INTENT(IN) :: tensor
1553 TYPE(dbt_tas_split_info), INTENT(IN) :: tas_split_info
1554 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1555 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1556 TYPE(dbt_pgrid_type) :: opt_pgrid
1557 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1558
1559 CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1560 CALL blk_dims_tensor(tensor, dims)
1561 opt_pgrid = dbt_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1562
1563 ALLOCATE (opt_pgrid%tas_split_info, source=tas_split_info)
1564 CALL dbt_tas_info_hold(opt_pgrid%tas_split_info)
1565 END FUNCTION
1566
1567! **************************************************************************************************
1568!> \brief Copy tensor to tensor with modified index mapping
1569!> \param map1_2d new index mapping
1570!> \param map2_2d new index mapping
1571!> \author Patrick Seewald
1572! **************************************************************************************************
1573 SUBROUTINE dbt_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1574 mp_dims_1, mp_dims_2, name, nodata, move_data)
1575 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1576 INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d
1577 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1578 CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1579 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1580 CLASS(mp_comm_type), INTENT(IN), OPTIONAL :: comm_2d
1581 TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1582 INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1583 INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1584 CHARACTER(len=default_string_length) :: name_tmp
1585 INTEGER, DIMENSION(:), ALLOCATABLE :: blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4, &
1586 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4
1587 TYPE(dbt_distribution_type) :: dist
1588 TYPE(mp_cart_type) :: comm_2d_prv
1589 INTEGER :: handle, i
1590 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1591 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_remap'
1592 LOGICAL :: nodata_prv
1593 TYPE(dbt_pgrid_type) :: comm_nd
1594
1595 CALL timeset(routinen, handle)
1596
1597 IF (PRESENT(name)) THEN
1598 name_tmp = name
1599 ELSE
1600 name_tmp = tensor_in%name
1601 END IF
1602 IF (PRESENT(dist1)) THEN
1603 cpassert(PRESENT(mp_dims_1))
1604 END IF
1605
1606 IF (PRESENT(dist2)) THEN
1607 cpassert(PRESENT(mp_dims_2))
1608 END IF
1609
1610 IF (PRESENT(comm_2d)) THEN
1611 comm_2d_prv = comm_2d
1612 ELSE
1613 comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1614 END IF
1615
1616 comm_nd = dbt_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1617 CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1618
1619 IF (ndims_tensor(tensor_in) == 2) THEN
1620 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2)
1621 END IF
1622 IF (ndims_tensor(tensor_in) == 3) THEN
1623 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3)
1624 END IF
1625 IF (ndims_tensor(tensor_in) == 4) THEN
1626 CALL get_arrays(tensor_in%blk_sizes, blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1627 END IF
1628
1629 IF (ndims_tensor(tensor_in) == 2) THEN
1630 IF (PRESENT(dist1)) THEN
1631 IF (any(map1_2d == 1)) THEN
1632 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1633 CALL get_ith_array(dist1, i, nd_dist_1)
1634 END IF
1635 END IF
1636
1637 IF (PRESENT(dist2)) THEN
1638 IF (any(map2_2d == 1)) THEN
1639 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1640 CALL get_ith_array(dist2, i, nd_dist_1)
1641 END IF
1642 END IF
1643
1644 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1645 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1646 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1647 END IF
1648 IF (PRESENT(dist1)) THEN
1649 IF (any(map1_2d == 2)) THEN
1650 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1651 CALL get_ith_array(dist1, i, nd_dist_2)
1652 END IF
1653 END IF
1654
1655 IF (PRESENT(dist2)) THEN
1656 IF (any(map2_2d == 2)) THEN
1657 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1658 CALL get_ith_array(dist2, i, nd_dist_2)
1659 END IF
1660 END IF
1661
1662 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1663 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1664 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1665 END IF
1666 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1667 nd_dist_1, nd_dist_2, own_comm=.true.)
1668 END IF
1669 IF (ndims_tensor(tensor_in) == 3) THEN
1670 IF (PRESENT(dist1)) THEN
1671 IF (any(map1_2d == 1)) THEN
1672 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1673 CALL get_ith_array(dist1, i, nd_dist_1)
1674 END IF
1675 END IF
1676
1677 IF (PRESENT(dist2)) THEN
1678 IF (any(map2_2d == 1)) THEN
1679 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1680 CALL get_ith_array(dist2, i, nd_dist_1)
1681 END IF
1682 END IF
1683
1684 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1685 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1686 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1687 END IF
1688 IF (PRESENT(dist1)) THEN
1689 IF (any(map1_2d == 2)) THEN
1690 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1691 CALL get_ith_array(dist1, i, nd_dist_2)
1692 END IF
1693 END IF
1694
1695 IF (PRESENT(dist2)) THEN
1696 IF (any(map2_2d == 2)) THEN
1697 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1698 CALL get_ith_array(dist2, i, nd_dist_2)
1699 END IF
1700 END IF
1701
1702 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1703 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1704 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1705 END IF
1706 IF (PRESENT(dist1)) THEN
1707 IF (any(map1_2d == 3)) THEN
1708 i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1709 CALL get_ith_array(dist1, i, nd_dist_3)
1710 END IF
1711 END IF
1712
1713 IF (PRESENT(dist2)) THEN
1714 IF (any(map2_2d == 3)) THEN
1715 i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1716 CALL get_ith_array(dist2, i, nd_dist_3)
1717 END IF
1718 END IF
1719
1720 IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1721 ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1722 CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1723 END IF
1724 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1725 nd_dist_1, nd_dist_2, nd_dist_3, own_comm=.true.)
1726 END IF
1727 IF (ndims_tensor(tensor_in) == 4) THEN
1728 IF (PRESENT(dist1)) THEN
1729 IF (any(map1_2d == 1)) THEN
1730 i = minloc(map1_2d, dim=1, mask=map1_2d == 1) ! i is location of idim in map1_2d
1731 CALL get_ith_array(dist1, i, nd_dist_1)
1732 END IF
1733 END IF
1734
1735 IF (PRESENT(dist2)) THEN
1736 IF (any(map2_2d == 1)) THEN
1737 i = minloc(map2_2d, dim=1, mask=map2_2d == 1) ! i is location of idim in map2_2d
1738 CALL get_ith_array(dist2, i, nd_dist_1)
1739 END IF
1740 END IF
1741
1742 IF (.NOT. ALLOCATED(nd_dist_1)) THEN
1743 ALLOCATE (nd_dist_1(SIZE(blk_sizes_1)))
1744 CALL dbt_default_distvec(SIZE(blk_sizes_1), pdims(1), blk_sizes_1, nd_dist_1)
1745 END IF
1746 IF (PRESENT(dist1)) THEN
1747 IF (any(map1_2d == 2)) THEN
1748 i = minloc(map1_2d, dim=1, mask=map1_2d == 2) ! i is location of idim in map1_2d
1749 CALL get_ith_array(dist1, i, nd_dist_2)
1750 END IF
1751 END IF
1752
1753 IF (PRESENT(dist2)) THEN
1754 IF (any(map2_2d == 2)) THEN
1755 i = minloc(map2_2d, dim=1, mask=map2_2d == 2) ! i is location of idim in map2_2d
1756 CALL get_ith_array(dist2, i, nd_dist_2)
1757 END IF
1758 END IF
1759
1760 IF (.NOT. ALLOCATED(nd_dist_2)) THEN
1761 ALLOCATE (nd_dist_2(SIZE(blk_sizes_2)))
1762 CALL dbt_default_distvec(SIZE(blk_sizes_2), pdims(2), blk_sizes_2, nd_dist_2)
1763 END IF
1764 IF (PRESENT(dist1)) THEN
1765 IF (any(map1_2d == 3)) THEN
1766 i = minloc(map1_2d, dim=1, mask=map1_2d == 3) ! i is location of idim in map1_2d
1767 CALL get_ith_array(dist1, i, nd_dist_3)
1768 END IF
1769 END IF
1770
1771 IF (PRESENT(dist2)) THEN
1772 IF (any(map2_2d == 3)) THEN
1773 i = minloc(map2_2d, dim=1, mask=map2_2d == 3) ! i is location of idim in map2_2d
1774 CALL get_ith_array(dist2, i, nd_dist_3)
1775 END IF
1776 END IF
1777
1778 IF (.NOT. ALLOCATED(nd_dist_3)) THEN
1779 ALLOCATE (nd_dist_3(SIZE(blk_sizes_3)))
1780 CALL dbt_default_distvec(SIZE(blk_sizes_3), pdims(3), blk_sizes_3, nd_dist_3)
1781 END IF
1782 IF (PRESENT(dist1)) THEN
1783 IF (any(map1_2d == 4)) THEN
1784 i = minloc(map1_2d, dim=1, mask=map1_2d == 4) ! i is location of idim in map1_2d
1785 CALL get_ith_array(dist1, i, nd_dist_4)
1786 END IF
1787 END IF
1788
1789 IF (PRESENT(dist2)) THEN
1790 IF (any(map2_2d == 4)) THEN
1791 i = minloc(map2_2d, dim=1, mask=map2_2d == 4) ! i is location of idim in map2_2d
1792 CALL get_ith_array(dist2, i, nd_dist_4)
1793 END IF
1794 END IF
1795
1796 IF (.NOT. ALLOCATED(nd_dist_4)) THEN
1797 ALLOCATE (nd_dist_4(SIZE(blk_sizes_4)))
1798 CALL dbt_default_distvec(SIZE(blk_sizes_4), pdims(4), blk_sizes_4, nd_dist_4)
1799 END IF
1800 CALL dbt_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1801 nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm=.true.)
1802 END IF
1803
1804 IF (ndims_tensor(tensor_in) == 2) THEN
1805 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1806 blk_sizes_1, blk_sizes_2)
1807 END IF
1808 IF (ndims_tensor(tensor_in) == 3) THEN
1809 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1810 blk_sizes_1, blk_sizes_2, blk_sizes_3)
1811 END IF
1812 IF (ndims_tensor(tensor_in) == 4) THEN
1813 CALL dbt_create(tensor_out, name_tmp, dist, map1_2d, map2_2d, &
1814 blk_sizes_1, blk_sizes_2, blk_sizes_3, blk_sizes_4)
1815 END IF
1816
1817 IF (PRESENT(nodata)) THEN
1818 nodata_prv = nodata
1819 ELSE
1820 nodata_prv = .false.
1821 END IF
1822
1823 IF (.NOT. nodata_prv) CALL dbt_copy_expert(tensor_in, tensor_out, move_data=move_data)
1824 CALL dbt_distribution_destroy(dist)
1825
1826 CALL timestop(handle)
1827 END SUBROUTINE
1828
1829! **************************************************************************************************
1830!> \brief Align index with data
1831!> \param order permutation resulting from alignment
1832!> \author Patrick Seewald
1833! **************************************************************************************************
1834 SUBROUTINE dbt_align_index(tensor_in, tensor_out, order)
1835 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1836 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1837 INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1838 INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1839 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1840 INTENT(OUT), OPTIONAL :: order
1841 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv
1842 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_align_index'
1843 INTEGER :: handle
1844
1845 CALL timeset(routinen, handle)
1846
1847 CALL dbt_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1848 order_prv = dbt_inverse_order([map1_2d, map2_2d])
1849 CALL dbt_permute_index(tensor_in, tensor_out, order=order_prv)
1850
1851 IF (PRESENT(order)) order = order_prv
1852
1853 CALL timestop(handle)
1854 END SUBROUTINE
1855
1856! **************************************************************************************************
1857!> \brief Create new tensor by reordering index, data is copied exactly (shallow copy)
1858!> \author Patrick Seewald
1859! **************************************************************************************************
1860 SUBROUTINE dbt_permute_index(tensor_in, tensor_out, order)
1861 TYPE(dbt_type), INTENT(INOUT) :: tensor_in
1862 TYPE(dbt_type), INTENT(OUT) :: tensor_out
1863 INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1864 INTENT(IN) :: order
1865
1866 TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs
1867 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_permute_index'
1868 INTEGER :: handle
1869 INTEGER :: ndims
1870
1871 CALL timeset(routinen, handle)
1872
1873 ndims = ndims_tensor(tensor_in)
1874
1875 CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1876 CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1877 CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1878
1879 tensor_out%matrix_rep => tensor_in%matrix_rep
1880 tensor_out%owns_matrix = .false.
1881
1882 tensor_out%nd_index = nd_index_rs
1883 tensor_out%nd_index_blk = nd_index_blk_rs
1884 tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1885 IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1886 ALLOCATE (tensor_out%pgrid%tas_split_info, source=tensor_in%pgrid%tas_split_info)
1887 END IF
1888 tensor_out%refcount => tensor_in%refcount
1889 CALL dbt_hold(tensor_out)
1890
1891 CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1892 CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1893 CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1894 CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1895 ALLOCATE (tensor_out%nblks_local(ndims))
1896 ALLOCATE (tensor_out%nfull_local(ndims))
1897 tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1898 tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1899 tensor_out%name = tensor_in%name
1900 tensor_out%valid = .true.
1901
1902 IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1903 ALLOCATE (tensor_out%contraction_storage, source=tensor_in%contraction_storage)
1904 CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1905 CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1906 END IF
1907
1908 CALL timestop(handle)
1909 END SUBROUTINE
1910
1911! **************************************************************************************************
1912!> \brief Map contraction bounds to bounds referring to tensor indices
1913!> see dbt_contract for docu of dummy arguments
1914!> \param bounds_t1 bounds mapped to tensor_1
1915!> \param bounds_t2 bounds mapped to tensor_2
1916!> \param do_crop_1 whether tensor 1 should be cropped
1917!> \param do_crop_2 whether tensor 2 should be cropped
1918!> \author Patrick Seewald
1919! **************************************************************************************************
1920 SUBROUTINE dbt_map_bounds_to_tensors(tensor_1, tensor_2, &
1921 contract_1, notcontract_1, &
1922 contract_2, notcontract_2, &
1923 bounds_t1, bounds_t2, &
1924 bounds_1, bounds_2, bounds_3, &
1925 do_crop_1, do_crop_2)
1926
1927 TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2
1928 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, &
1929 notcontract_1, notcontract_2
1930 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1931 INTENT(OUT) :: bounds_t1
1932 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1933 INTENT(OUT) :: bounds_t2
1934 INTEGER, DIMENSION(2, SIZE(contract_1)), &
1935 INTENT(IN), OPTIONAL :: bounds_1
1936 INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1937 INTENT(IN), OPTIONAL :: bounds_2
1938 INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1939 INTENT(IN), OPTIONAL :: bounds_3
1940 LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2
1941 LOGICAL, DIMENSION(2) :: do_crop
1942
1943 do_crop = .false.
1944
1945 bounds_t1(1, :) = 1
1946 CALL dbt_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1947
1948 bounds_t2(1, :) = 1
1949 CALL dbt_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1950
1951 IF (PRESENT(bounds_1)) THEN
1952 bounds_t1(:, contract_1) = bounds_1
1953 do_crop(1) = .true.
1954 bounds_t2(:, contract_2) = bounds_1
1955 do_crop(2) = .true.
1956 END IF
1957
1958 IF (PRESENT(bounds_2)) THEN
1959 bounds_t1(:, notcontract_1) = bounds_2
1960 do_crop(1) = .true.
1961 END IF
1962
1963 IF (PRESENT(bounds_3)) THEN
1964 bounds_t2(:, notcontract_2) = bounds_3
1965 do_crop(2) = .true.
1966 END IF
1967
1968 IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1969 IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1970
1971 END SUBROUTINE
1972
1973! **************************************************************************************************
1974!> \brief print tensor contraction indices in a human readable way
1975!> \param indchar1 characters printed for index of tensor 1
1976!> \param indchar2 characters printed for index of tensor 2
1977!> \param indchar3 characters printed for index of tensor 3
1978!> \param unit_nr output unit
1979!> \author Patrick Seewald
1980! **************************************************************************************************
1981 SUBROUTINE dbt_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1982 TYPE(dbt_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1983 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1984 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1985 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1986 INTEGER, INTENT(IN) :: unit_nr
1987 INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1988 INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1989 INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1990 INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1991 INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1992 INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1993 INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1994
1995 unit_nr_prv = prep_output_unit(unit_nr)
1996
1997 IF (unit_nr_prv /= 0) THEN
1998 CALL dbt_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1999 CALL dbt_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
2000 CALL dbt_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
2001 END IF
2002
2003 IF (unit_nr_prv > 0) THEN
2004 WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
2005 WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
2006 DO ichar1 = 1, SIZE(indchar1)
2007 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
2008 END DO
2009 WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2010 DO ichar2 = 1, SIZE(indchar2)
2011 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
2012 END DO
2013 WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2014 DO ichar3 = 1, SIZE(indchar3)
2015 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
2016 END DO
2017 WRITE (unit_nr_prv, '(A)') ")"
2018
2019 WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
2020 DO ichar1 = 1, SIZE(map11)
2021 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
2022 END DO
2023 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2024 DO ichar1 = 1, SIZE(map12)
2025 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
2026 END DO
2027 WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
2028 DO ichar2 = 1, SIZE(map21)
2029 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
2030 END DO
2031 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2032 DO ichar2 = 1, SIZE(map22)
2033 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
2034 END DO
2035 WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
2036 DO ichar3 = 1, SIZE(map31)
2037 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
2038 END DO
2039 WRITE (unit_nr_prv, '(A1)', advance='no') "|"
2040 DO ichar3 = 1, SIZE(map32)
2041 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
2042 END DO
2043 WRITE (unit_nr_prv, '(A)') ")"
2044 END IF
2045
2046 END SUBROUTINE
2047
2048! **************************************************************************************************
2049!> \brief Initialize batched contraction for this tensor.
2050!>
2051!> Explanation: A batched contraction is a contraction performed in several consecutive steps
2052!> by specification of bounds in dbt_contract. This can be used to reduce memory by
2053!> a large factor. The routines dbt_batched_contract_init and
2054!> dbt_batched_contract_finalize should be called to define the scope of a batched
2055!> contraction as this enables important optimizations (adapting communication scheme to
2056!> batches and adapting process grid to multiplication algorithm). The routines
2057!> dbt_batched_contract_init and dbt_batched_contract_finalize must be
2058!> called before the first and after the last contraction step on all 3 tensors.
2059!>
2060!> Requirements:
2061!> - the tensors are in a compatible matrix layout (see documentation of
2062!> `dbt_contract`, note 2 & 3). If they are not, process grid optimizations are
2063!> disabled and a warning is issued.
2064!> - within the scope of a batched contraction, it is not allowed to access or change tensor
2065!> data except by calling the routines dbt_contract & dbt_copy.
2066!> - the bounds affecting indices of the smallest tensor must not change in the course of a
2067!> batched contraction (todo: get rid of this requirement).
2068!>
2069!> Side effects:
2070!> - the parallel layout (process grid and distribution) of all tensors may change. In order
2071!> to disable the process grid optimization including this side effect, call this routine
2072!> only on the smallest of the 3 tensors.
2073!>
2074!> \note
2075!> Note 1: for an example of batched contraction see `examples/dbt_example.F`.
2076!> (todo: the example is outdated and should be updated).
2077!>
2078!> Note 2: it is meaningful to use this feature if the contraction consists of one batch only
2079!> but if multiple contractions involving the same 3 tensors are performed
2080!> (batched_contract_init and batched_contract_finalize must then be called before/after each
2081!> contraction call). The process grid is then optimized after the first contraction
2082!> and future contraction may profit from this optimization.
2083!>
2084!> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2085!> a new range. The size should be the number of ranges plus one, the last
2086!> element being the block index plus one of the last block in the last range.
2087!> For internal load balancing optimizations, optionally specify the index
2088!> ranges of batched contraction.
2089!> \author Patrick Seewald
2090! **************************************************************************************************
2091 SUBROUTINE dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2092 TYPE(dbt_type), INTENT(INOUT) :: tensor
2093 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2094 INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
2095 INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_prv_1, batch_range_prv_2, batch_range_prv_3,&
2096 & batch_range_prv_4
2097 LOGICAL :: static_range
2098
2099 CALL dbt_get_info(tensor, nblks_total=tdims)
2100
2101 static_range = .true.
2102 IF (ndims_tensor(tensor) >= 1) THEN
2103 IF (PRESENT(batch_range_1)) THEN
2104 ALLOCATE (batch_range_prv_1, source=batch_range_1)
2105 static_range = .false.
2106 ELSE
2107 ALLOCATE (batch_range_prv_1(2))
2108 batch_range_prv_1(1) = 1
2109 batch_range_prv_1(2) = tdims(1) + 1
2110 END IF
2111 END IF
2112 IF (ndims_tensor(tensor) >= 2) THEN
2113 IF (PRESENT(batch_range_2)) THEN
2114 ALLOCATE (batch_range_prv_2, source=batch_range_2)
2115 static_range = .false.
2116 ELSE
2117 ALLOCATE (batch_range_prv_2(2))
2118 batch_range_prv_2(1) = 1
2119 batch_range_prv_2(2) = tdims(2) + 1
2120 END IF
2121 END IF
2122 IF (ndims_tensor(tensor) >= 3) THEN
2123 IF (PRESENT(batch_range_3)) THEN
2124 ALLOCATE (batch_range_prv_3, source=batch_range_3)
2125 static_range = .false.
2126 ELSE
2127 ALLOCATE (batch_range_prv_3(2))
2128 batch_range_prv_3(1) = 1
2129 batch_range_prv_3(2) = tdims(3) + 1
2130 END IF
2131 END IF
2132 IF (ndims_tensor(tensor) >= 4) THEN
2133 IF (PRESENT(batch_range_4)) THEN
2134 ALLOCATE (batch_range_prv_4, source=batch_range_4)
2135 static_range = .false.
2136 ELSE
2137 ALLOCATE (batch_range_prv_4(2))
2138 batch_range_prv_4(1) = 1
2139 batch_range_prv_4(2) = tdims(4) + 1
2140 END IF
2141 END IF
2142
2143 ALLOCATE (tensor%contraction_storage)
2144 tensor%contraction_storage%static = static_range
2145 IF (static_range) THEN
2146 CALL dbt_tas_batched_mm_init(tensor%matrix_rep)
2147 END IF
2148 tensor%contraction_storage%nsplit_avg = 0.0_dp
2149 tensor%contraction_storage%ibatch = 0
2150
2151 IF (ndims_tensor(tensor) == 1) THEN
2152 CALL create_array_list(tensor%contraction_storage%batch_ranges, 1, &
2153 batch_range_prv_1)
2154 END IF
2155 IF (ndims_tensor(tensor) == 2) THEN
2156 CALL create_array_list(tensor%contraction_storage%batch_ranges, 2, &
2157 batch_range_prv_1, batch_range_prv_2)
2158 END IF
2159 IF (ndims_tensor(tensor) == 3) THEN
2160 CALL create_array_list(tensor%contraction_storage%batch_ranges, 3, &
2161 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3)
2162 END IF
2163 IF (ndims_tensor(tensor) == 4) THEN
2164 CALL create_array_list(tensor%contraction_storage%batch_ranges, 4, &
2165 batch_range_prv_1, batch_range_prv_2, batch_range_prv_3, batch_range_prv_4)
2166 END IF
2167
2168 END SUBROUTINE
2169
2170! **************************************************************************************************
2171!> \brief finalize batched contraction. This performs all communication that has been postponed in
2172!> the contraction calls.
2173!> \author Patrick Seewald
2174! **************************************************************************************************
2175 SUBROUTINE dbt_batched_contract_finalize(tensor, unit_nr)
2176 TYPE(dbt_type), INTENT(INOUT) :: tensor
2177 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2178 LOGICAL :: do_write
2179 INTEGER :: unit_nr_prv, handle
2180
2181 CALL tensor%pgrid%mp_comm_2d%sync()
2182 CALL timeset("dbt_total", handle)
2183 unit_nr_prv = prep_output_unit(unit_nr)
2184
2185 do_write = .false.
2186
2187 IF (tensor%contraction_storage%static) THEN
2188 IF (tensor%matrix_rep%do_batched > 0) THEN
2189 IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .true.
2190 END IF
2191 CALL dbt_tas_batched_mm_finalize(tensor%matrix_rep)
2192 END IF
2193
2194 IF (do_write .AND. unit_nr_prv /= 0) THEN
2195 IF (unit_nr_prv > 0) THEN
2196 WRITE (unit_nr_prv, "(T2,A)") &
2197 "FINALIZING BATCHED PROCESSING OF MATMUL"
2198 END IF
2199 CALL dbt_write_tensor_info(tensor, unit_nr_prv)
2200 CALL dbt_write_tensor_dist(tensor, unit_nr_prv)
2201 END IF
2202
2203 CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2204 DEALLOCATE (tensor%contraction_storage)
2205 CALL tensor%pgrid%mp_comm_2d%sync()
2206 CALL timestop(handle)
2207
2208 END SUBROUTINE
2209
2210! **************************************************************************************************
2211!> \brief change the process grid of a tensor
2212!> \param nodata optionally don't copy the tensor data (then tensor is empty on returned)
2213!> \param batch_range_i refers to the ith tensor dimension and contains all block indices starting
2214!> a new range. The size should be the number of ranges plus one, the last
2215!> element being the block index plus one of the last block in the last range.
2216!> For internal load balancing optimizations, optionally specify the index
2217!> ranges of batched contraction.
2218!> \author Patrick Seewald
2219! **************************************************************************************************
2220 SUBROUTINE dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2221 nodata, pgrid_changed, unit_nr)
2222 TYPE(dbt_type), INTENT(INOUT) :: tensor
2223 TYPE(dbt_pgrid_type), INTENT(IN) :: pgrid
2224 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2225 !!
2226 LOGICAL, INTENT(IN), OPTIONAL :: nodata
2227 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2228 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2229 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_change_pgrid'
2230 CHARACTER(default_string_length) :: name
2231 INTEGER :: handle
2232 INTEGER, ALLOCATABLE, DIMENSION(:) :: bs_1, bs_2, bs_3, bs_4, &
2233 dist_1, dist_2, dist_3, dist_4
2234 INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, &
2235 tdims
2236 TYPE(dbt_type) :: t_tmp
2237 TYPE(dbt_distribution_type) :: dist
2238 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2239 INTEGER, &
2240 DIMENSION(ndims_matrix_column(tensor)) :: map2
2241 LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware
2242 INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2243 INTEGER :: ind1, ind2, batch_size, ibatch
2244
2245 IF (PRESENT(pgrid_changed)) pgrid_changed = .false.
2246 CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2247 CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2248
2249 IF (all(pdims == pdims_ref)) THEN
2250 IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2251 IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2252 RETURN
2253 END IF
2254 END IF
2255 END IF
2256
2257 CALL timeset(routinen, handle)
2258
2259 IF (ndims_tensor(tensor) >= 1) THEN
2260 mem_aware(1) = PRESENT(batch_range_1)
2261 IF (mem_aware(1)) nbatch(1) = SIZE(batch_range_1) - 1
2262 END IF
2263 IF (ndims_tensor(tensor) >= 2) THEN
2264 mem_aware(2) = PRESENT(batch_range_2)
2265 IF (mem_aware(2)) nbatch(2) = SIZE(batch_range_2) - 1
2266 END IF
2267 IF (ndims_tensor(tensor) >= 3) THEN
2268 mem_aware(3) = PRESENT(batch_range_3)
2269 IF (mem_aware(3)) nbatch(3) = SIZE(batch_range_3) - 1
2270 END IF
2271 IF (ndims_tensor(tensor) >= 4) THEN
2272 mem_aware(4) = PRESENT(batch_range_4)
2273 IF (mem_aware(4)) nbatch(4) = SIZE(batch_range_4) - 1
2274 END IF
2275
2276 CALL dbt_get_info(tensor, nblks_total=tdims, name=name)
2277
2278 IF (ndims_tensor(tensor) >= 1) THEN
2279 ALLOCATE (bs_1(dbt_nblks_total(tensor, 1)))
2280 CALL get_ith_array(tensor%blk_sizes, 1, bs_1)
2281 ALLOCATE (dist_1(tdims(1)))
2282 dist_1 = 0
2283 IF (mem_aware(1)) THEN
2284 DO ibatch = 1, nbatch(1)
2285 ind1 = batch_range_1(ibatch)
2286 ind2 = batch_range_1(ibatch + 1) - 1
2287 batch_size = ind2 - ind1 + 1
2288 CALL dbt_default_distvec(batch_size, pdims(1), &
2289 bs_1(ind1:ind2), dist_1(ind1:ind2))
2290 END DO
2291 ELSE
2292 CALL dbt_default_distvec(tdims(1), pdims(1), bs_1, dist_1)
2293 END IF
2294 END IF
2295 IF (ndims_tensor(tensor) >= 2) THEN
2296 ALLOCATE (bs_2(dbt_nblks_total(tensor, 2)))
2297 CALL get_ith_array(tensor%blk_sizes, 2, bs_2)
2298 ALLOCATE (dist_2(tdims(2)))
2299 dist_2 = 0
2300 IF (mem_aware(2)) THEN
2301 DO ibatch = 1, nbatch(2)
2302 ind1 = batch_range_2(ibatch)
2303 ind2 = batch_range_2(ibatch + 1) - 1
2304 batch_size = ind2 - ind1 + 1
2305 CALL dbt_default_distvec(batch_size, pdims(2), &
2306 bs_2(ind1:ind2), dist_2(ind1:ind2))
2307 END DO
2308 ELSE
2309 CALL dbt_default_distvec(tdims(2), pdims(2), bs_2, dist_2)
2310 END IF
2311 END IF
2312 IF (ndims_tensor(tensor) >= 3) THEN
2313 ALLOCATE (bs_3(dbt_nblks_total(tensor, 3)))
2314 CALL get_ith_array(tensor%blk_sizes, 3, bs_3)
2315 ALLOCATE (dist_3(tdims(3)))
2316 dist_3 = 0
2317 IF (mem_aware(3)) THEN
2318 DO ibatch = 1, nbatch(3)
2319 ind1 = batch_range_3(ibatch)
2320 ind2 = batch_range_3(ibatch + 1) - 1
2321 batch_size = ind2 - ind1 + 1
2322 CALL dbt_default_distvec(batch_size, pdims(3), &
2323 bs_3(ind1:ind2), dist_3(ind1:ind2))
2324 END DO
2325 ELSE
2326 CALL dbt_default_distvec(tdims(3), pdims(3), bs_3, dist_3)
2327 END IF
2328 END IF
2329 IF (ndims_tensor(tensor) >= 4) THEN
2330 ALLOCATE (bs_4(dbt_nblks_total(tensor, 4)))
2331 CALL get_ith_array(tensor%blk_sizes, 4, bs_4)
2332 ALLOCATE (dist_4(tdims(4)))
2333 dist_4 = 0
2334 IF (mem_aware(4)) THEN
2335 DO ibatch = 1, nbatch(4)
2336 ind1 = batch_range_4(ibatch)
2337 ind2 = batch_range_4(ibatch + 1) - 1
2338 batch_size = ind2 - ind1 + 1
2339 CALL dbt_default_distvec(batch_size, pdims(4), &
2340 bs_4(ind1:ind2), dist_4(ind1:ind2))
2341 END DO
2342 ELSE
2343 CALL dbt_default_distvec(tdims(4), pdims(4), bs_4, dist_4)
2344 END IF
2345 END IF
2346
2347 CALL dbt_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2348 IF (ndims_tensor(tensor) == 2) THEN
2349 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2)
2350 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2)
2351 END IF
2352 IF (ndims_tensor(tensor) == 3) THEN
2353 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3)
2354 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3)
2355 END IF
2356 IF (ndims_tensor(tensor) == 4) THEN
2357 CALL dbt_distribution_new(dist, pgrid, dist_1, dist_2, dist_3, dist_4)
2358 CALL dbt_create(t_tmp, name, dist, map1, map2, bs_1, bs_2, bs_3, bs_4)
2359 END IF
2360 CALL dbt_distribution_destroy(dist)
2361
2362 IF (PRESENT(nodata)) THEN
2363 IF (.NOT. nodata) CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2364 ELSE
2365 CALL dbt_copy_expert(tensor, t_tmp, move_data=.true.)
2366 END IF
2367
2368 CALL dbt_copy_contraction_storage(tensor, t_tmp)
2369
2370 CALL dbt_destroy(tensor)
2371 tensor = t_tmp
2372
2373 IF (PRESENT(unit_nr)) THEN
2374 IF (unit_nr > 0) THEN
2375 WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", trim(tensor%name)
2376 WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2377 CALL dbt_write_split_info(pgrid, unit_nr)
2378 END IF
2379 END IF
2380
2381 IF (PRESENT(pgrid_changed)) pgrid_changed = .true.
2382
2383 CALL timestop(handle)
2384 END SUBROUTINE
2385
2386! **************************************************************************************************
2387!> \brief map tensor to a new 2d process grid for the matrix representation.
2388!> \author Patrick Seewald
2389! **************************************************************************************************
2390 SUBROUTINE dbt_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2391 TYPE(dbt_type), INTENT(INOUT) :: tensor
2392 TYPE(mp_cart_type), INTENT(IN) :: mp_comm
2393 INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2394 LOGICAL, INTENT(IN), OPTIONAL :: nodata
2395 INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2396 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2397 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2398 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2399 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2400 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2401 TYPE(dbt_pgrid_type) :: pgrid
2402 INTEGER, DIMENSION(:), ALLOCATABLE :: batch_range_1, batch_range_2, batch_range_3, batch_range_4
2403 INTEGER, DIMENSION(:), ALLOCATABLE :: array
2404 INTEGER :: idim
2405
2406 CALL dbt_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2407 CALL blk_dims_tensor(tensor, dims)
2408
2409 IF (ALLOCATED(tensor%contraction_storage)) THEN
2410 associate(batch_ranges => tensor%contraction_storage%batch_ranges)
2411 nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2412 ! for good load balancing the process grid dimensions should be chosen adapted to the
2413 ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2414 ! the number of batches (number of index ranges).
2415 DO idim = 1, ndims_tensor(tensor)
2416 CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2417 dims(idim) = array(nbatches(idim) + 1) - array(1)
2418 DEALLOCATE (array)
2419 dims(idim) = dims(idim)/nbatches(idim)
2420 IF (dims(idim) <= 0) dims(idim) = 1
2421 END DO
2422 END associate
2423 END IF
2424
2425 pgrid = dbt_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2426 IF (ALLOCATED(tensor%contraction_storage)) THEN
2427 IF (ndims_tensor(tensor) == 1) THEN
2428 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1)
2429 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, &
2430 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2431 END IF
2432 IF (ndims_tensor(tensor) == 2) THEN
2433 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2)
2434 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, &
2435 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2436 END IF
2437 IF (ndims_tensor(tensor) == 3) THEN
2438 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3)
2439 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, &
2440 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2441 END IF
2442 IF (ndims_tensor(tensor) == 4) THEN
2443 CALL get_arrays(tensor%contraction_storage%batch_ranges, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
2444 CALL dbt_change_pgrid(tensor, pgrid, batch_range_1, batch_range_2, batch_range_3, batch_range_4, &
2445 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2446 END IF
2447 ELSE
2448 CALL dbt_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2449 END IF
2450 CALL dbt_pgrid_destroy(pgrid)
2451
2452 END SUBROUTINE
2453
2454END MODULE
struct tensor_ tensor
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Definition dbm_api.F:529
Wrapper for allocating, copying and reshaping arrays.
Representation of arbitrary number of 1d integer arrays with arbitrary sizes. This is needed for gene...
pure logical function, public array_eq_i(arr1, arr2)
check whether two arrays are equal
subroutine, public get_arrays(list, data_1, data_2, data_3, data_4, i_selected)
Get all arrays contained in list.
subroutine, public create_array_list(list, ndata, data_1, data_2, data_3, data_4)
collects any number of arrays of different sizes into a single array (listcol_data),...
subroutine, public destroy_array_list(list)
destroy array list.
integer function, dimension(:), allocatable, public sizes_of_arrays(list)
sizes of arrays stored in list
type(array_list) function, public array_sublist(list, i_selected)
extract a subset of arrays
subroutine, public reorder_arrays(list_in, list_out, order)
reorder array list.
logical function, public check_equal(list1, list2)
check whether two array lists are equal
Methods to operate on n-dimensional tensor blocks.
Definition dbt_block.F:12
elemental logical function, public checker_tr(row, column)
Determines whether a transpose must be applied.
Definition dbt_block.F:453
logical function, public dbt_iterator_blocks_left(iterator)
Generalization of block_iterator_blocks_left for tensors.
Definition dbt_block.F:197
subroutine, public destroy_block(block)
Definition dbt_block.F:435
pure integer function, public ndims_iterator(iterator)
Number of dimensions.
Definition dbt_block.F:146
subroutine, public dbt_iterator_stop(iterator)
Generalization of block_iterator_stop for tensors.
Definition dbt_block.F:134
subroutine, public dbt_iterator_start(iterator, tensor)
Generalization of block_iterator_start for tensors.
Definition dbt_block.F:121
subroutine, public dbt_iterator_next_block(iterator, ind_nd, blk_size, blk_offset)
iterate over nd blocks of an nd rank tensor, index only (blocks must be retrieved by calling dbt_get_...
Definition dbt_block.F:161
tensor index and mapping to DBM index
Definition dbt_index.F:12
pure integer function, public ndims_mapping_row(map)
how many tensor dimensions are mapped to matrix row
Definition dbt_index.F:141
pure integer function, dimension(size(order)), public dbt_inverse_order(order)
Invert order.
Definition dbt_index.F:410
pure integer function, public ndims_mapping(map)
Definition dbt_index.F:130
subroutine, public permute_index(map_in, map_out, order)
reorder tensor index (no data)
Definition dbt_index.F:423
pure subroutine, public dbt_get_mapping_info(map, ndim_nd, ndim1_2d, ndim2_2d, dims_2d_i8, dims_2d, dims_nd, dims1_2d, dims2_2d, map1_2d, map2_2d, map_nd, base, col_major)
get mapping info
Definition dbt_index.F:176
pure integer function, public ndims_mapping_column(map)
how many tensor dimensions are mapped to matrix column
Definition dbt_index.F:151
pure integer function, dimension(map%ndim_nd), public get_nd_indices_tensor(map, ind_in)
transform 2d index to nd index, using info from index mapping.
Definition dbt_index.F:368
DBT tensor Input / Output.
Definition dbt_io.F:12
subroutine, public dbt_write_tensor_info(tensor, unit_nr, full_info)
Write tensor global info: block dimensions, full dimensions and process grid dimensions.
Definition dbt_io.F:50
subroutine, public dbt_write_tensor_dist(tensor, unit_nr)
Write info on tensor distribution & load balance.
Definition dbt_io.F:161
subroutine, public dbt_write_split_info(pgrid, unit_nr)
Definition dbt_io.F:401
integer function, public prep_output_unit(unit_nr)
Definition dbt_io.F:413
DBT tensor framework for block-sparse tensor contraction. Representation of n-rank tensors as DBT tal...
Definition dbt_methods.F:16
subroutine, public dbt_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
Copy tensor data. Redistributes tensor data according to distributions of target and source tensor....
subroutine, public dbt_batched_contract_finalize(tensor, unit_nr)
finalize batched contraction. This performs all communication that has been postponed in the contract...
subroutine, public dbt_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
copy matrix to tensor.
subroutine, public dbt_contract(alpha, tensor_1, tensor_2, beta, tensor_3, contract_1, notcontract_1, contract_2, notcontract_2, map_1, map_2, bounds_1, bounds_2, bounds_3, optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
Contract tensors by multiplying matrix representations. tensor_3(map_1, map_2) := alpha * tensor_1(no...
subroutine, public dbt_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
copy tensor to matrix
subroutine, public dbt_batched_contract_init(tensor, batch_range_1, batch_range_2, batch_range_3, batch_range_4)
Initialize batched contraction for this tensor.
Routines to reshape / redistribute tensors.
subroutine, public dbt_reshape(tensor_in, tensor_out, summation, move_data)
copy data (involves reshape) tensor_out = tensor_out + tensor_in move_data memory optimization: trans...
Routines to split blocks and to convert between tensors with different block sizes.
Definition dbt_split.F:12
subroutine, public dbt_split_copyback(tensor_split_in, tensor_out, summation)
Copy tensor with split blocks to tensor with original block sizes.
Definition dbt_split.F:527
subroutine, public dbt_make_compatible_blocks(tensor1, tensor2, tensor1_split, tensor2_split, order, nodata1, nodata2, move_data)
split two tensors with same total sizes but different block sizes such that they have equal block siz...
Definition dbt_split.F:788
subroutine, public dbt_crop(tensor_in, tensor_out, bounds, move_data)
Definition dbt_split.F:934
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
subroutine, public dbt_tas_finalize(matrix)
...
type(dbt_tas_split_info) function, public dbt_tas_info(matrix)
get info on mpi grid splitting
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
Definition dbt_tas_mm.F:18
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
subroutine, public dbt_tas_batched_mm_init(matrix)
...
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
Definition dbt_tas_mm.F:105
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
integer, parameter, public colsplit
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
real(dp), parameter, public default_nsplit_accept_ratio
real(dp), parameter, public default_pdims_accept_ratio
subroutine, public dbt_tas_info_hold(split_info)
...
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
DBT tensor framework for block-sparse tensor contraction: Types and create/destroy routines.
Definition dbt_types.F:12
subroutine, public dbt_pgrid_destroy(pgrid, keep_comm)
destroy process grid
Definition dbt_types.F:905
subroutine, public dbt_distribution_new(dist, pgrid, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4)
Create a tensor distribution.
Definition dbt_types.F:886
subroutine, public blk_dims_tensor(tensor, dims)
tensor block dimensions
Definition dbt_types.F:1466
subroutine, public dims_tensor(tensor, dims)
tensor dimensions
Definition dbt_types.F:1238
subroutine, public dbt_copy_contraction_storage(tensor_in, tensor_out)
Definition dbt_types.F:1888
type(dbt_pgrid_type) function, public dbt_nd_mp_comm(comm_2d, map1_2d, map2_2d, dims_nd, dims1_nd, dims2_nd, pdims_2d, tdims, nsplit, dimsplit)
Create a default nd process topology that is consistent with a given 2d topology. Purpose: a nd tenso...
Definition dbt_types.F:653
subroutine, public dbt_destroy(tensor)
Destroy a tensor.
Definition dbt_types.F:1410
pure integer function, public dbt_max_nblks_local(tensor)
returns an estimate of maximum number of local blocks in tensor (irrespective of the actual number of...
Definition dbt_types.F:1850
subroutine, public dbt_get_info(tensor, nblks_total, nfull_total, nblks_local, nfull_local, pdims, my_ploc, blks_local_1, blks_local_2, blks_local_3, blks_local_4, proc_dist_1, proc_dist_2, proc_dist_3, proc_dist_4, blk_size_1, blk_size_2, blk_size_3, blk_size_4, blk_offset_1, blk_offset_2, blk_offset_3, blk_offset_4, distribution, name)
As block_get_info but for tensors.
Definition dbt_types.F:1656
subroutine, public dbt_distribution_new_expert(dist, pgrid, map1_2d, map2_2d, nd_dist_1, nd_dist_2, nd_dist_3, nd_dist_4, own_comm)
Create a tensor distribution.
Definition dbt_types.F:787
type(dbt_distribution_type) function, public dbt_distribution(tensor)
get distribution from tensor
Definition dbt_types.F:980
pure integer function, public ndims_tensor(tensor)
tensor rank
Definition dbt_types.F:1227
pure integer function, public dbt_nblks_total(tensor, idim)
total numbers of blocks along dimension idim
Definition dbt_types.F:1617
pure integer function, public dbt_get_num_blocks(tensor)
As block_get_num_blocks: get number of local blocks.
Definition dbt_types.F:1759
subroutine, public dbt_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
Definition dbt_types.F:1876
subroutine, public dbt_hold(tensor)
reference counting for tensors (only needed for communicator handle that must be freed when no longer...
Definition dbt_types.F:1188
subroutine, public dbt_clear(tensor)
Clear tensor (s.t. it does not contain any blocks)
Definition dbt_types.F:1779
subroutine, public dbt_finalize(tensor)
Finalize tensor, as block_finalize. This should be taken care of internally in DBT tensors,...
Definition dbt_types.F:1790
subroutine, public mp_environ_pgrid(pgrid, dims, task_coor)
as mp_environ but for special pgrid type
Definition dbt_types.F:768
subroutine, public dbt_get_stored_coordinates(tensor, ind_nd, processor)
Generalization of block_get_stored_coordinates for tensors.
Definition dbt_types.F:1510
integer(kind=int_8) function, public dbt_get_num_blocks_total(tensor)
Get total number of blocks.
Definition dbt_types.F:1769
pure integer(int_8) function, public ndims_matrix_row(tensor)
how many tensor dimensions are mapped to matrix row
Definition dbt_types.F:1204
pure integer(int_8) function, public ndims_matrix_column(tensor)
how many tensor dimensions are mapped to matrix column
Definition dbt_types.F:1216
subroutine, public dbt_filter(tensor, eps)
As block_filter.
Definition dbt_types.F:1588
subroutine, public dbt_distribution_destroy(dist)
Destroy tensor distribution.
Definition dbt_types.F:926
subroutine, public dbt_scale(tensor, alpha)
as block_scale
Definition dbt_types.F:1799
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
Interface to the message passing library MPI.
All kind of helpful little routines.
Definition util.F:14