(git:34ef472)
dbt_tas_mm.F
Go to the documentation of this file.
1 !--------------------------------------------------------------------------------------------------!
2 ! CP2K: A general program to perform molecular dynamics simulations !
3 ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 ! !
5 ! SPDX-License-Identifier: GPL-2.0-or-later !
6 !--------------------------------------------------------------------------------------------------!
7 
8 ! **************************************************************************************************
9 !> \brief Matrix multiplication for tall-and-skinny matrices.
10 !> This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
11 !> as long as the two smaller dimensions have the same size.
12 !> Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
13 !> submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
14 !> parameters (group sizes and process grid dimensions) can not be derived from matrix
15 !> dimensions and need to be set manually.
16 !> \author Patrick Seewald
17 ! **************************************************************************************************
18 MODULE dbt_tas_mm
19  USE dbm_api, ONLY: &
21  dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
24  USE dbt_tas_base, ONLY: &
27  dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
29  dbt_tas_reserve_blocks
30  USE dbt_tas_global, ONLY: dbt_tas_blk_size_one,&
32  dbt_tas_dist_arb,&
34  dbt_tas_dist_cyclic,&
35  dbt_tas_distribution,&
36  dbt_tas_rowcol_data
37  USE dbt_tas_io, ONLY: dbt_tas_write_dist,&
44  USE dbt_tas_split, ONLY: &
47  rowsplit
48  USE dbt_tas_types, ONLY: dbt_tas_distribution_type,&
49  dbt_tas_iterator,&
50  dbt_tas_split_info,&
51  dbt_tas_type
52  USE dbt_tas_util, ONLY: array_eq,&
53  swap
54  USE kinds, ONLY: default_string_length,&
55  dp,&
56  int_8
57  USE message_passing, ONLY: mp_cart_type
58 #include "../../base/base_uses.f90"
59 
60  IMPLICIT NONE
61  PRIVATE
62 
63  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
64 
65  PUBLIC :: &
71 
72 CONTAINS
73 
74 ! **************************************************************************************************
75 !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
76 !> to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
77 !> \param transa ...
78 !> \param transb ...
79 !> \param transc ...
80 !> \param alpha ...
81 !> \param matrix_a ...
82 !> \param matrix_b ...
83 !> \param beta ...
84 !> \param matrix_c ...
85 !> \param optimize_dist Whether distribution should be optimized internally. In the current
86 !> implementation this guarantees optimal parameters only for dense matrices.
87 !> \param split_opt optionally return split info containing optimal grid and split parameters.
88 !> This can be used to choose optimal process grids for subsequent matrix
89 !> multiplications with matrices of similar shape and sparsity.
90 !> \param filter_eps ...
91 !> \param flop ...
92 !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
93 !> (for internal use only)
94 !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
95 !> (for internal use only)
96 !> \param retain_sparsity ...
97 !> \param simple_split ...
98 !> \param unit_nr unit number for logging output
99 !> \param log_verbose only for testing: verbose output
100 !> \author Patrick Seewald
101 ! **************************************************************************************************
102  RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103  optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104  move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
105 
106  LOGICAL, INTENT(IN) :: transa, transb, transc
107  REAL(dp), INTENT(IN) :: alpha
108  TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b
109  REAL(dp), INTENT(IN) :: beta
110  TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_c
111  LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
112  TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL :: split_opt
113  REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
114  INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
115  LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, &
116  retain_sparsity, simple_split
117  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
118  LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
119 
120  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_tas_multiply'
121 
122  INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123  nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
124  unit_nr_prv
125  INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126  INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c
127  INTEGER(KIND=int_8), DIMENSION(3) :: dims
128  INTEGER, DIMENSION(2) :: pdims, pdims_sub
129  LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130  simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131  REAL(kind=dp) :: filter_eps_prv
132  TYPE(dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
133  TYPE(dbt_tas_split_info) :: info, info_a, info_b, info_c
134  TYPE(dbt_tas_type), POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135  matrix_b_rs, matrix_c_rep, matrix_c_rs
136  TYPE(mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137  mp_comm_mm, mp_comm_opt
138 
139  CALL timeset(routinen, handle)
140  CALL matrix_a%dist%info%mp_comm%sync()
141  CALL timeset("dbt_tas_total", handle2)
142 
143  NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
144 
145  unit_nr_prv = prep_output_unit(unit_nr)
146 
147  IF (PRESENT(simple_split)) THEN
148  simple_split_prv = simple_split
149  ELSE
150  simple_split_prv = .false.
151 
152  info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
153  IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .true.
154  END IF
155 
156  nodata_3 = .true.
157  IF (PRESENT(retain_sparsity)) THEN
158  IF (retain_sparsity) nodata_3 = .false.
159  END IF
160 
161  ! get prestored info for multiplication strategy in case of batched mm
162  batched_repl = 0
163  do_batched = .false.
164  IF (matrix_a%do_batched > 0) THEN
165  do_batched = .true.
166  IF (matrix_a%do_batched == 3) THEN
167  cpassert(batched_repl == 0)
168  batched_repl = 1
169  CALL dbt_tas_get_split_info( &
170  dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
171  nsplit=nsplit_batched)
172  cpassert(nsplit_batched > 0)
173  max_mm_dim_batched = 3
174  END IF
175  END IF
176 
177  IF (matrix_b%do_batched > 0) THEN
178  do_batched = .true.
179  IF (matrix_b%do_batched == 3) THEN
180  cpassert(batched_repl == 0)
181  batched_repl = 2
182  CALL dbt_tas_get_split_info( &
183  dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
184  nsplit=nsplit_batched)
185  cpassert(nsplit_batched > 0)
186  max_mm_dim_batched = 1
187  END IF
188  END IF
189 
190  IF (matrix_c%do_batched > 0) THEN
191  do_batched = .true.
192  IF (matrix_c%do_batched == 3) THEN
193  cpassert(batched_repl == 0)
194  batched_repl = 3
195  CALL dbt_tas_get_split_info( &
196  dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
197  nsplit=nsplit_batched)
198  cpassert(nsplit_batched > 0)
199  max_mm_dim_batched = 2
200  END IF
201  END IF
202 
203  move_a = .false.
204  move_b = .false.
205 
206  IF (PRESENT(move_data_a)) move_a = move_data_a
207  IF (PRESENT(move_data_b)) move_b = move_data_b
208 
209  transa_prv = transa; transb_prv = transb; transc_prv = transc
210 
211  dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
212  dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
213  dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
214 
215  IF (unit_nr_prv > 0) THEN
216  WRITE (unit_nr_prv, "(A)") repeat("-", 80)
217  WRITE (unit_nr_prv, "(A)") &
218  "DBT TAS MATRIX MULTIPLICATION: "// &
219  trim(dbm_get_name(matrix_a%matrix))//" x "// &
220  trim(dbm_get_name(matrix_b%matrix))//" = "// &
221  trim(dbm_get_name(matrix_c%matrix))
222  WRITE (unit_nr_prv, "(A)") repeat("-", 80)
223  END IF
224  IF (do_batched) THEN
225  IF (unit_nr_prv > 0) THEN
226  WRITE (unit_nr_prv, "(T2,A)") &
227  "BATCHED PROCESSING OF MATMUL"
228  IF (batched_repl > 0) THEN
229  WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
230  END IF
231  END IF
232  END IF
233 
234  IF (transa_prv) THEN
235  CALL swap(dims_a)
236  END IF
237 
238  IF (transb_prv) THEN
239  CALL swap(dims_b)
240  END IF
241 
242  dims_c = [dims_a(1), dims_b(2)]
243 
244  IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
245  cpabort("inconsistent matrix dimensions")
246  END IF
247 
248  dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
249 
250  IF (unit_nr_prv > 0) THEN
251  WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
252  END IF
253 
254  CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
255  numproc = mp_comm%num_pe
256 
257  ! derive optimal matrix layout and split factor from occupancies
258  nze_a = dbt_tas_get_nze_total(matrix_a)
259  nze_b = dbt_tas_get_nze_total(matrix_b)
260 
261  IF (.NOT. simple_split_prv) THEN
262  CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263  estimated_nze=nze_c, filter_eps=filter_eps, &
264  retain_sparsity=retain_sparsity)
265 
266  max_mm_dim = maxloc(dims, 1)
267  nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
268  nsplit_opt = nsplit
269 
270  IF (unit_nr_prv > 0) THEN
271  WRITE (unit_nr_prv, "(T2,A)") &
272  "MM PARAMETERS"
273  WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
274  (nze_c + numproc - 1)/numproc
275 
276  WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
277  END IF
278 
279  ELSEIF (batched_repl > 0) THEN
280  nsplit = nsplit_batched
281  nsplit_opt = nsplit
282  max_mm_dim = max_mm_dim_batched
283  IF (unit_nr_prv > 0) THEN
284  WRITE (unit_nr_prv, "(T2,A)") &
285  "MM PARAMETERS"
286  WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
287  END IF
288 
289  ELSE
290  nsplit = 0
291  max_mm_dim = maxloc(dims, 1)
292  END IF
293 
294  ! reshape matrices to the optimal layout and split factor
295  split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
296  SELECT CASE (max_mm_dim)
297  CASE (1)
298 
299  split_a = rowsplit; split_c = rowsplit
300  CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301  new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
302  nsplit=nsplit, &
303  opt_nsplit=batched_repl == 0, &
304  split_rc_1=split_a, split_rc_2=split_c, &
305  nodata2=nodata_3, comm_new=comm_tmp, &
306  move_data_1=move_a, unit_nr=unit_nr_prv)
307 
308  info = dbt_tas_info(matrix_a_rs)
309  CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
310 
311  new_b = .false.
312  IF (matrix_b%do_batched <= 2) THEN
313  ALLOCATE (matrix_b_rs)
314  CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
315  transb_prv = .false.
316  new_b = .true.
317  END IF
318 
319  tr_case = transa_prv
320 
321  IF (unit_nr_prv > 0) THEN
322  IF (.NOT. tr_case) THEN
323  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
324  ELSE
325  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
326  END IF
327  END IF
328 
329  CASE (2)
330 
331  split_a = colsplit; split_b = rowsplit
332  CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333  optimize_dist=optimize_dist, &
334  nsplit=nsplit, &
335  opt_nsplit=batched_repl == 0, &
336  split_rc_1=split_a, split_rc_2=split_b, &
337  comm_new=comm_tmp, &
338  move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
339 
340  info = dbt_tas_info(matrix_a_rs)
341  CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
342 
343  IF (matrix_c%do_batched == 1) THEN
344  matrix_c%mm_storage%batched_beta = beta
345  ELSEIF (matrix_c%do_batched > 1) THEN
346  matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
347  END IF
348 
349  IF (matrix_c%do_batched <= 2) THEN
350  ALLOCATE (matrix_c_rs)
351  CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
352  transc_prv = .false.
353 
354  ! just leave sparsity structure for retain sparsity but no values
355  IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
356 
357  IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358  ELSEIF (matrix_c%do_batched == 3) THEN
359  matrix_c_rs => matrix_c%mm_storage%store_batched
360  END IF
361 
362  new_c = matrix_c%do_batched == 0
363  tr_case = transa_prv
364 
365  IF (unit_nr_prv > 0) THEN
366  IF (.NOT. tr_case) THEN
367  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
368  ELSE
369  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
370  END IF
371  END IF
372 
373  CASE (3)
374 
375  split_b = colsplit; split_c = colsplit
376  CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377  transc_prv, optimize_dist=optimize_dist, &
378  nsplit=nsplit, &
379  opt_nsplit=batched_repl == 0, &
380  split_rc_1=split_b, split_rc_2=split_c, &
381  nodata2=nodata_3, comm_new=comm_tmp, &
382  move_data_1=move_b, unit_nr=unit_nr_prv)
383  info = dbt_tas_info(matrix_b_rs)
384  CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
385 
386  new_a = .false.
387  IF (matrix_a%do_batched <= 2) THEN
388  ALLOCATE (matrix_a_rs)
389  CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
390  transa_prv = .false.
391  new_a = .true.
392  END IF
393 
394  tr_case = transb_prv
395 
396  IF (unit_nr_prv > 0) THEN
397  IF (.NOT. tr_case) THEN
398  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
399  ELSE
400  WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
401  END IF
402  END IF
403 
404  END SELECT
405 
406  CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
407 
408  numproc = mp_comm%num_pe
409  pdims_sub = mp_comm_group%num_pe_cart
410 
411  opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.true.)
412 
413  IF (PRESENT(filter_eps)) THEN
414  filter_eps_prv = filter_eps
415  ELSE
416  filter_eps_prv = 0.0_dp
417  END IF
418 
419  IF (unit_nr_prv /= 0) THEN
420  IF (unit_nr_prv > 0) THEN
421  WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
422  END IF
423  CALL dbt_tas_write_split_info(info, unit_nr_prv)
424  IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
425  IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
426  IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
427  IF (unit_nr_prv > 0) THEN
428  IF (opt_pgrid) THEN
429  WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
430  ELSE
431  WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
432  END IF
433  END IF
434  END IF
435 
436  pdims = 0
437  CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
438 
439  ! Convert DBM submatrices to optimized process grids and multiply
440  SELECT CASE (max_mm_dim)
441  CASE (1)
442  IF (matrix_b%do_batched <= 2) THEN
443  ALLOCATE (matrix_b_rep)
444  CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.true.)
445  IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
446  matrix_b%mm_storage%store_batched_repl => matrix_b_rep
447  CALL dbt_tas_set_batched_state(matrix_b, state=3)
448  END IF
449  ELSEIF (matrix_b%do_batched == 3) THEN
450  matrix_b_rep => matrix_b%mm_storage%store_batched_repl
451  END IF
452 
453  IF (new_b) THEN
454  CALL dbt_tas_destroy(matrix_b_rs)
455  DEALLOCATE (matrix_b_rs)
456  END IF
457  IF (unit_nr_prv /= 0) THEN
458  CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
459  CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
460  END IF
461 
462  CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
463 
464  ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
465  info_a = dbt_tas_info(matrix_a_rs)
466  CALL dbt_tas_info_hold(info_a)
467 
468  IF (new_a) THEN
469  CALL dbt_tas_destroy(matrix_a_rs)
470  DEALLOCATE (matrix_a_rs)
471  END IF
472  CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473  move_data=matrix_b%do_batched == 0)
474 
475  info_b = dbt_tas_info(matrix_b_rep)
476  CALL dbt_tas_info_hold(info_b)
477 
478  IF (matrix_b%do_batched == 0) THEN
479  CALL dbt_tas_destroy(matrix_b_rep)
480  DEALLOCATE (matrix_b_rep)
481  END IF
482 
483  CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
484 
485  info_c = dbt_tas_info(matrix_c_rs)
486  CALL dbt_tas_info_hold(info_c)
487 
488  CALL matrix_a%dist%info%mp_comm%sync()
489  CALL timeset("dbt_tas_dbm", handle4)
490  IF (.NOT. tr_case) THEN
491  CALL timeset("dbt_tas_mm_1N", handle3)
492 
493  CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
494  matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495  filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496  CALL timestop(handle3)
497  ELSE
498  CALL timeset("dbt_tas_mm_1T", handle3)
499  CALL dbm_multiply(transa=.true., transb=.false., alpha=alpha, &
500  matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501  filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
502 
503  CALL timestop(handle3)
504  END IF
505  CALL matrix_a%dist%info%mp_comm%sync()
506  CALL timestop(handle4)
507 
508  CALL dbm_release(matrix_a_mm)
509  CALL dbm_release(matrix_b_mm)
510 
511  nze_c = dbm_get_nze(matrix_c_mm)
512 
513  IF (.NOT. new_c) THEN
514  CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
515  ELSE
516  CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
517  END IF
518 
519  CALL dbm_release(matrix_c_mm)
520 
521  IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
522 
523  IF (unit_nr_prv /= 0) THEN
524  CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
525  END IF
526 
527  CASE (2)
528  IF (matrix_c%do_batched <= 1) THEN
529  ALLOCATE (matrix_c_rep)
530  CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
531  IF (matrix_c%do_batched == 1) THEN
532  matrix_c%mm_storage%store_batched_repl => matrix_c_rep
533  CALL dbt_tas_set_batched_state(matrix_c, state=3)
534  END IF
535  ELSEIF (matrix_c%do_batched == 2) THEN
536  ALLOCATE (matrix_c_rep)
537  CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
538  ! just leave sparsity structure for retain sparsity but no values
539  IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
540  matrix_c%mm_storage%store_batched_repl => matrix_c_rep
541  CALL dbt_tas_set_batched_state(matrix_c, state=3)
542  ELSEIF (matrix_c%do_batched == 3) THEN
543  matrix_c_rep => matrix_c%mm_storage%store_batched_repl
544  END IF
545 
546  IF (unit_nr_prv /= 0) THEN
547  CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
548  CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
549  END IF
550 
551  CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
552 
553  ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
554  info_a = dbt_tas_info(matrix_a_rs)
555  CALL dbt_tas_info_hold(info_a)
556 
557  IF (new_a) THEN
558  CALL dbt_tas_destroy(matrix_a_rs)
559  DEALLOCATE (matrix_a_rs)
560  END IF
561 
562  CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
563 
564  info_b = dbt_tas_info(matrix_b_rs)
565  CALL dbt_tas_info_hold(info_b)
566 
567  IF (new_b) THEN
568  CALL dbt_tas_destroy(matrix_b_rs)
569  DEALLOCATE (matrix_b_rs)
570  END IF
571 
572  CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
573 
574  info_c = dbt_tas_info(matrix_c_rep)
575  CALL dbt_tas_info_hold(info_c)
576 
577  CALL matrix_a%dist%info%mp_comm%sync()
578  CALL timeset("dbt_tas_dbm", handle4)
579  CALL timeset("dbt_tas_mm_2", handle3)
580  CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581  matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582  filter_eps=filter_eps_prv/real(nsplit, kind=dp), retain_sparsity=retain_sparsity, flop=flop)
583  CALL matrix_a%dist%info%mp_comm%sync()
584  CALL timestop(handle3)
585  CALL timestop(handle4)
586 
587  CALL dbm_release(matrix_a_mm)
588  CALL dbm_release(matrix_b_mm)
589 
590  nze_c = dbm_get_nze(matrix_c_mm)
591 
592  CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
593  nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
594 
595  CALL dbm_release(matrix_c_mm)
596 
597  IF (unit_nr_prv /= 0) THEN
598  CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
599  END IF
600 
601  IF (matrix_c%do_batched == 0) THEN
602  CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.true.)
603  ELSE
604  matrix_c%mm_storage%batched_out = .true. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
605  END IF
606 
607  IF (matrix_c%do_batched == 0) THEN
608  CALL dbt_tas_destroy(matrix_c_rep)
609  DEALLOCATE (matrix_c_rep)
610  END IF
611 
612  IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
613 
614  ! set upper limit on memory consumption for replicated matrix and complete batched mm
615  ! if limit is exceeded
616  IF (nze_c_sum > default_nsplit_accept_ratio*max(nze_a, nze_b)) THEN
617  CALL dbt_tas_batched_mm_complete(matrix_c)
618  END IF
619 
620  CASE (3)
621  IF (matrix_a%do_batched <= 2) THEN
622  ALLOCATE (matrix_a_rep)
623  CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.true.)
624  IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
625  matrix_a%mm_storage%store_batched_repl => matrix_a_rep
626  CALL dbt_tas_set_batched_state(matrix_a, state=3)
627  END IF
628  ELSEIF (matrix_a%do_batched == 3) THEN
629  matrix_a_rep => matrix_a%mm_storage%store_batched_repl
630  END IF
631 
632  IF (new_a) THEN
633  CALL dbt_tas_destroy(matrix_a_rs)
634  DEALLOCATE (matrix_a_rs)
635  END IF
636  IF (unit_nr_prv /= 0) THEN
637  CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
638  CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
639  END IF
640 
641  CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642  move_data=matrix_a%do_batched == 0)
643 
644  ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
645  info_a = dbt_tas_info(matrix_a_rep)
646  CALL dbt_tas_info_hold(info_a)
647 
648  IF (matrix_a%do_batched == 0) THEN
649  CALL dbt_tas_destroy(matrix_a_rep)
650  DEALLOCATE (matrix_a_rep)
651  END IF
652 
653  CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
654 
655  info_b = dbt_tas_info(matrix_b_rs)
656  CALL dbt_tas_info_hold(info_b)
657 
658  IF (new_b) THEN
659  CALL dbt_tas_destroy(matrix_b_rs)
660  DEALLOCATE (matrix_b_rs)
661  END IF
662  CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
663 
664  info_c = dbt_tas_info(matrix_c_rs)
665  CALL dbt_tas_info_hold(info_c)
666 
667  CALL matrix_a%dist%info%mp_comm%sync()
668  CALL timeset("dbt_tas_dbm", handle4)
669  IF (.NOT. tr_case) THEN
670  CALL timeset("dbt_tas_mm_3N", handle3)
671  CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
672  matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673  filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674  CALL timestop(handle3)
675  ELSE
676  CALL timeset("dbt_tas_mm_3T", handle3)
677  CALL dbm_multiply(transa=.false., transb=.true., alpha=alpha, &
678  matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679  filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680  CALL timestop(handle3)
681  END IF
682  CALL matrix_a%dist%info%mp_comm%sync()
683  CALL timestop(handle4)
684 
685  CALL dbm_release(matrix_a_mm)
686  CALL dbm_release(matrix_b_mm)
687 
688  nze_c = dbm_get_nze(matrix_c_mm)
689 
690  IF (.NOT. new_c) THEN
691  CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
692  ELSE
693  CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
694  END IF
695 
696  CALL dbm_release(matrix_c_mm)
697 
698  IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
699 
700  IF (unit_nr_prv /= 0) THEN
701  CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
702  END IF
703  END SELECT
704 
705  CALL mp_comm_mm%free()
706 
707  CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
708 
709  IF (PRESENT(split_opt)) THEN
710  SELECT CASE (max_mm_dim)
711  CASE (1, 3)
712  CALL mp_comm%sum(nze_c)
713  CASE (2)
714  CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
715  CALL mp_comm%sum(nze_c)
716  CALL mp_comm%max(nze_c)
717 
718  END SELECT
719  nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
720  ! ideally we should rederive the split factor from the actual sparsity of C, but
721  ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
722  mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
723  CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.true.)
724  IF (unit_nr_prv > 0) THEN
725  WRITE (unit_nr_prv, "(T2,A)") &
726  "MM PARAMETERS"
727  WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
728  (nze_c + numproc - 1)/numproc
729 
730  WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
731  END IF
732 
733  END IF
734 
735  IF (new_c) THEN
736  CALL dbm_scale(matrix_c%matrix, beta)
737  CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.true., &
738  transposed=(transc_prv .NEQV. transc), &
739  move_data=.true.)
740  CALL dbt_tas_destroy(matrix_c_rs)
741  DEALLOCATE (matrix_c_rs)
742  IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
743  ELSEIF (matrix_c%do_batched > 0) THEN
744  IF (matrix_c%mm_storage%batched_out) THEN
745  matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
746  END IF
747  END IF
748 
749  IF (PRESENT(move_data_a)) THEN
750  IF (move_data_a) CALL dbt_tas_clear(matrix_a)
751  END IF
752  IF (PRESENT(move_data_b)) THEN
753  IF (move_data_b) CALL dbt_tas_clear(matrix_b)
754  END IF
755 
756  IF (PRESENT(flop)) THEN
757  CALL mp_comm%sum(flop)
758  flop = (flop + numproc - 1)/numproc
759  END IF
760 
761  IF (PRESENT(optimize_dist)) THEN
762  IF (optimize_dist) CALL comm_tmp%free()
763  END IF
764  IF (unit_nr_prv > 0) THEN
765  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
766  WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
767  WRITE (unit_nr_prv, '(A)') repeat("-", 80)
768  END IF
769 
770  CALL dbt_tas_release_info(info_a)
771  CALL dbt_tas_release_info(info_b)
772  CALL dbt_tas_release_info(info_c)
773 
774  CALL matrix_a%dist%info%mp_comm%sync()
775  CALL timestop(handle2)
776  CALL timestop(handle)
777 
778  END SUBROUTINE
779 
780 ! **************************************************************************************************
781 !> \brief ...
782 !> \param matrix_in ...
783 !> \param matrix_out ...
784 !> \param local_copy ...
785 !> \param alpha ...
786 !> \author Patrick Seewald
787 ! **************************************************************************************************
788  SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789  TYPE(dbm_type), INTENT(IN) :: matrix_in
790  TYPE(dbm_type), INTENT(INOUT) :: matrix_out
791  LOGICAL, INTENT(IN), OPTIONAL :: local_copy
792  REAL(dp), INTENT(IN) :: alpha
793 
794  LOGICAL :: local_copy_prv
795  TYPE(dbm_type) :: matrix_tmp
796 
797  IF (PRESENT(local_copy)) THEN
798  local_copy_prv = local_copy
799  ELSE
800  local_copy_prv = .false.
801  END IF
802 
803  IF (alpha /= 1.0_dp) THEN
804  CALL dbm_scale(matrix_out, alpha)
805  END IF
806 
807  IF (.NOT. local_copy_prv) THEN
808  CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
809  CALL dbm_redistribute(matrix_in, matrix_tmp)
810  CALL dbm_add(matrix_out, matrix_tmp)
811  CALL dbm_release(matrix_tmp)
812  ELSE
813  CALL dbm_add(matrix_out, matrix_in)
814  END IF
815 
816  END SUBROUTINE
817 
818 ! **************************************************************************************************
819 !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
820 !> the same process grid as the other 2 matrices.
821 !> \param mp_comm communicator that defines Cartesian topology
822 !> \param matrix_in ...
823 !> \param matrix_out ...
824 !> \param transposed Whether matrix_out should be transposed
825 !> \param nodata Data of matrix_in should not be copied to matrix_out
826 !> \param move_data memory optimization: move data such that matrix_in is empty on return.
827 !> \author Patrick Seewald
828 ! **************************************************************************************************
829  SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
830  TYPE(mp_cart_type), INTENT(IN) :: mp_comm
831  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
832  TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
833  LOGICAL, INTENT(IN) :: transposed
834  LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
835 
836  CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_small'
837 
838  INTEGER :: handle
839  INTEGER(KIND=int_8), DIMENSION(2) :: dims
840  INTEGER, DIMENSION(2) :: pdims
841  LOGICAL :: nodata_prv
842  TYPE(dbt_tas_dist_arb) :: new_col_dist, new_row_dist
843  TYPE(dbt_tas_distribution_type) :: dist
844 
845  CALL timeset(routinen, handle)
846 
847  IF (PRESENT(nodata)) THEN
848  nodata_prv = nodata
849  ELSE
850  nodata_prv = .false.
851  END IF
852 
853  pdims = mp_comm%num_pe_cart
854 
855  dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
856 
857  IF (transposed) CALL swap(dims)
858 
859  IF (.NOT. transposed) THEN
860  new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
861  new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
862  CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.true.)
863  CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
864  matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
865  ELSE
866  new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
867  new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
868  CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.true.)
869  CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
870  matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
871  END IF
872  IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
873 
874  CALL timestop(handle)
875 
876  END SUBROUTINE
877 
878 ! **************************************************************************************************
879 !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
880 !> with the same split factor.
881 !> \param matrix1_in ...
882 !> \param matrix2_in ...
883 !> \param matrix1_out ...
884 !> \param matrix2_out ...
885 !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
886 !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
887 !> \param trans1 transpose flag of matrix1_in for multiplication
888 !> \param trans2 transpose flag of matrix2_in for multiplication
889 !> \param optimize_dist experimental: optimize matrix splitting and distribution
890 !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
891 !> \param opt_nsplit ...
892 !> \param split_rc_1 Whether to split rows or columns for matrix 1
893 !> \param split_rc_2 Whether to split rows or columns for matrix 2
894 !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
895 !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
896 !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
897 !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
898 !> \param comm_new returns the new communicator only if optimize_dist
899 !> \param unit_nr output unit
900 !> \author Patrick Seewald
901 ! **************************************************************************************************
902  SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903  optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904  move_data_1, move_data_2, comm_new, unit_nr)
905  TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix1_in, matrix2_in
906  TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix1_out, matrix2_out
907  LOGICAL, INTENT(OUT) :: new1, new2
908  LOGICAL, INTENT(INOUT) :: trans1, trans2
909  LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
910  INTEGER, INTENT(IN), OPTIONAL :: nsplit
911  LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
912  INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2
913  LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
914  LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
915  TYPE(mp_cart_type), INTENT(OUT), OPTIONAL :: comm_new
916  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
917 
918  CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_compatible'
919 
920  INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
921  unit_nr_prv
922  INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923  INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref
924  INTEGER, DIMENSION(2) :: pdims
925  LOGICAL :: nodata1_prv, nodata2_prv, &
926  optimize_dist_prv, trans1_newdist, &
927  trans2_newdist
928  TYPE(dbt_tas_dist_cyclic) :: col_dist_1, col_dist_2, row_dist_1, &
929  row_dist_2
930  TYPE(dbt_tas_distribution_type) :: dist_1, dist_2
931  TYPE(dbt_tas_split_info) :: split_info
932  TYPE(mp_cart_type) :: mp_comm
933 
934  CALL timeset(routinen, handle)
935  new1 = .false.; new2 = .false.
936 
937  IF (PRESENT(nodata1)) THEN
938  nodata1_prv = nodata1
939  ELSE
940  nodata1_prv = .false.
941  END IF
942 
943  IF (PRESENT(nodata2)) THEN
944  nodata2_prv = nodata2
945  ELSE
946  nodata2_prv = .false.
947  END IF
948 
949  unit_nr_prv = prep_output_unit(unit_nr)
950 
951  NULLIFY (matrix1_out, matrix2_out)
952 
953  IF (PRESENT(optimize_dist)) THEN
954  optimize_dist_prv = optimize_dist
955  ELSE
956  optimize_dist_prv = .false.
957  END IF
958 
959  dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
960  dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
961  nze1 = dbt_tas_get_nze_total(matrix1_in)
962  nze2 = dbt_tas_get_nze_total(matrix2_in)
963 
964  IF (trans1) split_rc_1 = mod(split_rc_1, 2) + 1
965 
966  IF (trans2) split_rc_2 = mod(split_rc_2, 2) + 1
967 
968  IF (nze1 >= nze2) THEN
969  ref = 1
970  split_rc_ref = split_rc_1
971  dims_ref = dims1
972  ELSE
973  ref = 2
974  split_rc_ref = split_rc_2
975  dims_ref = dims2
976  END IF
977 
978  IF (PRESENT(nsplit)) THEN
979  nsplit_prv = nsplit
980  ELSE
981  nsplit_prv = 0
982  END IF
983 
984  IF (optimize_dist_prv) THEN
985  cpassert(PRESENT(comm_new))
986  END IF
987 
988  IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
989  CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990  move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
991  CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
992  CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993  move_data=move_data_2, nodata=nodata2, opt_nsplit=.false.)
994  IF (unit_nr_prv > 0) THEN
995  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
996  trim(dbm_get_name(matrix1_in%matrix)), &
997  "and", trim(dbm_get_name(matrix2_in%matrix))
998  IF (new1) THEN
999  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1000  trim(dbm_get_name(matrix1_in%matrix)), ": Yes"
1001  ELSE
1002  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1003  trim(dbm_get_name(matrix1_in%matrix)), ": No"
1004  END IF
1005  IF (new2) THEN
1006  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1007  trim(dbm_get_name(matrix2_in%matrix)), ": Yes"
1008  ELSE
1009  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1010  trim(dbm_get_name(matrix2_in%matrix)), ": No"
1011  END IF
1012  END IF
1013  ELSE
1014 
1015  IF (optimize_dist_prv) THEN
1016  IF (unit_nr_prv > 0) THEN
1017  WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
1018  trim(dbm_get_name(matrix1_in%matrix)), &
1019  "and", trim(dbm_get_name(matrix2_in%matrix))
1020  END IF
1021 
1022  trans1_newdist = (split_rc_1 == colsplit)
1023  trans2_newdist = (split_rc_2 == colsplit)
1024 
1025  IF (trans1_newdist) THEN
1026  CALL swap(dims1)
1027  trans1 = .NOT. trans1
1028  END IF
1029 
1030  IF (trans2_newdist) THEN
1031  CALL swap(dims2)
1032  trans2 = .NOT. trans2
1033  END IF
1034 
1035  IF (nsplit_prv == 0) THEN
1036  SELECT CASE (split_rc_ref)
1037  CASE (rowsplit)
1038  d1 = dims_ref(1)
1039  d2 = dims_ref(2)
1040  CASE (colsplit)
1041  d1 = dims_ref(2)
1042  d2 = dims_ref(1)
1043  END SELECT
1044  nsplit_prv = int((d1 - 1)/d2 + 1)
1045  END IF
1046 
1047  cpassert(nsplit_prv > 0)
1048 
1049  CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
1050  comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
1051  CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
1052 
1053  pdims = comm_new%num_pe_cart
1054 
1055  ! use a very simple cyclic distribution that may not be load balanced if block
1056  ! sizes are not equal. However we can not use arbitrary distributions
1057  ! for large dimensions since this would require storing distribution vectors as arrays
1058  ! which can not be stored for large dimensions.
1059  row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
1060  col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
1061 
1062  row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
1063  col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
1064 
1065  CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
1066  CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
1067  CALL dbt_tas_release_info(split_info)
1068 
1069  ALLOCATE (matrix1_out)
1070  IF (.NOT. trans1_newdist) THEN
1071  CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1072  matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.true.)
1073 
1074  ELSE
1075  CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1076  matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.true.)
1077  END IF
1078 
1079  ALLOCATE (matrix2_out)
1080  IF (.NOT. trans2_newdist) THEN
1081  CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1082  matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.true.)
1083  ELSE
1084  CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1085  matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.true.)
1086  END IF
1087 
1088  IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089  IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1090  new1 = .true.
1091  new2 = .true.
1092 
1093  ELSE
1094  SELECT CASE (ref)
1095  CASE (1)
1096  IF (unit_nr_prv > 0) THEN
1097  WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1098  trim(dbm_get_name(matrix2_in%matrix))
1099  END IF
1100 
1101  CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102  move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1103 
1104  ALLOCATE (matrix2_out)
1105  CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106  nodata=nodata2, move_data=move_data_2)
1107  new2 = .true.
1108  CASE (2)
1109  IF (unit_nr_prv > 0) THEN
1110  WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1111  trim(dbm_get_name(matrix1_in%matrix))
1112  END IF
1113 
1114  CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115  move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1116 
1117  ALLOCATE (matrix1_out)
1118  CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119  nodata=nodata1, move_data=move_data_1)
1120  new1 = .true.
1121  END SELECT
1122  END IF
1123  END IF
1124 
1125  IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .true.
1126  IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .true.
1127 
1128  CALL timestop(handle)
1129 
1130  END SUBROUTINE
1131 
1132 ! **************************************************************************************************
1133 !> \brief Change split factor without redistribution
1134 !> \param matrix_in ...
1135 !> \param matrix_out ...
1136 !> \param nsplit new split factor, set to 0 to not change split of matrix_in
1137 !> \param split_rowcol split rows or columns
1138 !> \param is_new whether matrix_out is new or a pointer to matrix_in
1139 !> \param opt_nsplit whether nsplit should be optimized for current process grid
1140 !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1141 !> \param nodata Data of matrix_in should not be copied to matrix_out
1142 !> \author Patrick Seewald
1143 ! **************************************************************************************************
1144  SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1145  TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_in
1146  TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix_out
1147  INTEGER, INTENT(IN) :: nsplit, split_rowcol
1148  LOGICAL, INTENT(OUT) :: is_new
1149  LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
1150  LOGICAL, INTENT(INOUT), OPTIONAL :: move_data
1151  LOGICAL, INTENT(IN), OPTIONAL :: nodata
1152 
1153  CHARACTER(len=default_string_length) :: name
1154  INTEGER :: handle, nsplit_new, nsplit_old, &
1155  nsplit_prv, split_rc
1156  LOGICAL :: nodata_prv
1157  TYPE(dbt_tas_distribution_type) :: dist
1158  TYPE(dbt_tas_split_info) :: split_info
1159  TYPE(mp_cart_type) :: mp_comm
1160 
1161  CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
1162  CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize
1163  CHARACTER(LEN=*), PARAMETER :: routinen = 'change_split'
1164 
1165  NULLIFY (matrix_out)
1166 
1167  is_new = .true.
1168 
1169  CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
1170  split_rowcol=split_rc, nsplit=nsplit_old)
1171 
1172  IF (nsplit == 0) THEN
1173  IF (split_rowcol == split_rc) THEN
1174  matrix_out => matrix_in
1175  is_new = .false.
1176  RETURN
1177  ELSE
1178  nsplit_prv = 1
1179  END IF
1180  ELSE
1181  nsplit_prv = nsplit
1182  END IF
1183 
1184  CALL timeset(routinen, handle)
1185 
1186  nodata_prv = .false.
1187  IF (PRESENT(nodata)) nodata_prv = nodata
1188 
1189  CALL dbt_tas_get_info(matrix_in, name=name, &
1190  row_blk_size=rbsize, col_blk_size=cbsize, &
1191  proc_row_dist=rdist, proc_col_dist=cdist)
1192 
1193  CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1194 
1195  CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
1196 
1197  IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1198  matrix_out => matrix_in
1199  is_new = .false.
1200  CALL dbt_tas_release_info(split_info)
1201  CALL timestop(handle)
1202  RETURN
1203  END IF
1204 
1205  CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1206  split_info=split_info)
1207 
1208  CALL dbt_tas_release_info(split_info)
1209 
1210  ALLOCATE (matrix_out)
1211  CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.true.)
1212 
1213  IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
1214 
1215  IF (PRESENT(move_data)) THEN
1216  IF (.NOT. nodata_prv) THEN
1217  IF (move_data) CALL dbt_tas_clear(matrix_in)
1218  move_data = .true.
1219  END IF
1220  END IF
1221 
1222  CALL timestop(handle)
1223  END SUBROUTINE
1224 
1225 ! **************************************************************************************************
1226 !> \brief Check whether matrices have same distribution and same split.
1227 !> \param mat_a ...
1228 !> \param mat_b ...
1229 !> \param split_rc_a ...
1230 !> \param split_rc_b ...
1231 !> \param unit_nr ...
1232 !> \return ...
1233 !> \author Patrick Seewald
1234 ! **************************************************************************************************
1235  FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1236  TYPE(dbt_tas_type), INTENT(IN) :: mat_a, mat_b
1237  INTEGER, INTENT(IN) :: split_rc_a, split_rc_b
1238  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1239  LOGICAL :: dist_compatible
1240 
1241  INTEGER :: numproc, same_local_rowcols, &
1242  split_check_a, split_check_b, &
1243  unit_nr_prv
1244  INTEGER(int_8), ALLOCATABLE, DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245  INTEGER, DIMENSION(2) :: pdims_a, pdims_b
1246  TYPE(dbt_tas_split_info) :: info_a, info_b
1247 
1248  unit_nr_prv = prep_output_unit(unit_nr)
1249 
1250  dist_compatible = .false.
1251 
1252  info_a = dbt_tas_info(mat_a)
1253  info_b = dbt_tas_info(mat_b)
1254  CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
1255  CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
1256  IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
1257  IF (unit_nr_prv > 0) THEN
1258  WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
1259  WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
1260  END IF
1261  RETURN
1262  END IF
1263 
1264  ! check if communicators are equivalent
1265  ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1266  ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
1267  numproc = info_b%mp_comm%num_pe
1268  pdims_a = info_a%mp_comm%num_pe_cart
1269  pdims_b = info_b%mp_comm%num_pe_cart
1270  IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1271  IF (unit_nr_prv > 0) THEN
1272  WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
1273  END IF
1274  RETURN
1275  END IF
1276 
1277  ! check that distribution is the same by comparing local rows / columns for each matrix
1278  SELECT CASE (split_rc_a)
1279  CASE (rowsplit)
1280  CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
1281  CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
1282  CASE (colsplit)
1283  CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
1284  CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
1285  END SELECT
1286 
1287  same_local_rowcols = merge(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
1288 
1289  CALL info_a%mp_comm%sum(same_local_rowcols)
1290 
1291  IF (same_local_rowcols == numproc) THEN
1292  dist_compatible = .true.
1293  ELSE
1294  IF (unit_nr_prv > 0) THEN
1295  WRITE (unit_nr_prv, *) "local rowcols not compatible"
1296  WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
1297  WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
1298  END IF
1299  END IF
1300 
1301  END FUNCTION
1302 
1303 ! **************************************************************************************************
1304 !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
1305 !> \param template ...
1306 !> \param matrix_in ...
1307 !> \param matrix_out ...
1308 !> \param trans ...
1309 !> \param split_rc ...
1310 !> \param nodata ...
1311 !> \param move_data ...
1312 !> \author Patrick Seewald
1313 ! **************************************************************************************************
1314  SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1315  TYPE(dbt_tas_type), INTENT(IN) :: template
1316  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1317  TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1318  LOGICAL, INTENT(INOUT) :: trans
1319  INTEGER, INTENT(IN) :: split_rc
1320  LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1321 
1322  CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1323 
1324  TYPE(dbt_tas_distribution_type) :: dist_new
1325  TYPE(dbt_tas_split_info) :: info_template, info_matrix
1326  INTEGER :: dim_split_template, dim_split_matrix, &
1327  handle
1328  INTEGER, DIMENSION(2) :: pdims
1329  LOGICAL :: nodata_prv, transposed
1330  TYPE(mp_cart_type) :: mp_comm
1331  CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_template'
1332 
1333  CALL timeset(routinen, handle)
1334 
1335  IF (PRESENT(nodata)) THEN
1336  nodata_prv = nodata
1337  ELSE
1338  nodata_prv = .false.
1339  END IF
1340 
1341  info_template = dbt_tas_info(template)
1342  info_matrix = dbt_tas_info(matrix_in)
1343 
1344  dim_split_template = info_template%split_rowcol
1345  dim_split_matrix = split_rc
1346 
1347  transposed = dim_split_template .NE. dim_split_matrix
1348  IF (transposed) trans = .NOT. trans
1349 
1350  pdims = info_template%mp_comm%num_pe_cart
1351 
1352  SELECT CASE (dim_split_template)
1353  CASE (1)
1354  IF (.NOT. transposed) THEN
1355  ALLOCATE (row_dist, source=template%dist%row_dist)
1356  ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1357  ELSE
1358  ALLOCATE (row_dist, source=template%dist%row_dist)
1359  ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1360  END IF
1361  CASE (2)
1362  IF (.NOT. transposed) THEN
1363  ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1364  ALLOCATE (col_dist, source=template%dist%col_dist)
1365  ELSE
1366  ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1367  ALLOCATE (col_dist, source=template%dist%col_dist)
1368  END IF
1369  END SELECT
1370 
1371  CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
1372  CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1373  IF (.NOT. transposed) THEN
1374  CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1375  matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
1376  ELSE
1377  CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1378  matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
1379  END IF
1380 
1381  IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1382 
1383  CALL timestop(handle)
1384 
1385  END SUBROUTINE
1386 
1387 ! **************************************************************************************************
1388 !> \brief Estimate sparsity pattern of C resulting from A x B = C
1389 !> by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
1390 !> \param transa ...
1391 !> \param transb ...
1392 !> \param transc ...
1393 !> \param matrix_a ...
1394 !> \param matrix_b ...
1395 !> \param matrix_c ...
1396 !> \param estimated_nze ...
1397 !> \param filter_eps ...
1398 !> \param unit_nr ...
1399 !> \param retain_sparsity ...
1400 !> \author Patrick Seewald
1401 ! **************************************************************************************************
1402  SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403  estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404  LOGICAL, INTENT(IN) :: transa, transb, transc
1405  TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c
1406  INTEGER(int_8), INTENT(OUT) :: estimated_nze
1407  REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
1408  INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1409  LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1410 
1411  CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_tas_estimate_result_nze'
1412 
1413  INTEGER :: col_size, handle, row_size
1414  INTEGER(int_8) :: col, row
1415  LOGICAL :: retain_sparsity_prv
1416  TYPE(dbt_tas_iterator) :: iter
1417  TYPE(dbt_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1418  matrix_c_bnorm
1419  TYPE(mp_cart_type) :: mp_comm
1420 
1421  CALL timeset(routinen, handle)
1422 
1423  IF (PRESENT(retain_sparsity)) THEN
1424  retain_sparsity_prv = retain_sparsity
1425  ELSE
1426  retain_sparsity_prv = .false.
1427  END IF
1428 
1429  IF (.NOT. retain_sparsity_prv) THEN
1430  ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431  CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432  CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433  CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.true.)
1434 
1435  CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
1436  matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437  filter_eps=filter_eps, move_data_a=.true., move_data_b=.true., &
1438  simple_split=.true., unit_nr=unit_nr)
1439  CALL dbt_tas_destroy(matrix_a_bnorm)
1440  CALL dbt_tas_destroy(matrix_b_bnorm)
1441 
1442  DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1443  ELSE
1444  matrix_c_bnorm => matrix_c
1445  END IF
1446 
1447  estimated_nze = 0
1448 !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
1449 !$OMP PRIVATE(iter,row,col,row_size,col_size)
1450  CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
1451  DO WHILE (dbt_tas_iterator_blocks_left(iter))
1452  CALL dbt_tas_iterator_next_block(iter, row, col)
1453  row_size = matrix_c%row_blk_size%data(row)
1454  col_size = matrix_c%col_blk_size%data(col)
1455  estimated_nze = estimated_nze + row_size*col_size
1456  END DO
1457  CALL dbt_tas_iterator_stop(iter)
1458 !$OMP END PARALLEL
1459 
1460  CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
1461  CALL mp_comm%sum(estimated_nze)
1462 
1463  IF (.NOT. retain_sparsity_prv) THEN
1464  CALL dbt_tas_destroy(matrix_c_bnorm)
1465  DEALLOCATE (matrix_c_bnorm)
1466  END IF
1467 
1468  CALL timestop(handle)
1469 
1470  END SUBROUTINE
1471 
1472 ! **************************************************************************************************
1473 !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1474 !> This estimate is based on the minimization of communication volume whereby the
1475 !> communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
1476 !> \param max_mm_dim ...
1477 !> \param nze_a number of non-zeroes in A
1478 !> \param nze_b number of non-zeroes in B
1479 !> \param nze_c number of non-zeroes in C
1480 !> \param numnodes number of MPI ranks
1481 !> \return estimated split factor
1482 !> \author Patrick Seewald
1483 ! **************************************************************************************************
1484  FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1485  INTEGER, INTENT(IN) :: max_mm_dim
1486  INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c
1487  INTEGER, INTENT(IN) :: numnodes
1488  INTEGER :: nsplit
1489 
1490  INTEGER(KIND=int_8) :: max_nze, min_nze
1491  REAL(dp) :: s_opt_factor
1492 
1493  s_opt_factor = 1.0_dp ! Could be further tuned.
1494 
1495  SELECT CASE (max_mm_dim)
1496  CASE (1)
1497  min_nze = max(nze_b, 1_int_8)
1498  max_nze = max(maxval([nze_a, nze_c]), 1_int_8)
1499  CASE (2)
1500  min_nze = max(nze_c, 1_int_8)
1501  max_nze = max(maxval([nze_a, nze_b]), 1_int_8)
1502  CASE (3)
1503  min_nze = max(nze_a, 1_int_8)
1504  max_nze = max(maxval([nze_b, nze_c]), 1_int_8)
1505  CASE DEFAULT
1506  cpabort("")
1507  END SELECT
1508 
1509  nsplit = int(min(int(numnodes, kind=int_8), nint(real(max_nze, dp)/(real(min_nze, dp)*s_opt_factor), kind=int_8)))
1510  IF (nsplit == 0) nsplit = 1
1511 
1512  END FUNCTION
1513 
1514 ! **************************************************************************************************
1515 !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
1516 !> \param matrix_in ...
1517 !> \param matrix_out ...
1518 !> \param nodata ...
1519 !> \author Patrick Seewald
1520 ! **************************************************************************************************
1521  SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1522  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1523  TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1524  LOGICAL, INTENT(IN), OPTIONAL :: nodata
1525 
1526  CHARACTER(len=default_string_length) :: name
1527  INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528  LOGICAL :: nodata_prv
1529  REAL(dp), DIMENSION(1, 1) :: blk_put
1530  REAL(dp), DIMENSION(:, :), POINTER :: blk_get
1531  TYPE(dbt_tas_blk_size_one) :: col_blk_size, row_blk_size
1532  TYPE(dbt_tas_iterator) :: iter
1533 
1534 !REAL(dp), DIMENSION(:, :), POINTER :: dbt_put
1535 
1536  cpassert(matrix_in%valid)
1537 
1538  IF (PRESENT(nodata)) THEN
1539  nodata_prv = nodata
1540  ELSE
1541  nodata_prv = .false.
1542  END IF
1543 
1544  CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1545  row_blk_size = dbt_tas_blk_size_one(nblkrows)
1546  col_blk_size = dbt_tas_blk_size_one(nblkcols)
1547 
1548  ! not sure if assumption that same distribution can be taken still holds
1549  CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1550 
1551  IF (.NOT. nodata_prv) THEN
1552  CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
1553 !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
1554 !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
1555  CALL dbt_tas_iterator_start(iter, matrix_in)
1556  DO WHILE (dbt_tas_iterator_blocks_left(iter))
1557  CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
1558  blk_put(1, 1) = norm2(blk_get)
1559  CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
1560  END DO
1561  CALL dbt_tas_iterator_stop(iter)
1562 !$OMP END PARALLEL
1563  END IF
1564 
1565  END SUBROUTINE
1566 
1567 ! **************************************************************************************************
1568 !> \brief Convert a DBM matrix to a new process grid
1569 !> \param mp_comm_cart new process grid
1570 !> \param matrix_in ...
1571 !> \param matrix_out ...
1572 !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1573 !> \param nodata Data of matrix_in should not be copied to matrix_out
1574 !> \param optimize_pgrid Whether to change process grid
1575 !> \author Patrick Seewald
1576 ! **************************************************************************************************
1577  SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1578  TYPE(mp_cart_type), INTENT(IN) :: mp_comm_cart
1579  TYPE(dbm_type), INTENT(INOUT) :: matrix_in
1580  TYPE(dbm_type), INTENT(OUT) :: matrix_out
1581  LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata, optimize_pgrid
1582 
1583  CHARACTER(LEN=*), PARAMETER :: routinen = 'convert_to_new_pgrid'
1584 
1585  CHARACTER(len=default_string_length) :: name
1586  INTEGER :: handle, nbcols, nbrows
1587  INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: col_dist, rbsize, rcsize, row_dist
1588  INTEGER, DIMENSION(2) :: pdims
1589  LOGICAL :: nodata_prv, optimize_pgrid_prv
1590  TYPE(dbm_distribution_obj) :: dist, dist_old
1591 
1592  NULLIFY (row_dist, col_dist, rbsize, rcsize)
1593 
1594  CALL timeset(routinen, handle)
1595 
1596  IF (PRESENT(optimize_pgrid)) THEN
1597  optimize_pgrid_prv = optimize_pgrid
1598  ELSE
1599  optimize_pgrid_prv = .true.
1600  END IF
1601 
1602  IF (PRESENT(nodata)) THEN
1603  nodata_prv = nodata
1604  ELSE
1605  nodata_prv = .false.
1606  END IF
1607 
1608  name = dbm_get_name(matrix_in)
1609 
1610  IF (.NOT. optimize_pgrid_prv) THEN
1611  CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
1612  IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
1613  CALL timestop(handle)
1614  RETURN
1615  END IF
1616 
1617  rbsize => dbm_get_row_block_sizes(matrix_in)
1618  rcsize => dbm_get_col_block_sizes(matrix_in)
1619  nbrows = SIZE(rbsize)
1620  nbcols = SIZE(rcsize)
1621  dist_old = dbm_get_distribution(matrix_in)
1622  pdims = mp_comm_cart%num_pe_cart
1623 
1624  ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1625  CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1626  CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1627 
1628  CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
1629  DEALLOCATE (row_dist, col_dist)
1630 
1631  CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1632  CALL dbm_distribution_release(dist)
1633 
1634  IF (.NOT. nodata_prv) THEN
1635  CALL dbm_redistribute(matrix_in, matrix_out)
1636  IF (PRESENT(move_data)) THEN
1637  IF (move_data) CALL dbm_clear(matrix_in)
1638  END IF
1639  END IF
1640 
1641  CALL timestop(handle)
1642  END SUBROUTINE
1643 
1644 ! **************************************************************************************************
1645 !> \brief ...
1646 !> \param matrix ...
1647 !> \author Patrick Seewald
1648 ! **************************************************************************************************
1649  SUBROUTINE dbt_tas_batched_mm_init(matrix)
1650  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1651 
1652  CALL dbt_tas_set_batched_state(matrix, state=1)
1653  ALLOCATE (matrix%mm_storage)
1654  matrix%mm_storage%batched_out = .false.
1655  END SUBROUTINE
1656 
1657 ! **************************************************************************************************
1658 !> \brief ...
1659 !> \param matrix ...
1660 !> \author Patrick Seewald
1661 ! **************************************************************************************************
1662  SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
1663  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1664 
1665  INTEGER :: handle
1666 
1667  CALL matrix%dist%info%mp_comm%sync()
1668  CALL timeset("dbt_tas_total", handle)
1669 
1670  IF (matrix%do_batched == 0) RETURN
1671 
1672  IF (matrix%mm_storage%batched_out) THEN
1673  CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1674  END IF
1675 
1676  CALL dbt_tas_batched_mm_complete(matrix)
1677 
1678  matrix%mm_storage%batched_out = .false.
1679 
1680  DEALLOCATE (matrix%mm_storage)
1681  CALL dbt_tas_set_batched_state(matrix, state=0)
1682 
1683  CALL matrix%dist%info%mp_comm%sync()
1684  CALL timestop(handle)
1685 
1686  END SUBROUTINE
1687 
1688 ! **************************************************************************************************
1689 !> \brief set state flags during batched multiplication
1690 !> \param matrix ...
1691 !> \param state 0 no batched MM
1692 !> 1 batched MM but mm_storage not yet initialized
1693 !> 2 batched MM and mm_storage requires update
1694 !> 3 batched MM and mm_storage initialized
1695 !> \param opt_grid whether process grid was already optimized and should not be changed
1696 !> \author Patrick Seewald
1697 ! **************************************************************************************************
1698  SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
1699  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1700  INTEGER, INTENT(IN), OPTIONAL :: state
1701  LOGICAL, INTENT(IN), OPTIONAL :: opt_grid
1702 
1703  IF (PRESENT(opt_grid)) THEN
1704  matrix%has_opt_pgrid = opt_grid
1705  matrix%dist%info%strict_split(1) = .true.
1706  END IF
1707 
1708  IF (PRESENT(state)) THEN
1709  matrix%do_batched = state
1710  SELECT CASE (state)
1711  CASE (0, 1)
1712  ! reset to default
1713  IF (matrix%has_opt_pgrid) THEN
1714  matrix%dist%info%strict_split(1) = .true.
1715  ELSE
1716  matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1717  END IF
1718  CASE (2, 3)
1719  matrix%dist%info%strict_split(1) = .true.
1720  CASE DEFAULT
1721  cpabort("should not happen")
1722  END SELECT
1723  END IF
1724  END SUBROUTINE
1725 
1726 ! **************************************************************************************************
1727 !> \brief ...
1728 !> \param matrix ...
1729 !> \param warn ...
1730 !> \author Patrick Seewald
1731 ! **************************************************************************************************
1732  SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
1733  TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1734  LOGICAL, INTENT(IN), OPTIONAL :: warn
1735 
1736  IF (matrix%do_batched == 0) RETURN
1737  associate(storage => matrix%mm_storage)
1738  IF (PRESENT(warn)) THEN
1739  IF (warn .AND. matrix%do_batched == 3) THEN
1740  CALL cp_warn(__location__, &
1741  "Optimizations for batched multiplication are disabled because of conflicting data access")
1742  END IF
1743  END IF
1744  IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
1745 
1746  CALL dbt_tas_merge(storage%store_batched%matrix, &
1747  storage%store_batched_repl, move_data=.true.)
1748 
1749  CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.true., &
1750  transposed=storage%batched_trans, move_data=.true.)
1751  CALL dbt_tas_destroy(storage%store_batched)
1752  DEALLOCATE (storage%store_batched)
1753  END IF
1754 
1755  IF (ASSOCIATED(storage%store_batched_repl)) THEN
1756  CALL dbt_tas_destroy(storage%store_batched_repl)
1757  DEALLOCATE (storage%store_batched_repl)
1758  END IF
1759  END associate
1760 
1761  CALL dbt_tas_set_batched_state(matrix, state=2)
1762 
1763  END SUBROUTINE
1764 
1765 END MODULE
Definition: dbm_api.F:8
subroutine, public dbm_multiply(transa, transb, alpha, matrix_a, matrix_b, beta, matrix_c, retain_sparsity, filter_eps, flop)
Computes matrix product: matrix_c = alpha * matrix_a * matrix_b + beta * matrix_c.
Definition: dbm_api.F:736
subroutine, public dbm_redistribute(matrix, redist)
Copies content of matrix_b into matrix_a. Matrices may have different distributions.
Definition: dbm_api.F:412
subroutine, public dbm_zero(matrix)
Sets all blocks in the given matrix to zero.
Definition: dbm_api.F:662
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Definition: dbm_api.F:529
subroutine, public dbm_create_from_template(matrix, name, template)
Creates a new matrix from given template, reusing dist and row/col_block_sizes.
Definition: dbm_api.F:265
pure integer function, public dbm_get_nze(matrix)
Returns the number of local Non-Zero Elements of the given matrix.
Definition: dbm_api.F:1057
subroutine, public dbm_scale(matrix, alpha)
Multiplies all entries in the given matrix by the given factor alpha.
Definition: dbm_api.F:631
subroutine, public dbm_distribution_release(dist)
Decreases the reference counter of the given distribution.
Definition: dbm_api.F:1401
type(dbm_distribution_obj) function, public dbm_get_distribution(matrix)
Returns the distribution of the given matrix.
Definition: dbm_api.F:1266
subroutine, public dbm_create(matrix, name, dist, row_block_sizes, col_block_sizes)
Creates a new matrix.
Definition: dbm_api.F:293
integer function, dimension(:), pointer, contiguous, public dbm_get_row_block_sizes(matrix)
Returns the row block sizes of the given matrix.
Definition: dbm_api.F:1103
character(len=default_string_length) function, public dbm_get_name(matrix)
Returns the name of the matrix of the given matrix.
Definition: dbm_api.F:1023
subroutine, public dbm_add(matrix_a, matrix_b)
Adds matrix_b to matrix_a.
Definition: dbm_api.F:692
subroutine, public dbm_copy(matrix_a, matrix_b)
Copies content of matrix_b into matrix_a. Matrices must have the same row/col block sizes and distrib...
Definition: dbm_api.F:380
subroutine, public dbm_release(matrix)
Releases a matrix and all its ressources.
Definition: dbm_api.F:354
integer function, dimension(:), pointer, contiguous, public dbm_get_col_block_sizes(matrix)
Returns the column block sizes of the given matrix.
Definition: dbm_api.F:1130
subroutine, public dbm_distribution_new(dist, mp_comm, row_dist_block, col_dist_block)
Creates a new two dimensional distribution.
Definition: dbm_api.F:1294
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
Definition: dbt_tas_base.F:13
integer(kind=int_8) function, public dbt_tas_get_nze_total(matrix)
Get total number of non-zero elements.
Definition: dbt_tas_base.F:958
subroutine, public dbt_tas_iterator_start(iter, matrix_in)
As dbm_iterator_start.
Definition: dbt_tas_base.F:670
logical function, public dbt_tas_iterator_blocks_left(iter)
As dbm_iterator_blocks_left.
Definition: dbt_tas_base.F:698
integer(kind=int_8) function, public dbt_tas_nblkrows_total(matrix)
...
Definition: dbt_tas_base.F:835
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
Definition: dbt_tas_base.F:999
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
Definition: dbt_tas_base.F:250
subroutine, public dbt_tas_iterator_stop(iter)
As dbm_iterator_stop.
Definition: dbt_tas_base.F:710
subroutine, public dbt_tas_distribution_new(dist, mp_comm, row_dist, col_dist, split_info, nosplit)
create new distribution. Exactly like dbm_distribution_new but with custom types for row_dist and col...
Definition: dbt_tas_base.F:346
type(dbt_tas_split_info) function, public dbt_tas_info(matrix)
get info on mpi grid splitting
Definition: dbt_tas_base.F:822
integer(kind=int_8) function, public dbt_tas_nblkcols_total(matrix)
...
Definition: dbt_tas_base.F:861
subroutine, public dbt_tas_filter(matrix, eps)
As dbm_filter.
subroutine, public dbt_tas_clear(matrix)
Clear matrix (erase all data)
Definition: dbt_tas_base.F:974
subroutine, public dbt_tas_destroy(matrix)
...
Definition: dbt_tas_base.F:233
subroutine, public dbt_tas_put_block(matrix, row, col, block, summation)
As dbm_put_block.
Global data (distribution and block sizes) for tall-and-skinny matrices For very sparse matrices with...
subroutine, public dbt_tas_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
type(dbt_tas_dist_arb) function, public dbt_tas_dist_arb_default(nprowcol, nmrowcol, dbt_sizes)
Distribution that is more or less cyclic (round robin) and load balanced with different weights for e...
tall-and-skinny matrices: Input / Output
Definition: dbt_tas_io.F:12
integer function, public prep_output_unit(unit_nr)
...
Definition: dbt_tas_io.F:264
subroutine, public dbt_tas_write_matrix_info(matrix, unit_nr, full_info)
Write basic infos of tall-and-skinny matrix: block dimensions, full dimensions, process grid dimensio...
Definition: dbt_tas_io.F:59
subroutine, public dbt_tas_write_dist(matrix, unit_nr, full_info)
Write info on tall-and-skinny matrix distribution & load balance.
Definition: dbt_tas_io.F:122
subroutine, public dbt_tas_write_split_info(info, unit_nr, name)
Print info on how matrix is split.
Definition: dbt_tas_io.F:214
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
Definition: dbt_tas_mm.F:18
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
Definition: dbt_tas_mm.F:1699
subroutine, public dbt_tas_batched_mm_init(matrix)
...
Definition: dbt_tas_mm.F:1650
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
Definition: dbt_tas_mm.F:1663
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
Definition: dbt_tas_mm.F:105
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
Definition: dbt_tas_mm.F:1733
communication routines to reshape / replicate / merge tall-and-skinny matrices.
subroutine, public dbt_tas_merge(matrix_out, matrix_in, summation, move_data)
Merge submatrices of matrix_in to matrix_out by sum.
subroutine, public dbt_tas_replicate(matrix_in, info, matrix_out, nodata, move_data)
Replicate matrix_in such that each submatrix of matrix_out is an exact copy of matrix_in.
recursive subroutine, public dbt_tas_reshape(matrix_in, matrix_out, summation, transposed, move_data)
copy data (involves reshape)
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
Definition: dbt_tas_split.F:13
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
Definition: dbt_tas_split.F:50
subroutine, public dbt_tas_get_split_info(info, mp_comm, nsplit, igroup, mp_comm_group, split_rowcol, pgrid_offset)
Get info on split.
integer, parameter, public colsplit
Definition: dbt_tas_split.F:50
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
type(mp_cart_type) function, public dbt_tas_mp_comm(mp_comm, split_rowcol, nsplit)
Create default cartesian process grid that is consistent with default split heuristic of dbt_tas_crea...
real(dp), parameter, public default_nsplit_accept_ratio
Definition: dbt_tas_split.F:52
subroutine, public dbt_tas_info_hold(split_info)
...
logical function, public accept_pgrid_dims(dims, relative)
Whether to accept proposed process grid dimensions (based on ratio of dimensions)
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
Definition: dbt_tas_types.F:13
often used utilities for tall-and-skinny matrices
Definition: dbt_tas_util.F:12
Defines the basic variable types.
Definition: kinds.F:23
integer, parameter, public int_8
Definition: kinds.F:54
integer, parameter, public dp
Definition: kinds.F:34
integer, parameter, public default_string_length
Definition: kinds.F:57
Interface to the message passing library MPI.