(git:374b731)
Loading...
Searching...
No Matches
dbt_tas_mm.F
Go to the documentation of this file.
1!--------------------------------------------------------------------------------------------------!
2! CP2K: A general program to perform molecular dynamics simulations !
3! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4! !
5! SPDX-License-Identifier: GPL-2.0-or-later !
6!--------------------------------------------------------------------------------------------------!
7
8! **************************************************************************************************
9!> \brief Matrix multiplication for tall-and-skinny matrices.
10!> This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
11!> as long as the two smaller dimensions have the same size.
12!> Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
13!> submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
14!> parameters (group sizes and process grid dimensions) can not be derived from matrix
15!> dimensions and need to be set manually.
16!> \author Patrick Seewald
17! **************************************************************************************************
19 USE dbm_api, ONLY: &
24 USE dbt_tas_base, ONLY: &
44 USE dbt_tas_split, ONLY: &
52 USE dbt_tas_util, ONLY: array_eq,&
53 swap
54 USE kinds, ONLY: default_string_length,&
55 dp,&
56 int_8
58#include "../../base/base_uses.f90"
59
60 IMPLICIT NONE
61 PRIVATE
62
63 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
64
65 PUBLIC :: &
71
72CONTAINS
73
74! **************************************************************************************************
75!> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
76!> to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
77!> \param transa ...
78!> \param transb ...
79!> \param transc ...
80!> \param alpha ...
81!> \param matrix_a ...
82!> \param matrix_b ...
83!> \param beta ...
84!> \param matrix_c ...
85!> \param optimize_dist Whether distribution should be optimized internally. In the current
86!> implementation this guarantees optimal parameters only for dense matrices.
87!> \param split_opt optionally return split info containing optimal grid and split parameters.
88!> This can be used to choose optimal process grids for subsequent matrix
89!> multiplications with matrices of similar shape and sparsity.
90!> \param filter_eps ...
91!> \param flop ...
92!> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
93!> (for internal use only)
94!> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
95!> (for internal use only)
96!> \param retain_sparsity ...
97!> \param simple_split ...
98!> \param unit_nr unit number for logging output
99!> \param log_verbose only for testing: verbose output
100!> \author Patrick Seewald
101! **************************************************************************************************
102 RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103 optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104 move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
105
106 LOGICAL, INTENT(IN) :: transa, transb, transc
107 REAL(dp), INTENT(IN) :: alpha
108 TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b
109 REAL(dp), INTENT(IN) :: beta
110 TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_c
111 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
112 TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL :: split_opt
113 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
114 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
115 LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, &
116 retain_sparsity, simple_split
117 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
118 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
119
120 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_tas_multiply'
121
122 INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123 nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
124 unit_nr_prv
125 INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126 INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c
127 INTEGER(KIND=int_8), DIMENSION(3) :: dims
128 INTEGER, DIMENSION(2) :: pdims, pdims_sub
129 LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130 simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131 REAL(kind=dp) :: filter_eps_prv
132 TYPE(dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
133 TYPE(dbt_tas_split_info) :: info, info_a, info_b, info_c
134 TYPE(dbt_tas_type), POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135 matrix_b_rs, matrix_c_rep, matrix_c_rs
136 TYPE(mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137 mp_comm_mm, mp_comm_opt
138
139 CALL timeset(routinen, handle)
140 CALL matrix_a%dist%info%mp_comm%sync()
141 CALL timeset("dbt_tas_total", handle2)
142
143 NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
144
145 unit_nr_prv = prep_output_unit(unit_nr)
146
147 IF (PRESENT(simple_split)) THEN
148 simple_split_prv = simple_split
149 ELSE
150 simple_split_prv = .false.
151
152 info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
153 IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .true.
154 END IF
155
156 nodata_3 = .true.
157 IF (PRESENT(retain_sparsity)) THEN
158 IF (retain_sparsity) nodata_3 = .false.
159 END IF
160
161 ! get prestored info for multiplication strategy in case of batched mm
162 batched_repl = 0
163 do_batched = .false.
164 IF (matrix_a%do_batched > 0) THEN
165 do_batched = .true.
166 IF (matrix_a%do_batched == 3) THEN
167 cpassert(batched_repl == 0)
168 batched_repl = 1
170 dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
171 nsplit=nsplit_batched)
172 cpassert(nsplit_batched > 0)
173 max_mm_dim_batched = 3
174 END IF
175 END IF
176
177 IF (matrix_b%do_batched > 0) THEN
178 do_batched = .true.
179 IF (matrix_b%do_batched == 3) THEN
180 cpassert(batched_repl == 0)
181 batched_repl = 2
183 dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
184 nsplit=nsplit_batched)
185 cpassert(nsplit_batched > 0)
186 max_mm_dim_batched = 1
187 END IF
188 END IF
189
190 IF (matrix_c%do_batched > 0) THEN
191 do_batched = .true.
192 IF (matrix_c%do_batched == 3) THEN
193 cpassert(batched_repl == 0)
194 batched_repl = 3
196 dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
197 nsplit=nsplit_batched)
198 cpassert(nsplit_batched > 0)
199 max_mm_dim_batched = 2
200 END IF
201 END IF
202
203 move_a = .false.
204 move_b = .false.
205
206 IF (PRESENT(move_data_a)) move_a = move_data_a
207 IF (PRESENT(move_data_b)) move_b = move_data_b
208
209 transa_prv = transa; transb_prv = transb; transc_prv = transc
210
211 dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
212 dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
213 dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
214
215 IF (unit_nr_prv > 0) THEN
216 WRITE (unit_nr_prv, "(A)") repeat("-", 80)
217 WRITE (unit_nr_prv, "(A)") &
218 "DBT TAS MATRIX MULTIPLICATION: "// &
219 trim(dbm_get_name(matrix_a%matrix))//" x "// &
220 trim(dbm_get_name(matrix_b%matrix))//" = "// &
221 trim(dbm_get_name(matrix_c%matrix))
222 WRITE (unit_nr_prv, "(A)") repeat("-", 80)
223 END IF
224 IF (do_batched) THEN
225 IF (unit_nr_prv > 0) THEN
226 WRITE (unit_nr_prv, "(T2,A)") &
227 "BATCHED PROCESSING OF MATMUL"
228 IF (batched_repl > 0) THEN
229 WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
230 END IF
231 END IF
232 END IF
233
234 IF (transa_prv) THEN
235 CALL swap(dims_a)
236 END IF
237
238 IF (transb_prv) THEN
239 CALL swap(dims_b)
240 END IF
241
242 dims_c = [dims_a(1), dims_b(2)]
243
244 IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
245 cpabort("inconsistent matrix dimensions")
246 END IF
247
248 dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
249
250 IF (unit_nr_prv > 0) THEN
251 WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
252 END IF
253
254 CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
255 numproc = mp_comm%num_pe
256
257 ! derive optimal matrix layout and split factor from occupancies
258 nze_a = dbt_tas_get_nze_total(matrix_a)
259 nze_b = dbt_tas_get_nze_total(matrix_b)
260
261 IF (.NOT. simple_split_prv) THEN
262 CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263 estimated_nze=nze_c, filter_eps=filter_eps, &
264 retain_sparsity=retain_sparsity)
265
266 max_mm_dim = maxloc(dims, 1)
267 nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
268 nsplit_opt = nsplit
269
270 IF (unit_nr_prv > 0) THEN
271 WRITE (unit_nr_prv, "(T2,A)") &
272 "MM PARAMETERS"
273 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
274 (nze_c + numproc - 1)/numproc
275
276 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
277 END IF
278
279 ELSEIF (batched_repl > 0) THEN
280 nsplit = nsplit_batched
281 nsplit_opt = nsplit
282 max_mm_dim = max_mm_dim_batched
283 IF (unit_nr_prv > 0) THEN
284 WRITE (unit_nr_prv, "(T2,A)") &
285 "MM PARAMETERS"
286 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
287 END IF
288
289 ELSE
290 nsplit = 0
291 max_mm_dim = maxloc(dims, 1)
292 END IF
293
294 ! reshape matrices to the optimal layout and split factor
295 split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
296 SELECT CASE (max_mm_dim)
297 CASE (1)
298
299 split_a = rowsplit; split_c = rowsplit
300 CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301 new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
302 nsplit=nsplit, &
303 opt_nsplit=batched_repl == 0, &
304 split_rc_1=split_a, split_rc_2=split_c, &
305 nodata2=nodata_3, comm_new=comm_tmp, &
306 move_data_1=move_a, unit_nr=unit_nr_prv)
307
308 info = dbt_tas_info(matrix_a_rs)
309 CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
310
311 new_b = .false.
312 IF (matrix_b%do_batched <= 2) THEN
313 ALLOCATE (matrix_b_rs)
314 CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
315 transb_prv = .false.
316 new_b = .true.
317 END IF
318
319 tr_case = transa_prv
320
321 IF (unit_nr_prv > 0) THEN
322 IF (.NOT. tr_case) THEN
323 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
324 ELSE
325 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
326 END IF
327 END IF
328
329 CASE (2)
330
331 split_a = colsplit; split_b = rowsplit
332 CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333 optimize_dist=optimize_dist, &
334 nsplit=nsplit, &
335 opt_nsplit=batched_repl == 0, &
336 split_rc_1=split_a, split_rc_2=split_b, &
337 comm_new=comm_tmp, &
338 move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
339
340 info = dbt_tas_info(matrix_a_rs)
341 CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
342
343 IF (matrix_c%do_batched == 1) THEN
344 matrix_c%mm_storage%batched_beta = beta
345 ELSEIF (matrix_c%do_batched > 1) THEN
346 matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
347 END IF
348
349 IF (matrix_c%do_batched <= 2) THEN
350 ALLOCATE (matrix_c_rs)
351 CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
352 transc_prv = .false.
353
354 ! just leave sparsity structure for retain sparsity but no values
355 IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
356
357 IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358 ELSEIF (matrix_c%do_batched == 3) THEN
359 matrix_c_rs => matrix_c%mm_storage%store_batched
360 END IF
361
362 new_c = matrix_c%do_batched == 0
363 tr_case = transa_prv
364
365 IF (unit_nr_prv > 0) THEN
366 IF (.NOT. tr_case) THEN
367 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
368 ELSE
369 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
370 END IF
371 END IF
372
373 CASE (3)
374
375 split_b = colsplit; split_c = colsplit
376 CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377 transc_prv, optimize_dist=optimize_dist, &
378 nsplit=nsplit, &
379 opt_nsplit=batched_repl == 0, &
380 split_rc_1=split_b, split_rc_2=split_c, &
381 nodata2=nodata_3, comm_new=comm_tmp, &
382 move_data_1=move_b, unit_nr=unit_nr_prv)
383 info = dbt_tas_info(matrix_b_rs)
384 CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
385
386 new_a = .false.
387 IF (matrix_a%do_batched <= 2) THEN
388 ALLOCATE (matrix_a_rs)
389 CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
390 transa_prv = .false.
391 new_a = .true.
392 END IF
393
394 tr_case = transb_prv
395
396 IF (unit_nr_prv > 0) THEN
397 IF (.NOT. tr_case) THEN
398 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
399 ELSE
400 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
401 END IF
402 END IF
403
404 END SELECT
405
406 CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
407
408 numproc = mp_comm%num_pe
409 pdims_sub = mp_comm_group%num_pe_cart
410
411 opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.true.)
412
413 IF (PRESENT(filter_eps)) THEN
414 filter_eps_prv = filter_eps
415 ELSE
416 filter_eps_prv = 0.0_dp
417 END IF
418
419 IF (unit_nr_prv /= 0) THEN
420 IF (unit_nr_prv > 0) THEN
421 WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
422 END IF
423 CALL dbt_tas_write_split_info(info, unit_nr_prv)
424 IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
425 IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
426 IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
427 IF (unit_nr_prv > 0) THEN
428 IF (opt_pgrid) THEN
429 WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
430 ELSE
431 WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
432 END IF
433 END IF
434 END IF
435
436 pdims = 0
437 CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
438
439 ! Convert DBM submatrices to optimized process grids and multiply
440 SELECT CASE (max_mm_dim)
441 CASE (1)
442 IF (matrix_b%do_batched <= 2) THEN
443 ALLOCATE (matrix_b_rep)
444 CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.true.)
445 IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
446 matrix_b%mm_storage%store_batched_repl => matrix_b_rep
447 CALL dbt_tas_set_batched_state(matrix_b, state=3)
448 END IF
449 ELSEIF (matrix_b%do_batched == 3) THEN
450 matrix_b_rep => matrix_b%mm_storage%store_batched_repl
451 END IF
452
453 IF (new_b) THEN
454 CALL dbt_tas_destroy(matrix_b_rs)
455 DEALLOCATE (matrix_b_rs)
456 END IF
457 IF (unit_nr_prv /= 0) THEN
458 CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
459 CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
460 END IF
461
462 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
463
464 ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
465 info_a = dbt_tas_info(matrix_a_rs)
466 CALL dbt_tas_info_hold(info_a)
467
468 IF (new_a) THEN
469 CALL dbt_tas_destroy(matrix_a_rs)
470 DEALLOCATE (matrix_a_rs)
471 END IF
472 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473 move_data=matrix_b%do_batched == 0)
474
475 info_b = dbt_tas_info(matrix_b_rep)
476 CALL dbt_tas_info_hold(info_b)
477
478 IF (matrix_b%do_batched == 0) THEN
479 CALL dbt_tas_destroy(matrix_b_rep)
480 DEALLOCATE (matrix_b_rep)
481 END IF
482
483 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
484
485 info_c = dbt_tas_info(matrix_c_rs)
486 CALL dbt_tas_info_hold(info_c)
487
488 CALL matrix_a%dist%info%mp_comm%sync()
489 CALL timeset("dbt_tas_dbm", handle4)
490 IF (.NOT. tr_case) THEN
491 CALL timeset("dbt_tas_mm_1N", handle3)
492
493 CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
494 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496 CALL timestop(handle3)
497 ELSE
498 CALL timeset("dbt_tas_mm_1T", handle3)
499 CALL dbm_multiply(transa=.true., transb=.false., alpha=alpha, &
500 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
502
503 CALL timestop(handle3)
504 END IF
505 CALL matrix_a%dist%info%mp_comm%sync()
506 CALL timestop(handle4)
507
508 CALL dbm_release(matrix_a_mm)
509 CALL dbm_release(matrix_b_mm)
510
511 nze_c = dbm_get_nze(matrix_c_mm)
512
513 IF (.NOT. new_c) THEN
514 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
515 ELSE
516 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
517 END IF
518
519 CALL dbm_release(matrix_c_mm)
520
521 IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
522
523 IF (unit_nr_prv /= 0) THEN
524 CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
525 END IF
526
527 CASE (2)
528 IF (matrix_c%do_batched <= 1) THEN
529 ALLOCATE (matrix_c_rep)
530 CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
531 IF (matrix_c%do_batched == 1) THEN
532 matrix_c%mm_storage%store_batched_repl => matrix_c_rep
533 CALL dbt_tas_set_batched_state(matrix_c, state=3)
534 END IF
535 ELSEIF (matrix_c%do_batched == 2) THEN
536 ALLOCATE (matrix_c_rep)
537 CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
538 ! just leave sparsity structure for retain sparsity but no values
539 IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
540 matrix_c%mm_storage%store_batched_repl => matrix_c_rep
541 CALL dbt_tas_set_batched_state(matrix_c, state=3)
542 ELSEIF (matrix_c%do_batched == 3) THEN
543 matrix_c_rep => matrix_c%mm_storage%store_batched_repl
544 END IF
545
546 IF (unit_nr_prv /= 0) THEN
547 CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
548 CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
549 END IF
550
551 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
552
553 ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
554 info_a = dbt_tas_info(matrix_a_rs)
555 CALL dbt_tas_info_hold(info_a)
556
557 IF (new_a) THEN
558 CALL dbt_tas_destroy(matrix_a_rs)
559 DEALLOCATE (matrix_a_rs)
560 END IF
561
562 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
563
564 info_b = dbt_tas_info(matrix_b_rs)
565 CALL dbt_tas_info_hold(info_b)
566
567 IF (new_b) THEN
568 CALL dbt_tas_destroy(matrix_b_rs)
569 DEALLOCATE (matrix_b_rs)
570 END IF
571
572 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
573
574 info_c = dbt_tas_info(matrix_c_rep)
575 CALL dbt_tas_info_hold(info_c)
576
577 CALL matrix_a%dist%info%mp_comm%sync()
578 CALL timeset("dbt_tas_dbm", handle4)
579 CALL timeset("dbt_tas_mm_2", handle3)
580 CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581 matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582 filter_eps=filter_eps_prv/real(nsplit, kind=dp), retain_sparsity=retain_sparsity, flop=flop)
583 CALL matrix_a%dist%info%mp_comm%sync()
584 CALL timestop(handle3)
585 CALL timestop(handle4)
586
587 CALL dbm_release(matrix_a_mm)
588 CALL dbm_release(matrix_b_mm)
589
590 nze_c = dbm_get_nze(matrix_c_mm)
591
592 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
593 nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
594
595 CALL dbm_release(matrix_c_mm)
596
597 IF (unit_nr_prv /= 0) THEN
598 CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
599 END IF
600
601 IF (matrix_c%do_batched == 0) THEN
602 CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.true.)
603 ELSE
604 matrix_c%mm_storage%batched_out = .true. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
605 END IF
606
607 IF (matrix_c%do_batched == 0) THEN
608 CALL dbt_tas_destroy(matrix_c_rep)
609 DEALLOCATE (matrix_c_rep)
610 END IF
611
612 IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
613
614 ! set upper limit on memory consumption for replicated matrix and complete batched mm
615 ! if limit is exceeded
616 IF (nze_c_sum > default_nsplit_accept_ratio*max(nze_a, nze_b)) THEN
617 CALL dbt_tas_batched_mm_complete(matrix_c)
618 END IF
619
620 CASE (3)
621 IF (matrix_a%do_batched <= 2) THEN
622 ALLOCATE (matrix_a_rep)
623 CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.true.)
624 IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
625 matrix_a%mm_storage%store_batched_repl => matrix_a_rep
626 CALL dbt_tas_set_batched_state(matrix_a, state=3)
627 END IF
628 ELSEIF (matrix_a%do_batched == 3) THEN
629 matrix_a_rep => matrix_a%mm_storage%store_batched_repl
630 END IF
631
632 IF (new_a) THEN
633 CALL dbt_tas_destroy(matrix_a_rs)
634 DEALLOCATE (matrix_a_rs)
635 END IF
636 IF (unit_nr_prv /= 0) THEN
637 CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
638 CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
639 END IF
640
641 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642 move_data=matrix_a%do_batched == 0)
643
644 ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
645 info_a = dbt_tas_info(matrix_a_rep)
646 CALL dbt_tas_info_hold(info_a)
647
648 IF (matrix_a%do_batched == 0) THEN
649 CALL dbt_tas_destroy(matrix_a_rep)
650 DEALLOCATE (matrix_a_rep)
651 END IF
652
653 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
654
655 info_b = dbt_tas_info(matrix_b_rs)
656 CALL dbt_tas_info_hold(info_b)
657
658 IF (new_b) THEN
659 CALL dbt_tas_destroy(matrix_b_rs)
660 DEALLOCATE (matrix_b_rs)
661 END IF
662 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
663
664 info_c = dbt_tas_info(matrix_c_rs)
665 CALL dbt_tas_info_hold(info_c)
666
667 CALL matrix_a%dist%info%mp_comm%sync()
668 CALL timeset("dbt_tas_dbm", handle4)
669 IF (.NOT. tr_case) THEN
670 CALL timeset("dbt_tas_mm_3N", handle3)
671 CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
672 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674 CALL timestop(handle3)
675 ELSE
676 CALL timeset("dbt_tas_mm_3T", handle3)
677 CALL dbm_multiply(transa=.false., transb=.true., alpha=alpha, &
678 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680 CALL timestop(handle3)
681 END IF
682 CALL matrix_a%dist%info%mp_comm%sync()
683 CALL timestop(handle4)
684
685 CALL dbm_release(matrix_a_mm)
686 CALL dbm_release(matrix_b_mm)
687
688 nze_c = dbm_get_nze(matrix_c_mm)
689
690 IF (.NOT. new_c) THEN
691 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
692 ELSE
693 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
694 END IF
695
696 CALL dbm_release(matrix_c_mm)
697
698 IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
699
700 IF (unit_nr_prv /= 0) THEN
701 CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
702 END IF
703 END SELECT
704
705 CALL mp_comm_mm%free()
706
707 CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
708
709 IF (PRESENT(split_opt)) THEN
710 SELECT CASE (max_mm_dim)
711 CASE (1, 3)
712 CALL mp_comm%sum(nze_c)
713 CASE (2)
714 CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
715 CALL mp_comm%sum(nze_c)
716 CALL mp_comm%max(nze_c)
717
718 END SELECT
719 nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
720 ! ideally we should rederive the split factor from the actual sparsity of C, but
721 ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
722 mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
723 CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.true.)
724 IF (unit_nr_prv > 0) THEN
725 WRITE (unit_nr_prv, "(T2,A)") &
726 "MM PARAMETERS"
727 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
728 (nze_c + numproc - 1)/numproc
729
730 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
731 END IF
732
733 END IF
734
735 IF (new_c) THEN
736 CALL dbm_scale(matrix_c%matrix, beta)
737 CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.true., &
738 transposed=(transc_prv .NEQV. transc), &
739 move_data=.true.)
740 CALL dbt_tas_destroy(matrix_c_rs)
741 DEALLOCATE (matrix_c_rs)
742 IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
743 ELSEIF (matrix_c%do_batched > 0) THEN
744 IF (matrix_c%mm_storage%batched_out) THEN
745 matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
746 END IF
747 END IF
748
749 IF (PRESENT(move_data_a)) THEN
750 IF (move_data_a) CALL dbt_tas_clear(matrix_a)
751 END IF
752 IF (PRESENT(move_data_b)) THEN
753 IF (move_data_b) CALL dbt_tas_clear(matrix_b)
754 END IF
755
756 IF (PRESENT(flop)) THEN
757 CALL mp_comm%sum(flop)
758 flop = (flop + numproc - 1)/numproc
759 END IF
760
761 IF (PRESENT(optimize_dist)) THEN
762 IF (optimize_dist) CALL comm_tmp%free()
763 END IF
764 IF (unit_nr_prv > 0) THEN
765 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
766 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
767 WRITE (unit_nr_prv, '(A)') repeat("-", 80)
768 END IF
769
770 CALL dbt_tas_release_info(info_a)
771 CALL dbt_tas_release_info(info_b)
772 CALL dbt_tas_release_info(info_c)
773
774 CALL matrix_a%dist%info%mp_comm%sync()
775 CALL timestop(handle2)
776 CALL timestop(handle)
777
778 END SUBROUTINE
779
780! **************************************************************************************************
781!> \brief ...
782!> \param matrix_in ...
783!> \param matrix_out ...
784!> \param local_copy ...
785!> \param alpha ...
786!> \author Patrick Seewald
787! **************************************************************************************************
788 SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789 TYPE(dbm_type), INTENT(IN) :: matrix_in
790 TYPE(dbm_type), INTENT(INOUT) :: matrix_out
791 LOGICAL, INTENT(IN), OPTIONAL :: local_copy
792 REAL(dp), INTENT(IN) :: alpha
793
794 LOGICAL :: local_copy_prv
795 TYPE(dbm_type) :: matrix_tmp
796
797 IF (PRESENT(local_copy)) THEN
798 local_copy_prv = local_copy
799 ELSE
800 local_copy_prv = .false.
801 END IF
802
803 IF (alpha /= 1.0_dp) THEN
804 CALL dbm_scale(matrix_out, alpha)
805 END IF
806
807 IF (.NOT. local_copy_prv) THEN
808 CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
809 CALL dbm_redistribute(matrix_in, matrix_tmp)
810 CALL dbm_add(matrix_out, matrix_tmp)
811 CALL dbm_release(matrix_tmp)
812 ELSE
813 CALL dbm_add(matrix_out, matrix_in)
814 END IF
815
816 END SUBROUTINE
817
818! **************************************************************************************************
819!> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
820!> the same process grid as the other 2 matrices.
821!> \param mp_comm communicator that defines Cartesian topology
822!> \param matrix_in ...
823!> \param matrix_out ...
824!> \param transposed Whether matrix_out should be transposed
825!> \param nodata Data of matrix_in should not be copied to matrix_out
826!> \param move_data memory optimization: move data such that matrix_in is empty on return.
827!> \author Patrick Seewald
828! **************************************************************************************************
829 SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
830 TYPE(mp_cart_type), INTENT(IN) :: mp_comm
831 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
832 TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
833 LOGICAL, INTENT(IN) :: transposed
834 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
835
836 CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_small'
837
838 INTEGER :: handle
839 INTEGER(KIND=int_8), DIMENSION(2) :: dims
840 INTEGER, DIMENSION(2) :: pdims
841 LOGICAL :: nodata_prv
842 TYPE(dbt_tas_dist_arb) :: new_col_dist, new_row_dist
843 TYPE(dbt_tas_distribution_type) :: dist
844
845 CALL timeset(routinen, handle)
846
847 IF (PRESENT(nodata)) THEN
848 nodata_prv = nodata
849 ELSE
850 nodata_prv = .false.
851 END IF
852
853 pdims = mp_comm%num_pe_cart
854
855 dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
856
857 IF (transposed) CALL swap(dims)
858
859 IF (.NOT. transposed) THEN
860 new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
861 new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
862 CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.true.)
863 CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
864 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
865 ELSE
866 new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
867 new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
868 CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.true.)
869 CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
870 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
871 END IF
872 IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
873
874 CALL timestop(handle)
875
876 END SUBROUTINE
877
878! **************************************************************************************************
879!> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
880!> with the same split factor.
881!> \param matrix1_in ...
882!> \param matrix2_in ...
883!> \param matrix1_out ...
884!> \param matrix2_out ...
885!> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
886!> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
887!> \param trans1 transpose flag of matrix1_in for multiplication
888!> \param trans2 transpose flag of matrix2_in for multiplication
889!> \param optimize_dist experimental: optimize matrix splitting and distribution
890!> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
891!> \param opt_nsplit ...
892!> \param split_rc_1 Whether to split rows or columns for matrix 1
893!> \param split_rc_2 Whether to split rows or columns for matrix 2
894!> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
895!> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
896!> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
897!> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
898!> \param comm_new returns the new communicator only if optimize_dist
899!> \param unit_nr output unit
900!> \author Patrick Seewald
901! **************************************************************************************************
902 SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903 optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904 move_data_1, move_data_2, comm_new, unit_nr)
905 TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix1_in, matrix2_in
906 TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix1_out, matrix2_out
907 LOGICAL, INTENT(OUT) :: new1, new2
908 LOGICAL, INTENT(INOUT) :: trans1, trans2
909 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
910 INTEGER, INTENT(IN), OPTIONAL :: nsplit
911 LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
912 INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2
913 LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
914 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
915 TYPE(mp_cart_type), INTENT(OUT), OPTIONAL :: comm_new
916 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
917
918 CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_compatible'
919
920 INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
921 unit_nr_prv
922 INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923 INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref
924 INTEGER, DIMENSION(2) :: pdims
925 LOGICAL :: nodata1_prv, nodata2_prv, &
926 optimize_dist_prv, trans1_newdist, &
927 trans2_newdist
928 TYPE(dbt_tas_dist_cyclic) :: col_dist_1, col_dist_2, row_dist_1, &
929 row_dist_2
930 TYPE(dbt_tas_distribution_type) :: dist_1, dist_2
931 TYPE(dbt_tas_split_info) :: split_info
932 TYPE(mp_cart_type) :: mp_comm
933
934 CALL timeset(routinen, handle)
935 new1 = .false.; new2 = .false.
936
937 IF (PRESENT(nodata1)) THEN
938 nodata1_prv = nodata1
939 ELSE
940 nodata1_prv = .false.
941 END IF
942
943 IF (PRESENT(nodata2)) THEN
944 nodata2_prv = nodata2
945 ELSE
946 nodata2_prv = .false.
947 END IF
948
949 unit_nr_prv = prep_output_unit(unit_nr)
950
951 NULLIFY (matrix1_out, matrix2_out)
952
953 IF (PRESENT(optimize_dist)) THEN
954 optimize_dist_prv = optimize_dist
955 ELSE
956 optimize_dist_prv = .false.
957 END IF
958
959 dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
960 dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
961 nze1 = dbt_tas_get_nze_total(matrix1_in)
962 nze2 = dbt_tas_get_nze_total(matrix2_in)
963
964 IF (trans1) split_rc_1 = mod(split_rc_1, 2) + 1
965
966 IF (trans2) split_rc_2 = mod(split_rc_2, 2) + 1
967
968 IF (nze1 >= nze2) THEN
969 ref = 1
970 split_rc_ref = split_rc_1
971 dims_ref = dims1
972 ELSE
973 ref = 2
974 split_rc_ref = split_rc_2
975 dims_ref = dims2
976 END IF
977
978 IF (PRESENT(nsplit)) THEN
979 nsplit_prv = nsplit
980 ELSE
981 nsplit_prv = 0
982 END IF
983
984 IF (optimize_dist_prv) THEN
985 cpassert(PRESENT(comm_new))
986 END IF
987
988 IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
989 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
991 CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
992 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993 move_data=move_data_2, nodata=nodata2, opt_nsplit=.false.)
994 IF (unit_nr_prv > 0) THEN
995 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
996 trim(dbm_get_name(matrix1_in%matrix)), &
997 "and", trim(dbm_get_name(matrix2_in%matrix))
998 IF (new1) THEN
999 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1000 trim(dbm_get_name(matrix1_in%matrix)), ": Yes"
1001 ELSE
1002 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1003 trim(dbm_get_name(matrix1_in%matrix)), ": No"
1004 END IF
1005 IF (new2) THEN
1006 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1007 trim(dbm_get_name(matrix2_in%matrix)), ": Yes"
1008 ELSE
1009 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1010 trim(dbm_get_name(matrix2_in%matrix)), ": No"
1011 END IF
1012 END IF
1013 ELSE
1014
1015 IF (optimize_dist_prv) THEN
1016 IF (unit_nr_prv > 0) THEN
1017 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
1018 trim(dbm_get_name(matrix1_in%matrix)), &
1019 "and", trim(dbm_get_name(matrix2_in%matrix))
1020 END IF
1021
1022 trans1_newdist = (split_rc_1 == colsplit)
1023 trans2_newdist = (split_rc_2 == colsplit)
1024
1025 IF (trans1_newdist) THEN
1026 CALL swap(dims1)
1027 trans1 = .NOT. trans1
1028 END IF
1029
1030 IF (trans2_newdist) THEN
1031 CALL swap(dims2)
1032 trans2 = .NOT. trans2
1033 END IF
1034
1035 IF (nsplit_prv == 0) THEN
1036 SELECT CASE (split_rc_ref)
1037 CASE (rowsplit)
1038 d1 = dims_ref(1)
1039 d2 = dims_ref(2)
1040 CASE (colsplit)
1041 d1 = dims_ref(2)
1042 d2 = dims_ref(1)
1043 END SELECT
1044 nsplit_prv = int((d1 - 1)/d2 + 1)
1045 END IF
1046
1047 cpassert(nsplit_prv > 0)
1048
1049 CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
1050 comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
1051 CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
1052
1053 pdims = comm_new%num_pe_cart
1054
1055 ! use a very simple cyclic distribution that may not be load balanced if block
1056 ! sizes are not equal. However we can not use arbitrary distributions
1057 ! for large dimensions since this would require storing distribution vectors as arrays
1058 ! which can not be stored for large dimensions.
1059 row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
1060 col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
1061
1062 row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
1063 col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
1064
1065 CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
1066 CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
1067 CALL dbt_tas_release_info(split_info)
1068
1069 ALLOCATE (matrix1_out)
1070 IF (.NOT. trans1_newdist) THEN
1071 CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1072 matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.true.)
1073
1074 ELSE
1075 CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1076 matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.true.)
1077 END IF
1078
1079 ALLOCATE (matrix2_out)
1080 IF (.NOT. trans2_newdist) THEN
1081 CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1082 matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.true.)
1083 ELSE
1084 CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1085 matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.true.)
1086 END IF
1087
1088 IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089 IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1090 new1 = .true.
1091 new2 = .true.
1092
1093 ELSE
1094 SELECT CASE (ref)
1095 CASE (1)
1096 IF (unit_nr_prv > 0) THEN
1097 WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1098 trim(dbm_get_name(matrix2_in%matrix))
1099 END IF
1100
1101 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1103
1104 ALLOCATE (matrix2_out)
1105 CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106 nodata=nodata2, move_data=move_data_2)
1107 new2 = .true.
1108 CASE (2)
1109 IF (unit_nr_prv > 0) THEN
1110 WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1111 trim(dbm_get_name(matrix1_in%matrix))
1112 END IF
1113
1114 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115 move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1116
1117 ALLOCATE (matrix1_out)
1118 CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119 nodata=nodata1, move_data=move_data_1)
1120 new1 = .true.
1121 END SELECT
1122 END IF
1123 END IF
1124
1125 IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .true.
1126 IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .true.
1127
1128 CALL timestop(handle)
1129
1130 END SUBROUTINE
1131
1132! **************************************************************************************************
1133!> \brief Change split factor without redistribution
1134!> \param matrix_in ...
1135!> \param matrix_out ...
1136!> \param nsplit new split factor, set to 0 to not change split of matrix_in
1137!> \param split_rowcol split rows or columns
1138!> \param is_new whether matrix_out is new or a pointer to matrix_in
1139!> \param opt_nsplit whether nsplit should be optimized for current process grid
1140!> \param move_data memory optimization: move data such that matrix_in is empty on return.
1141!> \param nodata Data of matrix_in should not be copied to matrix_out
1142!> \author Patrick Seewald
1143! **************************************************************************************************
1144 SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1145 TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_in
1146 TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix_out
1147 INTEGER, INTENT(IN) :: nsplit, split_rowcol
1148 LOGICAL, INTENT(OUT) :: is_new
1149 LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
1150 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data
1151 LOGICAL, INTENT(IN), OPTIONAL :: nodata
1152
1153 CHARACTER(len=default_string_length) :: name
1154 INTEGER :: handle, nsplit_new, nsplit_old, &
1155 nsplit_prv, split_rc
1156 LOGICAL :: nodata_prv
1157 TYPE(dbt_tas_distribution_type) :: dist
1158 TYPE(dbt_tas_split_info) :: split_info
1159 TYPE(mp_cart_type) :: mp_comm
1160
1161 CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
1162 CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize
1163 CHARACTER(LEN=*), PARAMETER :: routinen = 'change_split'
1164
1165 NULLIFY (matrix_out)
1166
1167 is_new = .true.
1168
1169 CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
1170 split_rowcol=split_rc, nsplit=nsplit_old)
1171
1172 IF (nsplit == 0) THEN
1173 IF (split_rowcol == split_rc) THEN
1174 matrix_out => matrix_in
1175 is_new = .false.
1176 RETURN
1177 ELSE
1178 nsplit_prv = 1
1179 END IF
1180 ELSE
1181 nsplit_prv = nsplit
1182 END IF
1183
1184 CALL timeset(routinen, handle)
1185
1186 nodata_prv = .false.
1187 IF (PRESENT(nodata)) nodata_prv = nodata
1188
1189 CALL dbt_tas_get_info(matrix_in, name=name, &
1190 row_blk_size=rbsize, col_blk_size=cbsize, &
1191 proc_row_dist=rdist, proc_col_dist=cdist)
1192
1193 CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1194
1195 CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
1196
1197 IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1198 matrix_out => matrix_in
1199 is_new = .false.
1200 CALL dbt_tas_release_info(split_info)
1201 CALL timestop(handle)
1202 RETURN
1203 END IF
1204
1205 CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1206 split_info=split_info)
1207
1208 CALL dbt_tas_release_info(split_info)
1209
1210 ALLOCATE (matrix_out)
1211 CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.true.)
1212
1213 IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
1214
1215 IF (PRESENT(move_data)) THEN
1216 IF (.NOT. nodata_prv) THEN
1217 IF (move_data) CALL dbt_tas_clear(matrix_in)
1218 move_data = .true.
1219 END IF
1220 END IF
1221
1222 CALL timestop(handle)
1223 END SUBROUTINE
1224
1225! **************************************************************************************************
1226!> \brief Check whether matrices have same distribution and same split.
1227!> \param mat_a ...
1228!> \param mat_b ...
1229!> \param split_rc_a ...
1230!> \param split_rc_b ...
1231!> \param unit_nr ...
1232!> \return ...
1233!> \author Patrick Seewald
1234! **************************************************************************************************
1235 FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1236 TYPE(dbt_tas_type), INTENT(IN) :: mat_a, mat_b
1237 INTEGER, INTENT(IN) :: split_rc_a, split_rc_b
1238 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1239 LOGICAL :: dist_compatible
1240
1241 INTEGER :: numproc, same_local_rowcols, &
1242 split_check_a, split_check_b, &
1243 unit_nr_prv
1244 INTEGER(int_8), ALLOCATABLE, DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245 INTEGER, DIMENSION(2) :: pdims_a, pdims_b
1246 TYPE(dbt_tas_split_info) :: info_a, info_b
1247
1248 unit_nr_prv = prep_output_unit(unit_nr)
1249
1250 dist_compatible = .false.
1251
1252 info_a = dbt_tas_info(mat_a)
1253 info_b = dbt_tas_info(mat_b)
1254 CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
1255 CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
1256 IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
1257 IF (unit_nr_prv > 0) THEN
1258 WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
1259 WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
1260 END IF
1261 RETURN
1262 END IF
1263
1264 ! check if communicators are equivalent
1265 ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1266 ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
1267 numproc = info_b%mp_comm%num_pe
1268 pdims_a = info_a%mp_comm%num_pe_cart
1269 pdims_b = info_b%mp_comm%num_pe_cart
1270 IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1271 IF (unit_nr_prv > 0) THEN
1272 WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
1273 END IF
1274 RETURN
1275 END IF
1276
1277 ! check that distribution is the same by comparing local rows / columns for each matrix
1278 SELECT CASE (split_rc_a)
1279 CASE (rowsplit)
1280 CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
1281 CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
1282 CASE (colsplit)
1283 CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
1284 CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
1285 END SELECT
1286
1287 same_local_rowcols = merge(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
1288
1289 CALL info_a%mp_comm%sum(same_local_rowcols)
1290
1291 IF (same_local_rowcols == numproc) THEN
1292 dist_compatible = .true.
1293 ELSE
1294 IF (unit_nr_prv > 0) THEN
1295 WRITE (unit_nr_prv, *) "local rowcols not compatible"
1296 WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
1297 WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
1298 END IF
1299 END IF
1300
1301 END FUNCTION
1302
1303! **************************************************************************************************
1304!> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
1305!> \param template ...
1306!> \param matrix_in ...
1307!> \param matrix_out ...
1308!> \param trans ...
1309!> \param split_rc ...
1310!> \param nodata ...
1311!> \param move_data ...
1312!> \author Patrick Seewald
1313! **************************************************************************************************
1314 SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1315 TYPE(dbt_tas_type), INTENT(IN) :: template
1316 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1317 TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1318 LOGICAL, INTENT(INOUT) :: trans
1319 INTEGER, INTENT(IN) :: split_rc
1320 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1321
1322 CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1323
1324 TYPE(dbt_tas_distribution_type) :: dist_new
1325 TYPE(dbt_tas_split_info) :: info_template, info_matrix
1326 INTEGER :: dim_split_template, dim_split_matrix, &
1327 handle
1328 INTEGER, DIMENSION(2) :: pdims
1329 LOGICAL :: nodata_prv, transposed
1330 TYPE(mp_cart_type) :: mp_comm
1331 CHARACTER(LEN=*), PARAMETER :: routinen = 'reshape_mm_template'
1332
1333 CALL timeset(routinen, handle)
1334
1335 IF (PRESENT(nodata)) THEN
1336 nodata_prv = nodata
1337 ELSE
1338 nodata_prv = .false.
1339 END IF
1340
1341 info_template = dbt_tas_info(template)
1342 info_matrix = dbt_tas_info(matrix_in)
1343
1344 dim_split_template = info_template%split_rowcol
1345 dim_split_matrix = split_rc
1346
1347 transposed = dim_split_template .NE. dim_split_matrix
1348 IF (transposed) trans = .NOT. trans
1349
1350 pdims = info_template%mp_comm%num_pe_cart
1351
1352 SELECT CASE (dim_split_template)
1353 CASE (1)
1354 IF (.NOT. transposed) THEN
1355 ALLOCATE (row_dist, source=template%dist%row_dist)
1356 ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1357 ELSE
1358 ALLOCATE (row_dist, source=template%dist%row_dist)
1359 ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1360 END IF
1361 CASE (2)
1362 IF (.NOT. transposed) THEN
1363 ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1364 ALLOCATE (col_dist, source=template%dist%col_dist)
1365 ELSE
1366 ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1367 ALLOCATE (col_dist, source=template%dist%col_dist)
1368 END IF
1369 END SELECT
1370
1371 CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
1372 CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1373 IF (.NOT. transposed) THEN
1374 CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1375 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
1376 ELSE
1377 CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1378 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
1379 END IF
1380
1381 IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1382
1383 CALL timestop(handle)
1384
1385 END SUBROUTINE
1386
1387! **************************************************************************************************
1388!> \brief Estimate sparsity pattern of C resulting from A x B = C
1389!> by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
1390!> \param transa ...
1391!> \param transb ...
1392!> \param transc ...
1393!> \param matrix_a ...
1394!> \param matrix_b ...
1395!> \param matrix_c ...
1396!> \param estimated_nze ...
1397!> \param filter_eps ...
1398!> \param unit_nr ...
1399!> \param retain_sparsity ...
1400!> \author Patrick Seewald
1401! **************************************************************************************************
1402 SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403 estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404 LOGICAL, INTENT(IN) :: transa, transb, transc
1405 TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c
1406 INTEGER(int_8), INTENT(OUT) :: estimated_nze
1407 REAL(kind=dp), INTENT(IN), OPTIONAL :: filter_eps
1408 INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1409 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1410
1411 CHARACTER(LEN=*), PARAMETER :: routinen = 'dbt_tas_estimate_result_nze'
1412
1413 INTEGER :: col_size, handle, row_size
1414 INTEGER(int_8) :: col, row
1415 LOGICAL :: retain_sparsity_prv
1416 TYPE(dbt_tas_iterator) :: iter
1417 TYPE(dbt_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1418 matrix_c_bnorm
1419 TYPE(mp_cart_type) :: mp_comm
1420
1421 CALL timeset(routinen, handle)
1422
1423 IF (PRESENT(retain_sparsity)) THEN
1424 retain_sparsity_prv = retain_sparsity
1425 ELSE
1426 retain_sparsity_prv = .false.
1427 END IF
1428
1429 IF (.NOT. retain_sparsity_prv) THEN
1430 ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431 CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432 CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433 CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.true.)
1434
1435 CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
1436 matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437 filter_eps=filter_eps, move_data_a=.true., move_data_b=.true., &
1438 simple_split=.true., unit_nr=unit_nr)
1439 CALL dbt_tas_destroy(matrix_a_bnorm)
1440 CALL dbt_tas_destroy(matrix_b_bnorm)
1441
1442 DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1443 ELSE
1444 matrix_c_bnorm => matrix_c
1445 END IF
1446
1447 estimated_nze = 0
1448!$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
1449!$OMP PRIVATE(iter,row,col,row_size,col_size)
1450 CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
1451 DO WHILE (dbt_tas_iterator_blocks_left(iter))
1452 CALL dbt_tas_iterator_next_block(iter, row, col)
1453 row_size = matrix_c%row_blk_size%data(row)
1454 col_size = matrix_c%col_blk_size%data(col)
1455 estimated_nze = estimated_nze + row_size*col_size
1456 END DO
1457 CALL dbt_tas_iterator_stop(iter)
1458!$OMP END PARALLEL
1459
1460 CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
1461 CALL mp_comm%sum(estimated_nze)
1462
1463 IF (.NOT. retain_sparsity_prv) THEN
1464 CALL dbt_tas_destroy(matrix_c_bnorm)
1465 DEALLOCATE (matrix_c_bnorm)
1466 END IF
1467
1468 CALL timestop(handle)
1469
1470 END SUBROUTINE
1471
1472! **************************************************************************************************
1473!> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1474!> This estimate is based on the minimization of communication volume whereby the
1475!> communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
1476!> \param max_mm_dim ...
1477!> \param nze_a number of non-zeroes in A
1478!> \param nze_b number of non-zeroes in B
1479!> \param nze_c number of non-zeroes in C
1480!> \param numnodes number of MPI ranks
1481!> \return estimated split factor
1482!> \author Patrick Seewald
1483! **************************************************************************************************
1484 FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1485 INTEGER, INTENT(IN) :: max_mm_dim
1486 INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c
1487 INTEGER, INTENT(IN) :: numnodes
1488 INTEGER :: nsplit
1489
1490 INTEGER(KIND=int_8) :: max_nze, min_nze
1491 REAL(dp) :: s_opt_factor
1492
1493 s_opt_factor = 1.0_dp ! Could be further tuned.
1494
1495 SELECT CASE (max_mm_dim)
1496 CASE (1)
1497 min_nze = max(nze_b, 1_int_8)
1498 max_nze = max(maxval([nze_a, nze_c]), 1_int_8)
1499 CASE (2)
1500 min_nze = max(nze_c, 1_int_8)
1501 max_nze = max(maxval([nze_a, nze_b]), 1_int_8)
1502 CASE (3)
1503 min_nze = max(nze_a, 1_int_8)
1504 max_nze = max(maxval([nze_b, nze_c]), 1_int_8)
1505 CASE DEFAULT
1506 cpabort("")
1507 END SELECT
1508
1509 nsplit = int(min(int(numnodes, kind=int_8), nint(real(max_nze, dp)/(real(min_nze, dp)*s_opt_factor), kind=int_8)))
1510 IF (nsplit == 0) nsplit = 1
1511
1512 END FUNCTION
1513
1514! **************************************************************************************************
1515!> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
1516!> \param matrix_in ...
1517!> \param matrix_out ...
1518!> \param nodata ...
1519!> \author Patrick Seewald
1520! **************************************************************************************************
1521 SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1522 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1523 TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1524 LOGICAL, INTENT(IN), OPTIONAL :: nodata
1525
1526 CHARACTER(len=default_string_length) :: name
1527 INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528 LOGICAL :: nodata_prv
1529 REAL(dp), DIMENSION(1, 1) :: blk_put
1530 REAL(dp), DIMENSION(:, :), POINTER :: blk_get
1531 TYPE(dbt_tas_blk_size_one) :: col_blk_size, row_blk_size
1532 TYPE(dbt_tas_iterator) :: iter
1533
1534!REAL(dp), DIMENSION(:, :), POINTER :: dbt_put
1535
1536 cpassert(matrix_in%valid)
1537
1538 IF (PRESENT(nodata)) THEN
1539 nodata_prv = nodata
1540 ELSE
1541 nodata_prv = .false.
1542 END IF
1543
1544 CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1545 row_blk_size = dbt_tas_blk_size_one(nblkrows)
1546 col_blk_size = dbt_tas_blk_size_one(nblkcols)
1547
1548 ! not sure if assumption that same distribution can be taken still holds
1549 CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1550
1551 IF (.NOT. nodata_prv) THEN
1552 CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
1553!$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
1554!$OMP PRIVATE(iter,row,column,blk_get,blk_put)
1555 CALL dbt_tas_iterator_start(iter, matrix_in)
1556 DO WHILE (dbt_tas_iterator_blocks_left(iter))
1557 CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
1558 blk_put(1, 1) = norm2(blk_get)
1559 CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
1560 END DO
1561 CALL dbt_tas_iterator_stop(iter)
1562!$OMP END PARALLEL
1563 END IF
1564
1565 END SUBROUTINE
1566
1567! **************************************************************************************************
1568!> \brief Convert a DBM matrix to a new process grid
1569!> \param mp_comm_cart new process grid
1570!> \param matrix_in ...
1571!> \param matrix_out ...
1572!> \param move_data memory optimization: move data such that matrix_in is empty on return.
1573!> \param nodata Data of matrix_in should not be copied to matrix_out
1574!> \param optimize_pgrid Whether to change process grid
1575!> \author Patrick Seewald
1576! **************************************************************************************************
1577 SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1578 TYPE(mp_cart_type), INTENT(IN) :: mp_comm_cart
1579 TYPE(dbm_type), INTENT(INOUT) :: matrix_in
1580 TYPE(dbm_type), INTENT(OUT) :: matrix_out
1581 LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata, optimize_pgrid
1582
1583 CHARACTER(LEN=*), PARAMETER :: routinen = 'convert_to_new_pgrid'
1584
1585 CHARACTER(len=default_string_length) :: name
1586 INTEGER :: handle, nbcols, nbrows
1587 INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: col_dist, rbsize, rcsize, row_dist
1588 INTEGER, DIMENSION(2) :: pdims
1589 LOGICAL :: nodata_prv, optimize_pgrid_prv
1590 TYPE(dbm_distribution_obj) :: dist, dist_old
1591
1592 NULLIFY (row_dist, col_dist, rbsize, rcsize)
1593
1594 CALL timeset(routinen, handle)
1595
1596 IF (PRESENT(optimize_pgrid)) THEN
1597 optimize_pgrid_prv = optimize_pgrid
1598 ELSE
1599 optimize_pgrid_prv = .true.
1600 END IF
1601
1602 IF (PRESENT(nodata)) THEN
1603 nodata_prv = nodata
1604 ELSE
1605 nodata_prv = .false.
1606 END IF
1607
1608 name = dbm_get_name(matrix_in)
1609
1610 IF (.NOT. optimize_pgrid_prv) THEN
1611 CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
1612 IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
1613 CALL timestop(handle)
1614 RETURN
1615 END IF
1616
1617 rbsize => dbm_get_row_block_sizes(matrix_in)
1618 rcsize => dbm_get_col_block_sizes(matrix_in)
1619 nbrows = SIZE(rbsize)
1620 nbcols = SIZE(rcsize)
1621 dist_old = dbm_get_distribution(matrix_in)
1622 pdims = mp_comm_cart%num_pe_cart
1623
1624 ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1625 CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1626 CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1627
1628 CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
1629 DEALLOCATE (row_dist, col_dist)
1630
1631 CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1632 CALL dbm_distribution_release(dist)
1633
1634 IF (.NOT. nodata_prv) THEN
1635 CALL dbm_redistribute(matrix_in, matrix_out)
1636 IF (PRESENT(move_data)) THEN
1637 IF (move_data) CALL dbm_clear(matrix_in)
1638 END IF
1639 END IF
1640
1641 CALL timestop(handle)
1642 END SUBROUTINE
1643
1644! **************************************************************************************************
1645!> \brief ...
1646!> \param matrix ...
1647!> \author Patrick Seewald
1648! **************************************************************************************************
1649 SUBROUTINE dbt_tas_batched_mm_init(matrix)
1650 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1651
1652 CALL dbt_tas_set_batched_state(matrix, state=1)
1653 ALLOCATE (matrix%mm_storage)
1654 matrix%mm_storage%batched_out = .false.
1655 END SUBROUTINE
1656
1657! **************************************************************************************************
1658!> \brief ...
1659!> \param matrix ...
1660!> \author Patrick Seewald
1661! **************************************************************************************************
1663 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1664
1665 INTEGER :: handle
1666
1667 CALL matrix%dist%info%mp_comm%sync()
1668 CALL timeset("dbt_tas_total", handle)
1669
1670 IF (matrix%do_batched == 0) RETURN
1671
1672 IF (matrix%mm_storage%batched_out) THEN
1673 CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1674 END IF
1675
1676 CALL dbt_tas_batched_mm_complete(matrix)
1677
1678 matrix%mm_storage%batched_out = .false.
1679
1680 DEALLOCATE (matrix%mm_storage)
1681 CALL dbt_tas_set_batched_state(matrix, state=0)
1682
1683 CALL matrix%dist%info%mp_comm%sync()
1684 CALL timestop(handle)
1685
1686 END SUBROUTINE
1687
1688! **************************************************************************************************
1689!> \brief set state flags during batched multiplication
1690!> \param matrix ...
1691!> \param state 0 no batched MM
1692!> 1 batched MM but mm_storage not yet initialized
1693!> 2 batched MM and mm_storage requires update
1694!> 3 batched MM and mm_storage initialized
1695!> \param opt_grid whether process grid was already optimized and should not be changed
1696!> \author Patrick Seewald
1697! **************************************************************************************************
1698 SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
1699 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1700 INTEGER, INTENT(IN), OPTIONAL :: state
1701 LOGICAL, INTENT(IN), OPTIONAL :: opt_grid
1702
1703 IF (PRESENT(opt_grid)) THEN
1704 matrix%has_opt_pgrid = opt_grid
1705 matrix%dist%info%strict_split(1) = .true.
1706 END IF
1707
1708 IF (PRESENT(state)) THEN
1709 matrix%do_batched = state
1710 SELECT CASE (state)
1711 CASE (0, 1)
1712 ! reset to default
1713 IF (matrix%has_opt_pgrid) THEN
1714 matrix%dist%info%strict_split(1) = .true.
1715 ELSE
1716 matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1717 END IF
1718 CASE (2, 3)
1719 matrix%dist%info%strict_split(1) = .true.
1720 CASE DEFAULT
1721 cpabort("should not happen")
1722 END SELECT
1723 END IF
1724 END SUBROUTINE
1725
1726! **************************************************************************************************
1727!> \brief ...
1728!> \param matrix ...
1729!> \param warn ...
1730!> \author Patrick Seewald
1731! **************************************************************************************************
1732 SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
1733 TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1734 LOGICAL, INTENT(IN), OPTIONAL :: warn
1735
1736 IF (matrix%do_batched == 0) RETURN
1737 associate(storage => matrix%mm_storage)
1738 IF (PRESENT(warn)) THEN
1739 IF (warn .AND. matrix%do_batched == 3) THEN
1740 CALL cp_warn(__location__, &
1741 "Optimizations for batched multiplication are disabled because of conflicting data access")
1742 END IF
1743 END IF
1744 IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
1745
1746 CALL dbt_tas_merge(storage%store_batched%matrix, &
1747 storage%store_batched_repl, move_data=.true.)
1748
1749 CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.true., &
1750 transposed=storage%batched_trans, move_data=.true.)
1751 CALL dbt_tas_destroy(storage%store_batched)
1752 DEALLOCATE (storage%store_batched)
1753 END IF
1754
1755 IF (ASSOCIATED(storage%store_batched_repl)) THEN
1756 CALL dbt_tas_destroy(storage%store_batched_repl)
1757 DEALLOCATE (storage%store_batched_repl)
1758 END IF
1759 END associate
1760
1761 CALL dbt_tas_set_batched_state(matrix, state=2)
1762
1763 END SUBROUTINE
1764
1765END MODULE
subroutine, public dbm_multiply(transa, transb, alpha, matrix_a, matrix_b, beta, matrix_c, retain_sparsity, filter_eps, flop)
Computes matrix product: matrix_c = alpha * matrix_a * matrix_b + beta * matrix_c.
Definition dbm_api.F:736
subroutine, public dbm_redistribute(matrix, redist)
Copies content of matrix_b into matrix_a. Matrices may have different distributions.
Definition dbm_api.F:412
subroutine, public dbm_zero(matrix)
Sets all blocks in the given matrix to zero.
Definition dbm_api.F:662
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
Definition dbm_api.F:529
subroutine, public dbm_create_from_template(matrix, name, template)
Creates a new matrix from given template, reusing dist and row/col_block_sizes.
Definition dbm_api.F:265
pure integer function, public dbm_get_nze(matrix)
Returns the number of local Non-Zero Elements of the given matrix.
Definition dbm_api.F:1057
subroutine, public dbm_scale(matrix, alpha)
Multiplies all entries in the given matrix by the given factor alpha.
Definition dbm_api.F:631
subroutine, public dbm_distribution_release(dist)
Decreases the reference counter of the given distribution.
Definition dbm_api.F:1401
type(dbm_distribution_obj) function, public dbm_get_distribution(matrix)
Returns the distribution of the given matrix.
Definition dbm_api.F:1266
subroutine, public dbm_create(matrix, name, dist, row_block_sizes, col_block_sizes)
Creates a new matrix.
Definition dbm_api.F:293
integer function, dimension(:), pointer, contiguous, public dbm_get_row_block_sizes(matrix)
Returns the row block sizes of the given matrix.
Definition dbm_api.F:1103
character(len=default_string_length) function, public dbm_get_name(matrix)
Returns the name of the matrix of the given matrix.
Definition dbm_api.F:1023
subroutine, public dbm_add(matrix_a, matrix_b)
Adds matrix_b to matrix_a.
Definition dbm_api.F:692
subroutine, public dbm_copy(matrix_a, matrix_b)
Copies content of matrix_b into matrix_a. Matrices must have the same row/col block sizes and distrib...
Definition dbm_api.F:380
subroutine, public dbm_release(matrix)
Releases a matrix and all its ressources.
Definition dbm_api.F:354
integer function, dimension(:), pointer, contiguous, public dbm_get_col_block_sizes(matrix)
Returns the column block sizes of the given matrix.
Definition dbm_api.F:1130
subroutine, public dbm_distribution_new(dist, mp_comm, row_dist_block, col_dist_block)
Creates a new two dimensional distribution.
Definition dbm_api.F:1294
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
integer(kind=int_8) function, public dbt_tas_get_nze_total(matrix)
Get total number of non-zero elements.
subroutine, public dbt_tas_iterator_start(iter, matrix_in)
As dbm_iterator_start.
logical function, public dbt_tas_iterator_blocks_left(iter)
As dbm_iterator_blocks_left.
integer(kind=int_8) function, public dbt_tas_nblkrows_total(matrix)
...
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
subroutine, public dbt_tas_iterator_stop(iter)
As dbm_iterator_stop.
subroutine, public dbt_tas_distribution_new(dist, mp_comm, row_dist, col_dist, split_info, nosplit)
create new distribution. Exactly like dbm_distribution_new but with custom types for row_dist and col...
type(dbt_tas_split_info) function, public dbt_tas_info(matrix)
get info on mpi grid splitting
integer(kind=int_8) function, public dbt_tas_nblkcols_total(matrix)
...
subroutine, public dbt_tas_filter(matrix, eps)
As dbm_filter.
subroutine, public dbt_tas_clear(matrix)
Clear matrix (erase all data)
subroutine, public dbt_tas_destroy(matrix)
...
subroutine, public dbt_tas_put_block(matrix, row, col, block, summation)
As dbm_put_block.
Global data (distribution and block sizes) for tall-and-skinny matrices For very sparse matrices with...
subroutine, public dbt_tas_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
type(dbt_tas_dist_arb) function, public dbt_tas_dist_arb_default(nprowcol, nmrowcol, dbt_sizes)
Distribution that is more or less cyclic (round robin) and load balanced with different weights for e...
tall-and-skinny matrices: Input / Output
Definition dbt_tas_io.F:12
integer function, public prep_output_unit(unit_nr)
...
Definition dbt_tas_io.F:264
subroutine, public dbt_tas_write_matrix_info(matrix, unit_nr, full_info)
Write basic infos of tall-and-skinny matrix: block dimensions, full dimensions, process grid dimensio...
Definition dbt_tas_io.F:59
subroutine, public dbt_tas_write_dist(matrix, unit_nr, full_info)
Write info on tall-and-skinny matrix distribution & load balance.
Definition dbt_tas_io.F:122
subroutine, public dbt_tas_write_split_info(info, unit_nr, name)
Print info on how matrix is split.
Definition dbt_tas_io.F:214
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
Definition dbt_tas_mm.F:18
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
subroutine, public dbt_tas_batched_mm_init(matrix)
...
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
Definition dbt_tas_mm.F:105
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
communication routines to reshape / replicate / merge tall-and-skinny matrices.
subroutine, public dbt_tas_merge(matrix_out, matrix_in, summation, move_data)
Merge submatrices of matrix_in to matrix_out by sum.
subroutine, public dbt_tas_replicate(matrix_in, info, matrix_out, nodata, move_data)
Replicate matrix_in such that each submatrix of matrix_out is an exact copy of matrix_in.
recursive subroutine, public dbt_tas_reshape(matrix_in, matrix_out, summation, transposed, move_data)
copy data (involves reshape)
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
subroutine, public dbt_tas_get_split_info(info, mp_comm, nsplit, igroup, mp_comm_group, split_rowcol, pgrid_offset)
Get info on split.
integer, parameter, public colsplit
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
real(dp), parameter, public default_nsplit_accept_ratio
subroutine, public dbt_tas_info_hold(split_info)
...
logical function, public accept_pgrid_dims(dims, relative)
Whether to accept proposed process grid dimensions (based on ratio of dimensions)
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
often used utilities for tall-and-skinny matrices
Defines the basic variable types.
Definition kinds.F:23
integer, parameter, public int_8
Definition kinds.F:54
integer, parameter, public dp
Definition kinds.F:34
integer, parameter, public default_string_length
Definition kinds.F:57
Interface to the message passing library MPI.
type for blocks of size one
type for arbitrary distributions
type for cyclic (round robin) distribution: