9 USE iso_c_binding,
ONLY: c_loc, &
12 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
14 USE spla,
ONLY: spla_pu_host, &
18 spla_op_conj_transpose, &
25 spla_ctx_set_op_threshold_gpu, &
31 #include "./base/base_uses.f90"
37 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'local_gemm_api'
45 INTEGER,
PARAMETER,
PUBLIC :: &
49 INTEGER,
PRIVATE :: do_dgemm = 1
71 alpha, A, lda, B, ldb, &
73 CHARACTER,
INTENT(in) :: opa
74 CHARACTER,
INTENT(in) :: opb
75 INTEGER,
INTENT(in) :: m
76 INTEGER,
INTENT(in) :: n
77 INTEGER,
INTENT(in) :: k
78 REAL(8),
INTENT(in) :: alpha
79 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
80 REAL(8),
DIMENSION(*),
INTENT(in),
TARGET :: a
82 REAL(8),
DIMENSION(:, :),
INTENT(in),
TARGET :: a
84 INTEGER,
INTENT(in) :: lda
85 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
86 REAL(8),
DIMENSION(*),
INTENT(in),
TARGET :: b
88 REAL(8),
DIMENSION(:, :),
INTENT(in),
TARGET :: b
91 INTEGER,
INTENT(in) :: ldb
92 REAL(8),
INTENT(in) :: beta
93 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
94 REAL(8),
DIMENSION(*),
INTENT(inout),
TARGET ::c
96 REAL(8),
DIMENSION(:, :),
INTENT(inout),
TARGET :: c
98 INTEGER,
INTENT(in) :: ldc
99 TYPE(c_ptr),
OPTIONAL,
INTENT(inout) :: ctx
103 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
104 INTEGER :: spla_op_a, spla_op_b, spla_error
106 CHARACTER(LEN=*),
PARAMETER :: routinen =
'local_gemm'
107 CALL timeset(routinen, handle)
110 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
113 IF (opa ==
'N') spla_op_a = spla_op_none
114 IF (opa ==
'T') spla_op_a = spla_op_transpose
116 IF (opb ==
'N') spla_op_b = spla_op_none
117 IF (opb ==
'T') spla_op_b = spla_op_transpose
120 cpassert(is_contiguous(a))
121 cpassert(is_contiguous(b))
122 cpassert(is_contiguous(c))
126 spla_error = spla_dgemm(spla_op_a, spla_op_b, &
130 beta, c_loc(c), ldc, ctx)
131 cpassert(spla_error == spla_success)
134 CALL dgemm(opa, opb, m, n, k, alpha, &
136 b, ldb, beta, c, ldc)
137 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
142 CALL timestop(handle)
152 TYPE(c_ptr),
INTENT(out) :: ctx
153 INTEGER,
INTENT(in) :: pu
155 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
161 error_ = spla_ctx_create(ctx, pu)
162 cpassert(error_ == spla_success)
178 TYPE(c_ptr),
INTENT(inout) :: ctx
180 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
186 error_ = spla_ctx_destroy(ctx)
187 cpassert(error_ == spla_success)
202 INTEGER,
INTENT(in) :: opthresholdgpu
204 #if defined(__SPLA) && defined(__OFFLOAD_GEMM)
208 error__ = spla_ctx_set_op_threshold_gpu(ctx, opthresholdgpu)
211 mark_used(opthresholdgpu)
220 INTEGER,
INTENT(IN) :: dgemm_library
222 do_dgemm = dgemm_library
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
subroutine, public local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
...
subroutine, public local_gemm_set_library(dgemm_library)
...
integer, parameter, public local_gemm_pu_gpu
subroutine, public local_gemm_destroy(ctx)
release resources associated to a gemm context
subroutine, public local_gemm_create(ctx, pu)
create a context for handling gemm offloading
subroutine, public local_gemm(opA, opB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, ctx)
...
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()