58#include "../../base/base_uses.f90"
63 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'dbt_tas_mm'
102 RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103 optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104 move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
106 LOGICAL,
INTENT(IN) :: transa, transb, transc
107 REAL(
dp),
INTENT(IN) :: alpha
108 TYPE(
dbt_tas_type),
INTENT(INOUT),
TARGET :: matrix_a, matrix_b
109 REAL(
dp),
INTENT(IN) :: beta
111 LOGICAL,
INTENT(IN),
OPTIONAL :: optimize_dist
113 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: filter_eps
114 INTEGER(KIND=int_8),
INTENT(OUT),
OPTIONAL :: flop
115 LOGICAL,
INTENT(IN),
OPTIONAL :: move_data_a, move_data_b, &
116 retain_sparsity, simple_split
117 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
118 LOGICAL,
INTENT(IN),
OPTIONAL :: log_verbose
120 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_tas_multiply'
122 INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123 nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
125 INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126 INTEGER(KIND=int_8),
DIMENSION(2) :: dims_a, dims_b, dims_c
127 INTEGER(KIND=int_8),
DIMENSION(3) :: dims
128 INTEGER,
DIMENSION(2) :: pdims, pdims_sub
129 LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130 simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131 REAL(kind=
dp) :: filter_eps_prv
132 TYPE(
dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
134 TYPE(
dbt_tas_type),
POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135 matrix_b_rs, matrix_c_rep, matrix_c_rs
136 TYPE(
mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137 mp_comm_mm, mp_comm_opt
139 CALL timeset(routinen, handle)
140 CALL matrix_a%dist%info%mp_comm%sync()
141 CALL timeset(
"dbt_tas_total", handle2)
143 NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
147 IF (
PRESENT(simple_split))
THEN
148 simple_split_prv = simple_split
150 simple_split_prv = .false.
153 IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .true.
157 IF (
PRESENT(retain_sparsity))
THEN
158 IF (retain_sparsity) nodata_3 = .false.
164 IF (matrix_a%do_batched > 0)
THEN
166 IF (matrix_a%do_batched == 3)
THEN
167 cpassert(batched_repl == 0)
171 nsplit=nsplit_batched)
172 cpassert(nsplit_batched > 0)
173 max_mm_dim_batched = 3
177 IF (matrix_b%do_batched > 0)
THEN
179 IF (matrix_b%do_batched == 3)
THEN
180 cpassert(batched_repl == 0)
184 nsplit=nsplit_batched)
185 cpassert(nsplit_batched > 0)
186 max_mm_dim_batched = 1
190 IF (matrix_c%do_batched > 0)
THEN
192 IF (matrix_c%do_batched == 3)
THEN
193 cpassert(batched_repl == 0)
197 nsplit=nsplit_batched)
198 cpassert(nsplit_batched > 0)
199 max_mm_dim_batched = 2
206 IF (
PRESENT(move_data_a)) move_a = move_data_a
207 IF (
PRESENT(move_data_b)) move_b = move_data_b
209 transa_prv = transa; transb_prv = transb; transc_prv = transc
215 IF (unit_nr_prv > 0)
THEN
216 WRITE (unit_nr_prv,
"(A)") repeat(
"-", 80)
217 WRITE (unit_nr_prv,
"(A)") &
218 "DBT TAS MATRIX MULTIPLICATION: "// &
222 WRITE (unit_nr_prv,
"(A)") repeat(
"-", 80)
225 IF (unit_nr_prv > 0)
THEN
226 WRITE (unit_nr_prv,
"(T2,A)") &
227 "BATCHED PROCESSING OF MATMUL"
228 IF (batched_repl > 0)
THEN
229 WRITE (unit_nr_prv,
"(T4,A,T80,I1)")
"reusing replicated matrix:", batched_repl
242 dims_c = [dims_a(1), dims_b(2)]
244 IF (.NOT. (dims_a(2) .EQ. dims_b(1)))
THEN
245 cpabort(
"inconsistent matrix dimensions")
248 dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
250 IF (unit_nr_prv > 0)
THEN
251 WRITE (unit_nr_prv,
"(T2,A, 1X, I12, 1X, I12, 1X, I12)")
"mm dims:", dims(1), dims(2), dims(3)
255 numproc = mp_comm%num_pe
261 IF (.NOT. simple_split_prv)
THEN
262 CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263 estimated_nze=nze_c, filter_eps=filter_eps, &
264 retain_sparsity=retain_sparsity)
266 max_mm_dim = maxloc(dims, 1)
267 nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
270 IF (unit_nr_prv > 0)
THEN
271 WRITE (unit_nr_prv,
"(T2,A)") &
273 WRITE (unit_nr_prv,
"(T4,A,T68,I13)")
"Est. number of matrix elements per CPU of result matrix:", &
274 (nze_c + numproc - 1)/numproc
276 WRITE (unit_nr_prv,
"(T4,A,T68,I13)")
"Est. optimal split factor:", nsplit
279 ELSEIF (batched_repl > 0)
THEN
280 nsplit = nsplit_batched
282 max_mm_dim = max_mm_dim_batched
283 IF (unit_nr_prv > 0)
THEN
284 WRITE (unit_nr_prv,
"(T2,A)") &
286 WRITE (unit_nr_prv,
"(T4,A,T68,I13)")
"Est. optimal split factor:", nsplit
291 max_mm_dim = maxloc(dims, 1)
296 SELECT CASE (max_mm_dim)
300 CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301 new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
303 opt_nsplit=batched_repl == 0, &
304 split_rc_1=split_a, split_rc_2=split_c, &
305 nodata2=nodata_3, comm_new=comm_tmp, &
306 move_data_1=move_a, unit_nr=unit_nr_prv)
312 IF (matrix_b%do_batched <= 2)
THEN
313 ALLOCATE (matrix_b_rs)
314 CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
321 IF (unit_nr_prv > 0)
THEN
322 IF (.NOT. tr_case)
THEN
323 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"| x + = |"
325 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"--T x + = --T"
332 CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333 optimize_dist=optimize_dist, &
335 opt_nsplit=batched_repl == 0, &
336 split_rc_1=split_a, split_rc_2=split_b, &
338 move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
343 IF (matrix_c%do_batched == 1)
THEN
344 matrix_c%mm_storage%batched_beta = beta
345 ELSEIF (matrix_c%do_batched > 1)
THEN
346 matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
349 IF (matrix_c%do_batched <= 2)
THEN
350 ALLOCATE (matrix_c_rs)
351 CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
355 IF (.NOT. nodata_3)
CALL dbm_zero(matrix_c_rs%matrix)
357 IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358 ELSEIF (matrix_c%do_batched == 3)
THEN
359 matrix_c_rs => matrix_c%mm_storage%store_batched
362 new_c = matrix_c%do_batched == 0
365 IF (unit_nr_prv > 0)
THEN
366 IF (.NOT. tr_case)
THEN
367 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"-- x --T = +"
369 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"|T x | = +"
376 CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377 transc_prv, optimize_dist=optimize_dist, &
379 opt_nsplit=batched_repl == 0, &
380 split_rc_1=split_b, split_rc_2=split_c, &
381 nodata2=nodata_3, comm_new=comm_tmp, &
382 move_data_1=move_b, unit_nr=unit_nr_prv)
387 IF (matrix_a%do_batched <= 2)
THEN
388 ALLOCATE (matrix_a_rs)
389 CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
396 IF (unit_nr_prv > 0)
THEN
397 IF (.NOT. tr_case)
THEN
398 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"+ x -- = --"
400 WRITE (unit_nr_prv,
"(T2,A, 1X, A)")
"mm case:",
"+ x |T = |T"
408 numproc = mp_comm%num_pe
409 pdims_sub = mp_comm_group%num_pe_cart
413 IF (
PRESENT(filter_eps))
THEN
414 filter_eps_prv = filter_eps
416 filter_eps_prv = 0.0_dp
419 IF (unit_nr_prv /= 0)
THEN
420 IF (unit_nr_prv > 0)
THEN
421 WRITE (unit_nr_prv,
"(T2, A)")
"SPLIT / PARALLELIZATION INFO"
427 IF (unit_nr_prv > 0)
THEN
429 WRITE (unit_nr_prv,
"(T4, A, 1X, A)")
"Change process grid:",
"Yes"
431 WRITE (unit_nr_prv,
"(T4, A, 1X, A)")
"Change process grid:",
"No"
437 CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
440 SELECT CASE (max_mm_dim)
442 IF (matrix_b%do_batched <= 2)
THEN
443 ALLOCATE (matrix_b_rep)
445 IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2)
THEN
446 matrix_b%mm_storage%store_batched_repl => matrix_b_rep
449 ELSEIF (matrix_b%do_batched == 3)
THEN
450 matrix_b_rep => matrix_b%mm_storage%store_batched_repl
455 DEALLOCATE (matrix_b_rs)
457 IF (unit_nr_prv /= 0)
THEN
462 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
470 DEALLOCATE (matrix_a_rs)
472 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473 move_data=matrix_b%do_batched == 0)
478 IF (matrix_b%do_batched == 0)
THEN
480 DEALLOCATE (matrix_b_rep)
483 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
488 CALL matrix_a%dist%info%mp_comm%sync()
489 CALL timeset(
"dbt_tas_dbm", handle4)
490 IF (.NOT. tr_case)
THEN
491 CALL timeset(
"dbt_tas_mm_1N", handle3)
493 CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
494 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496 CALL timestop(handle3)
498 CALL timeset(
"dbt_tas_mm_1T", handle3)
499 CALL dbm_multiply(transa=.true., transb=.false., alpha=alpha, &
500 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
503 CALL timestop(handle3)
505 CALL matrix_a%dist%info%mp_comm%sync()
506 CALL timestop(handle4)
513 IF (.NOT. new_c)
THEN
514 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
516 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
521 IF (
PRESENT(filter_eps))
CALL dbt_tas_filter(matrix_c_rs, filter_eps)
523 IF (unit_nr_prv /= 0)
THEN
528 IF (matrix_c%do_batched <= 1)
THEN
529 ALLOCATE (matrix_c_rep)
531 IF (matrix_c%do_batched == 1)
THEN
532 matrix_c%mm_storage%store_batched_repl => matrix_c_rep
535 ELSEIF (matrix_c%do_batched == 2)
THEN
536 ALLOCATE (matrix_c_rep)
539 IF (.NOT. nodata_3)
CALL dbm_zero(matrix_c_rep%matrix)
540 matrix_c%mm_storage%store_batched_repl => matrix_c_rep
542 ELSEIF (matrix_c%do_batched == 3)
THEN
543 matrix_c_rep => matrix_c%mm_storage%store_batched_repl
546 IF (unit_nr_prv /= 0)
THEN
551 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
559 DEALLOCATE (matrix_a_rs)
562 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
569 DEALLOCATE (matrix_b_rs)
572 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
577 CALL matrix_a%dist%info%mp_comm%sync()
578 CALL timeset(
"dbt_tas_dbm", handle4)
579 CALL timeset(
"dbt_tas_mm_2", handle3)
580 CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581 matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582 filter_eps=filter_eps_prv/real(nsplit, kind=
dp), retain_sparsity=retain_sparsity, flop=flop)
583 CALL matrix_a%dist%info%mp_comm%sync()
584 CALL timestop(handle3)
585 CALL timestop(handle4)
592 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
597 IF (unit_nr_prv /= 0)
THEN
601 IF (matrix_c%do_batched == 0)
THEN
602 CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.true.)
604 matrix_c%mm_storage%batched_out = .true.
607 IF (matrix_c%do_batched == 0)
THEN
609 DEALLOCATE (matrix_c_rep)
612 IF (
PRESENT(filter_eps))
CALL dbt_tas_filter(matrix_c_rs, filter_eps)
621 IF (matrix_a%do_batched <= 2)
THEN
622 ALLOCATE (matrix_a_rep)
624 IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2)
THEN
625 matrix_a%mm_storage%store_batched_repl => matrix_a_rep
628 ELSEIF (matrix_a%do_batched == 3)
THEN
629 matrix_a_rep => matrix_a%mm_storage%store_batched_repl
634 DEALLOCATE (matrix_a_rs)
636 IF (unit_nr_prv /= 0)
THEN
641 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642 move_data=matrix_a%do_batched == 0)
648 IF (matrix_a%do_batched == 0)
THEN
650 DEALLOCATE (matrix_a_rep)
653 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
660 DEALLOCATE (matrix_b_rs)
662 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
667 CALL matrix_a%dist%info%mp_comm%sync()
668 CALL timeset(
"dbt_tas_dbm", handle4)
669 IF (.NOT. tr_case)
THEN
670 CALL timeset(
"dbt_tas_mm_3N", handle3)
671 CALL dbm_multiply(transa=.false., transb=.false., alpha=alpha, &
672 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674 CALL timestop(handle3)
676 CALL timeset(
"dbt_tas_mm_3T", handle3)
677 CALL dbm_multiply(transa=.false., transb=.true., alpha=alpha, &
678 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680 CALL timestop(handle3)
682 CALL matrix_a%dist%info%mp_comm%sync()
683 CALL timestop(handle4)
690 IF (.NOT. new_c)
THEN
691 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
693 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
698 IF (
PRESENT(filter_eps))
CALL dbt_tas_filter(matrix_c_rs, filter_eps)
700 IF (unit_nr_prv /= 0)
THEN
705 CALL mp_comm_mm%free()
709 IF (
PRESENT(split_opt))
THEN
710 SELECT CASE (max_mm_dim)
712 CALL mp_comm%sum(nze_c)
715 CALL mp_comm%sum(nze_c)
716 CALL mp_comm%max(nze_c)
719 nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
724 IF (unit_nr_prv > 0)
THEN
725 WRITE (unit_nr_prv,
"(T2,A)") &
727 WRITE (unit_nr_prv,
"(T4,A,T68,I13)")
"Number of matrix elements per CPU of result matrix:", &
728 (nze_c + numproc - 1)/numproc
730 WRITE (unit_nr_prv,
"(T4,A,T68,I13)")
"Optimal split factor:", nsplit_opt
738 transposed=(transc_prv .NEQV. transc), &
741 DEALLOCATE (matrix_c_rs)
742 IF (
PRESENT(filter_eps))
CALL dbt_tas_filter(matrix_c, filter_eps)
743 ELSEIF (matrix_c%do_batched > 0)
THEN
744 IF (matrix_c%mm_storage%batched_out)
THEN
745 matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
749 IF (
PRESENT(move_data_a))
THEN
752 IF (
PRESENT(move_data_b))
THEN
756 IF (
PRESENT(flop))
THEN
757 CALL mp_comm%sum(flop)
758 flop = (flop + numproc - 1)/numproc
761 IF (
PRESENT(optimize_dist))
THEN
762 IF (optimize_dist)
CALL comm_tmp%free()
764 IF (unit_nr_prv > 0)
THEN
765 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
766 WRITE (unit_nr_prv,
'(A,1X,A,1X,A,1X,A,1X,A,1X,A)')
"TAS MATRIX MULTIPLICATION DONE"
767 WRITE (unit_nr_prv,
'(A)') repeat(
"-", 80)
774 CALL matrix_a%dist%info%mp_comm%sync()
775 CALL timestop(handle2)
776 CALL timestop(handle)
788 SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789 TYPE(
dbm_type),
INTENT(IN) :: matrix_in
790 TYPE(
dbm_type),
INTENT(INOUT) :: matrix_out
791 LOGICAL,
INTENT(IN),
OPTIONAL :: local_copy
792 REAL(
dp),
INTENT(IN) :: alpha
794 LOGICAL :: local_copy_prv
797 IF (
PRESENT(local_copy))
THEN
798 local_copy_prv = local_copy
800 local_copy_prv = .false.
803 IF (alpha /= 1.0_dp)
THEN
807 IF (.NOT. local_copy_prv)
THEN
810 CALL dbm_add(matrix_out, matrix_tmp)
813 CALL dbm_add(matrix_out, matrix_in)
829 SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
833 LOGICAL,
INTENT(IN) :: transposed
834 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata, move_data
836 CHARACTER(LEN=*),
PARAMETER :: routinen =
'reshape_mm_small'
839 INTEGER(KIND=int_8),
DIMENSION(2) :: dims
840 INTEGER,
DIMENSION(2) :: pdims
841 LOGICAL :: nodata_prv
845 CALL timeset(routinen, handle)
847 IF (
PRESENT(nodata))
THEN
853 pdims = mp_comm%num_pe_cart
857 IF (transposed)
CALL swap(dims)
859 IF (.NOT. transposed)
THEN
864 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
870 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
872 IF (.NOT. nodata_prv)
CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
874 CALL timestop(handle)
902 SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903 optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904 move_data_1, move_data_2, comm_new, unit_nr)
905 TYPE(
dbt_tas_type),
INTENT(INOUT),
TARGET :: matrix1_in, matrix2_in
906 TYPE(
dbt_tas_type),
INTENT(OUT),
POINTER :: matrix1_out, matrix2_out
907 LOGICAL,
INTENT(OUT) :: new1, new2
908 LOGICAL,
INTENT(INOUT) :: trans1, trans2
909 LOGICAL,
INTENT(IN),
OPTIONAL :: optimize_dist
910 INTEGER,
INTENT(IN),
OPTIONAL :: nsplit
911 LOGICAL,
INTENT(IN),
OPTIONAL :: opt_nsplit
912 INTEGER,
INTENT(INOUT) :: split_rc_1, split_rc_2
913 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata1, nodata2
914 LOGICAL,
INTENT(INOUT),
OPTIONAL :: move_data_1, move_data_2
916 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
918 CHARACTER(LEN=*),
PARAMETER :: routinen =
'reshape_mm_compatible'
920 INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
922 INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923 INTEGER(KIND=int_8),
DIMENSION(2) :: dims1, dims2, dims_ref
924 INTEGER,
DIMENSION(2) :: pdims
925 LOGICAL :: nodata1_prv, nodata2_prv, &
926 optimize_dist_prv, trans1_newdist, &
934 CALL timeset(routinen, handle)
935 new1 = .false.; new2 = .false.
937 IF (
PRESENT(nodata1))
THEN
938 nodata1_prv = nodata1
940 nodata1_prv = .false.
943 IF (
PRESENT(nodata2))
THEN
944 nodata2_prv = nodata2
946 nodata2_prv = .false.
951 NULLIFY (matrix1_out, matrix2_out)
953 IF (
PRESENT(optimize_dist))
THEN
954 optimize_dist_prv = optimize_dist
956 optimize_dist_prv = .false.
964 IF (trans1) split_rc_1 = mod(split_rc_1, 2) + 1
966 IF (trans2) split_rc_2 = mod(split_rc_2, 2) + 1
968 IF (nze1 >= nze2)
THEN
970 split_rc_ref = split_rc_1
974 split_rc_ref = split_rc_2
978 IF (
PRESENT(nsplit))
THEN
984 IF (optimize_dist_prv)
THEN
985 cpassert(
PRESENT(comm_new))
988 IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2))
THEN
989 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
992 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993 move_data=move_data_2, nodata=nodata2, opt_nsplit=.false.)
994 IF (unit_nr_prv > 0)
THEN
995 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A,1X,A)")
"No redistribution of", &
999 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A)")
"Change split factor of", &
1002 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A)")
"Change split factor of", &
1006 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A)")
"Change split factor of", &
1009 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A)")
"Change split factor of", &
1015 IF (optimize_dist_prv)
THEN
1016 IF (unit_nr_prv > 0)
THEN
1017 WRITE (unit_nr_prv,
"(T2,A,1X,A,1X,A,1X,A)")
"Optimizing distribution of", &
1022 trans1_newdist = (split_rc_1 ==
colsplit)
1023 trans2_newdist = (split_rc_2 ==
colsplit)
1025 IF (trans1_newdist)
THEN
1027 trans1 = .NOT. trans1
1030 IF (trans2_newdist)
THEN
1032 trans2 = .NOT. trans2
1035 IF (nsplit_prv == 0)
THEN
1036 SELECT CASE (split_rc_ref)
1044 nsplit_prv = int((d1 - 1)/d2 + 1)
1047 cpassert(nsplit_prv > 0)
1053 pdims = comm_new%num_pe_cart
1069 ALLOCATE (matrix1_out)
1070 IF (.NOT. trans1_newdist)
THEN
1072 matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.true.)
1076 matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.true.)
1079 ALLOCATE (matrix2_out)
1080 IF (.NOT. trans2_newdist)
THEN
1082 matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.true.)
1085 matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.true.)
1088 IF (.NOT. nodata1_prv)
CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089 IF (.NOT. nodata2_prv)
CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1096 IF (unit_nr_prv > 0)
THEN
1097 WRITE (unit_nr_prv,
"(T2,A,1X,A)")
"Redistribution of", &
1101 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1104 ALLOCATE (matrix2_out)
1105 CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106 nodata=nodata2, move_data=move_data_2)
1109 IF (unit_nr_prv > 0)
THEN
1110 WRITE (unit_nr_prv,
"(T2,A,1X,A)")
"Redistribution of", &
1114 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115 move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1117 ALLOCATE (matrix1_out)
1118 CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119 nodata=nodata1, move_data=move_data_1)
1125 IF (
PRESENT(move_data_1) .AND. new1) move_data_1 = .true.
1126 IF (
PRESENT(move_data_2) .AND. new2) move_data_2 = .true.
1128 CALL timestop(handle)
1144 SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1147 INTEGER,
INTENT(IN) :: nsplit, split_rowcol
1148 LOGICAL,
INTENT(OUT) :: is_new
1149 LOGICAL,
INTENT(IN),
OPTIONAL :: opt_nsplit
1150 LOGICAL,
INTENT(INOUT),
OPTIONAL :: move_data
1151 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata
1153 CHARACTER(len=default_string_length) :: name
1154 INTEGER :: handle, nsplit_new, nsplit_old, &
1155 nsplit_prv, split_rc
1156 LOGICAL :: nodata_prv
1163 CHARACTER(LEN=*),
PARAMETER :: routinen =
'change_split'
1165 NULLIFY (matrix_out)
1170 split_rowcol=split_rc, nsplit=nsplit_old)
1172 IF (nsplit == 0)
THEN
1173 IF (split_rowcol == split_rc)
THEN
1174 matrix_out => matrix_in
1184 CALL timeset(routinen, handle)
1186 nodata_prv = .false.
1187 IF (
PRESENT(nodata)) nodata_prv = nodata
1190 row_blk_size=rbsize, col_blk_size=cbsize, &
1191 proc_row_dist=rdist, proc_col_dist=cdist)
1193 CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1197 IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol)
THEN
1198 matrix_out => matrix_in
1201 CALL timestop(handle)
1206 split_info=split_info)
1210 ALLOCATE (matrix_out)
1211 CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.true.)
1213 IF (.NOT. nodata_prv)
CALL dbt_tas_copy(matrix_out, matrix_in)
1215 IF (
PRESENT(move_data))
THEN
1216 IF (.NOT. nodata_prv)
THEN
1222 CALL timestop(handle)
1235 FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1237 INTEGER,
INTENT(IN) :: split_rc_a, split_rc_b
1238 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
1239 LOGICAL :: dist_compatible
1241 INTEGER :: numproc, same_local_rowcols, &
1242 split_check_a, split_check_b, &
1244 INTEGER(int_8),
ALLOCATABLE,
DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245 INTEGER,
DIMENSION(2) :: pdims_a, pdims_b
1250 dist_compatible = .false.
1256 IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b)
THEN
1257 IF (unit_nr_prv > 0)
THEN
1258 WRITE (unit_nr_prv, *)
"matrix layout a not compatible", split_check_a, split_rc_a
1259 WRITE (unit_nr_prv, *)
"matrix layout b not compatible", split_check_b, split_rc_b
1267 numproc = info_b%mp_comm%num_pe
1268 pdims_a = info_a%mp_comm%num_pe_cart
1269 pdims_b = info_b%mp_comm%num_pe_cart
1270 IF (.NOT.
array_eq(pdims_a, pdims_b))
THEN
1271 IF (unit_nr_prv > 0)
THEN
1272 WRITE (unit_nr_prv, *)
"mp dims not compatible:", pdims_a,
"|", pdims_b
1278 SELECT CASE (split_rc_a)
1287 same_local_rowcols = merge(1, 0,
array_eq(local_rowcols_a, local_rowcols_b))
1289 CALL info_a%mp_comm%sum(same_local_rowcols)
1291 IF (same_local_rowcols == numproc)
THEN
1292 dist_compatible = .true.
1294 IF (unit_nr_prv > 0)
THEN
1295 WRITE (unit_nr_prv, *)
"local rowcols not compatible"
1296 WRITE (unit_nr_prv, *)
"local rowcols A", local_rowcols_a
1297 WRITE (unit_nr_prv, *)
"local rowcols B", local_rowcols_b
1314 SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1318 LOGICAL,
INTENT(INOUT) :: trans
1319 INTEGER,
INTENT(IN) :: split_rc
1320 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata, move_data
1326 INTEGER :: dim_split_template, dim_split_matrix, &
1328 INTEGER,
DIMENSION(2) :: pdims
1329 LOGICAL :: nodata_prv, transposed
1331 CHARACTER(LEN=*),
PARAMETER :: routinen =
'reshape_mm_template'
1333 CALL timeset(routinen, handle)
1335 IF (
PRESENT(nodata))
THEN
1338 nodata_prv = .false.
1344 dim_split_template = info_template%split_rowcol
1345 dim_split_matrix = split_rc
1347 transposed = dim_split_template .NE. dim_split_matrix
1348 IF (transposed) trans = .NOT. trans
1350 pdims = info_template%mp_comm%num_pe_cart
1352 SELECT CASE (dim_split_template)
1354 IF (.NOT. transposed)
THEN
1355 ALLOCATE (row_dist, source=template%dist%row_dist)
1358 ALLOCATE (row_dist, source=template%dist%row_dist)
1362 IF (.NOT. transposed)
THEN
1364 ALLOCATE (col_dist, source=template%dist%col_dist)
1367 ALLOCATE (col_dist, source=template%dist%col_dist)
1373 IF (.NOT. transposed)
THEN
1375 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.true.)
1378 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.true.)
1381 IF (.NOT. nodata_prv)
CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1383 CALL timestop(handle)
1402 SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403 estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404 LOGICAL,
INTENT(IN) :: transa, transb, transc
1405 TYPE(
dbt_tas_type),
INTENT(INOUT),
TARGET :: matrix_a, matrix_b, matrix_c
1406 INTEGER(int_8),
INTENT(OUT) :: estimated_nze
1407 REAL(kind=
dp),
INTENT(IN),
OPTIONAL :: filter_eps
1408 INTEGER,
INTENT(IN),
OPTIONAL :: unit_nr
1409 LOGICAL,
INTENT(IN),
OPTIONAL :: retain_sparsity
1411 CHARACTER(LEN=*),
PARAMETER :: routinen =
'dbt_tas_estimate_result_nze'
1413 INTEGER :: col_size, handle, row_size
1414 INTEGER(int_8) :: col, row
1415 LOGICAL :: retain_sparsity_prv
1417 TYPE(
dbt_tas_type),
POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1421 CALL timeset(routinen, handle)
1423 IF (
PRESENT(retain_sparsity))
THEN
1424 retain_sparsity_prv = retain_sparsity
1426 retain_sparsity_prv = .false.
1429 IF (.NOT. retain_sparsity_prv)
THEN
1430 ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431 CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432 CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433 CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.true.)
1436 matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437 filter_eps=filter_eps, move_data_a=.true., move_data_b=.true., &
1438 simple_split=.true., unit_nr=unit_nr)
1442 DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1444 matrix_c_bnorm => matrix_c
1453 row_size = matrix_c%row_blk_size%data(row)
1454 col_size = matrix_c%col_blk_size%data(col)
1455 estimated_nze = estimated_nze + row_size*col_size
1461 CALL mp_comm%sum(estimated_nze)
1463 IF (.NOT. retain_sparsity_prv)
THEN
1465 DEALLOCATE (matrix_c_bnorm)
1468 CALL timestop(handle)
1484 FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes)
RESULT(nsplit)
1485 INTEGER,
INTENT(IN) :: max_mm_dim
1486 INTEGER(KIND=int_8),
INTENT(IN) :: nze_a, nze_b, nze_c
1487 INTEGER,
INTENT(IN) :: numnodes
1490 INTEGER(KIND=int_8) :: max_nze, min_nze
1491 REAL(
dp) :: s_opt_factor
1493 s_opt_factor = 1.0_dp
1495 SELECT CASE (max_mm_dim)
1497 min_nze = max(nze_b, 1_int_8)
1498 max_nze = max(maxval([nze_a, nze_c]), 1_int_8)
1500 min_nze = max(nze_c, 1_int_8)
1501 max_nze = max(maxval([nze_a, nze_b]), 1_int_8)
1503 min_nze = max(nze_a, 1_int_8)
1504 max_nze = max(maxval([nze_b, nze_c]), 1_int_8)
1509 nsplit = int(min(int(numnodes, kind=
int_8), nint(real(max_nze,
dp)/(real(min_nze,
dp)*s_opt_factor), kind=
int_8)))
1510 IF (nsplit == 0) nsplit = 1
1521 SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1524 LOGICAL,
INTENT(IN),
OPTIONAL :: nodata
1526 CHARACTER(len=default_string_length) :: name
1527 INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528 LOGICAL :: nodata_prv
1529 REAL(
dp),
DIMENSION(1, 1) :: blk_put
1530 REAL(
dp),
DIMENSION(:, :),
POINTER :: blk_get
1536 cpassert(matrix_in%valid)
1538 IF (
PRESENT(nodata))
THEN
1541 nodata_prv = .false.
1544 CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1549 CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1551 IF (.NOT. nodata_prv)
THEN
1558 blk_put(1, 1) = norm2(blk_get)
1577 SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1579 TYPE(
dbm_type),
INTENT(INOUT) :: matrix_in
1580 TYPE(
dbm_type),
INTENT(OUT) :: matrix_out
1581 LOGICAL,
INTENT(IN),
OPTIONAL :: move_data, nodata, optimize_pgrid
1583 CHARACTER(LEN=*),
PARAMETER :: routinen =
'convert_to_new_pgrid'
1585 CHARACTER(len=default_string_length) :: name
1586 INTEGER :: handle, nbcols, nbrows
1587 INTEGER,
CONTIGUOUS,
DIMENSION(:),
POINTER :: col_dist, rbsize, rcsize, row_dist
1588 INTEGER,
DIMENSION(2) :: pdims
1589 LOGICAL :: nodata_prv, optimize_pgrid_prv
1592 NULLIFY (row_dist, col_dist, rbsize, rcsize)
1594 CALL timeset(routinen, handle)
1596 IF (
PRESENT(optimize_pgrid))
THEN
1597 optimize_pgrid_prv = optimize_pgrid
1599 optimize_pgrid_prv = .true.
1602 IF (
PRESENT(nodata))
THEN
1605 nodata_prv = .false.
1610 IF (.NOT. optimize_pgrid_prv)
THEN
1612 IF (.NOT. nodata_prv)
CALL dbm_copy(matrix_out, matrix_in)
1613 CALL timestop(handle)
1619 nbrows =
SIZE(rbsize)
1620 nbcols =
SIZE(rcsize)
1622 pdims = mp_comm_cart%num_pe_cart
1624 ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1629 DEALLOCATE (row_dist, col_dist)
1631 CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1634 IF (.NOT. nodata_prv)
THEN
1636 IF (
PRESENT(move_data))
THEN
1637 IF (move_data)
CALL dbm_clear(matrix_in)
1641 CALL timestop(handle)
1653 ALLOCATE (matrix%mm_storage)
1654 matrix%mm_storage%batched_out = .false.
1667 CALL matrix%dist%info%mp_comm%sync()
1668 CALL timeset(
"dbt_tas_total", handle)
1670 IF (matrix%do_batched == 0)
RETURN
1672 IF (matrix%mm_storage%batched_out)
THEN
1673 CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1678 matrix%mm_storage%batched_out = .false.
1680 DEALLOCATE (matrix%mm_storage)
1683 CALL matrix%dist%info%mp_comm%sync()
1684 CALL timestop(handle)
1700 INTEGER,
INTENT(IN),
OPTIONAL :: state
1701 LOGICAL,
INTENT(IN),
OPTIONAL :: opt_grid
1703 IF (
PRESENT(opt_grid))
THEN
1704 matrix%has_opt_pgrid = opt_grid
1705 matrix%dist%info%strict_split(1) = .true.
1708 IF (
PRESENT(state))
THEN
1709 matrix%do_batched = state
1713 IF (matrix%has_opt_pgrid)
THEN
1714 matrix%dist%info%strict_split(1) = .true.
1716 matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1719 matrix%dist%info%strict_split(1) = .true.
1721 cpabort(
"should not happen")
1734 LOGICAL,
INTENT(IN),
OPTIONAL :: warn
1736 IF (matrix%do_batched == 0)
RETURN
1737 associate(storage => matrix%mm_storage)
1738 IF (
PRESENT(warn))
THEN
1739 IF (warn .AND. matrix%do_batched == 3)
THEN
1740 CALL cp_warn(__location__, &
1741 "Optimizations for batched multiplication are disabled because of conflicting data access")
1744 IF (storage%batched_out .AND. matrix%do_batched == 3)
THEN
1747 storage%store_batched_repl, move_data=.true.)
1749 CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.true., &
1750 transposed=storage%batched_trans, move_data=.true.)
1752 DEALLOCATE (storage%store_batched)
1755 IF (
ASSOCIATED(storage%store_batched_repl))
THEN
1757 DEALLOCATE (storage%store_batched_repl)
subroutine, public dbm_multiply(transa, transb, alpha, matrix_a, matrix_b, beta, matrix_c, retain_sparsity, filter_eps, flop)
Computes matrix product: matrix_c = alpha * matrix_a * matrix_b + beta * matrix_c.
subroutine, public dbm_redistribute(matrix, redist)
Copies content of matrix_b into matrix_a. Matrices may have different distributions.
subroutine, public dbm_zero(matrix)
Sets all blocks in the given matrix to zero.
subroutine, public dbm_clear(matrix)
Remove all blocks from given matrix, but does not release the underlying memory.
subroutine, public dbm_create_from_template(matrix, name, template)
Creates a new matrix from given template, reusing dist and row/col_block_sizes.
pure integer function, public dbm_get_nze(matrix)
Returns the number of local Non-Zero Elements of the given matrix.
subroutine, public dbm_scale(matrix, alpha)
Multiplies all entries in the given matrix by the given factor alpha.
subroutine, public dbm_distribution_release(dist)
Decreases the reference counter of the given distribution.
type(dbm_distribution_obj) function, public dbm_get_distribution(matrix)
Returns the distribution of the given matrix.
subroutine, public dbm_create(matrix, name, dist, row_block_sizes, col_block_sizes)
Creates a new matrix.
integer function, dimension(:), pointer, contiguous, public dbm_get_row_block_sizes(matrix)
Returns the row block sizes of the given matrix.
character(len=default_string_length) function, public dbm_get_name(matrix)
Returns the name of the matrix of the given matrix.
subroutine, public dbm_add(matrix_a, matrix_b)
Adds matrix_b to matrix_a.
subroutine, public dbm_copy(matrix_a, matrix_b)
Copies content of matrix_b into matrix_a. Matrices must have the same row/col block sizes and distrib...
subroutine, public dbm_release(matrix)
Releases a matrix and all its ressources.
integer function, dimension(:), pointer, contiguous, public dbm_get_col_block_sizes(matrix)
Returns the column block sizes of the given matrix.
subroutine, public dbm_distribution_new(dist, mp_comm, row_dist_block, col_dist_block)
Creates a new two dimensional distribution.
Tall-and-skinny matrices: base routines similar to DBM API, mostly wrappers around existing DBM routi...
integer(kind=int_8) function, public dbt_tas_get_nze_total(matrix)
Get total number of non-zero elements.
subroutine, public dbt_tas_iterator_start(iter, matrix_in)
As dbm_iterator_start.
logical function, public dbt_tas_iterator_blocks_left(iter)
As dbm_iterator_blocks_left.
subroutine, public dbt_tas_get_info(matrix, nblkrows_total, nblkcols_total, local_rows, local_cols, proc_row_dist, proc_col_dist, row_blk_size, col_blk_size, distribution, name)
...
pure integer(kind=int_8) function, public dbt_tas_nblkrows_total(matrix)
...
subroutine, public dbt_tas_copy(matrix_b, matrix_a, summation)
Copy matrix_a to matrix_b.
subroutine, public dbt_tas_iterator_stop(iter)
As dbm_iterator_stop.
pure integer(kind=int_8) function, public dbt_tas_nblkcols_total(matrix)
...
subroutine, public dbt_tas_distribution_new(dist, mp_comm, row_dist, col_dist, split_info, nosplit)
create new distribution. Exactly like dbm_distribution_new but with custom types for row_dist and col...
subroutine, public dbt_tas_filter(matrix, eps)
As dbm_filter.
subroutine, public dbt_tas_clear(matrix)
Clear matrix (erase all data)
type(dbt_tas_split_info) function, pointer, public dbt_tas_info(matrix)
get info on mpi grid splitting
subroutine, public dbt_tas_destroy(matrix)
...
subroutine, public dbt_tas_put_block(matrix, row, col, block, summation)
As dbm_put_block.
Global data (distribution and block sizes) for tall-and-skinny matrices For very sparse matrices with...
subroutine, public dbt_tas_default_distvec(nblk, nproc, blk_size, dist)
get a load-balanced and randomized distribution along one tensor dimension
type(dbt_tas_dist_arb) function, public dbt_tas_dist_arb_default(nprowcol, nmrowcol, dbt_sizes)
Distribution that is more or less cyclic (round robin) and load balanced with different weights for e...
tall-and-skinny matrices: Input / Output
integer function, public prep_output_unit(unit_nr)
...
subroutine, public dbt_tas_write_matrix_info(matrix, unit_nr, full_info)
Write basic infos of tall-and-skinny matrix: block dimensions, full dimensions, process grid dimensio...
subroutine, public dbt_tas_write_dist(matrix, unit_nr, full_info)
Write info on tall-and-skinny matrix distribution & load balance.
subroutine, public dbt_tas_write_split_info(info, unit_nr, name)
Print info on how matrix is split.
Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA algor...
subroutine, public dbt_tas_set_batched_state(matrix, state, opt_grid)
set state flags during batched multiplication
subroutine, public dbt_tas_batched_mm_init(matrix)
...
subroutine, public dbt_tas_batched_mm_finalize(matrix)
...
recursive subroutine, public dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, optimize_dist, split_opt, filter_eps, flop, move_data_a, move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to arguments...
subroutine, public dbt_tas_batched_mm_complete(matrix, warn)
...
communication routines to reshape / replicate / merge tall-and-skinny matrices.
subroutine, public dbt_tas_merge(matrix_out, matrix_in, summation, move_data)
Merge submatrices of matrix_in to matrix_out by sum.
subroutine, public dbt_tas_replicate(matrix_in, info, matrix_out, nodata, move_data)
Replicate matrix_in such that each submatrix of matrix_out is an exact copy of matrix_in.
recursive subroutine, public dbt_tas_reshape(matrix_in, matrix_out, summation, transposed, move_data)
copy data (involves reshape)
methods to split tall-and-skinny matrices along longest dimension. Basically, we are splitting proces...
subroutine, public dbt_tas_release_info(split_info)
...
integer, parameter, public rowsplit
subroutine, public dbt_tas_get_split_info(info, mp_comm, nsplit, igroup, mp_comm_group, split_rowcol, pgrid_offset)
Get info on split.
integer, parameter, public colsplit
subroutine, public dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit, own_comm, opt_nsplit)
Split Cartesian process grid using a default split heuristic.
real(dp), parameter, public default_nsplit_accept_ratio
subroutine, public dbt_tas_info_hold(split_info)
...
logical function, public accept_pgrid_dims(dims, relative)
Whether to accept proposed process grid dimensions (based on ratio of dimensions)
DBT tall-and-skinny base types. Mostly wrappers around existing DBM routines.
often used utilities for tall-and-skinny matrices
Defines the basic variable types.
integer, parameter, public int_8
integer, parameter, public dp
integer, parameter, public default_string_length
Interface to the message passing library MPI.
type for blocks of size one
type for arbitrary distributions
type for cyclic (round robin) distribution: