9 USE iso_c_binding,
ONLY: c_null_ptr, &
11#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
13 USE iso_c_binding,
ONLY: c_associated, &
15 USE spla,
ONLY: spla_pu_host, &
19 spla_op_conj_transpose, &
26 spla_ctx_set_op_threshold_gpu, &
32#include "./base/base_uses.f90"
38 CHARACTER(len=*),
PARAMETER,
PRIVATE :: moduleN =
'local_gemm_api'
43 INTEGER,
PARAMETER,
PUBLIC :: &
47 INTEGER,
PRIVATE :: do_dgemm = 1
50 TYPE(c_ptr) :: spla_context = c_null_ptr
52 PROCEDURE, pass(ctx), non_overridable :: create => local_gemm_create
53 PROCEDURE, pass(ctx), non_overridable :: destroy => local_gemm_destroy
54 PROCEDURE, pass(ctx), non_overridable :: set_op_threshold_gpu => local_gemm_set_op_threshold_gpu
55 PROCEDURE, pass(ctx), non_overridable :: gemm =>
local_gemm
78 alpha, A, lda, B, ldb, &
80 CHARACTER,
INTENT(in) :: opA
81 CHARACTER,
INTENT(in) :: opB
82 INTEGER,
INTENT(in) :: m
83 INTEGER,
INTENT(in) :: n
84 INTEGER,
INTENT(in) :: k
85 REAL(8),
INTENT(in) :: alpha
86#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
87 REAL(8),
DIMENSION(*),
INTENT(in),
TARGET :: A
89 REAL(8),
DIMENSION(:, :),
INTENT(in),
TARGET :: A
91 INTEGER,
INTENT(in) :: lda
92#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
93 REAL(8),
DIMENSION(*),
INTENT(in),
TARGET :: B
95 REAL(8),
DIMENSION(:, :),
INTENT(in),
TARGET :: B
98 INTEGER,
INTENT(in) :: ldb
99 REAL(8),
INTENT(in) :: beta
100#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
101 REAL(8),
DIMENSION(*),
INTENT(inout),
TARGET ::C
103 REAL(8),
DIMENSION(:, :),
INTENT(inout),
TARGET :: C
105 INTEGER,
INTENT(in) :: ldc
110#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
111 INTEGER :: spla_op_A, spla_op_B, spla_error
113 CHARACTER(LEN=*),
PARAMETER :: routineN =
'local_gemm'
114 CALL timeset(routinen, handle)
117#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
120 IF (opa ==
'N') spla_op_a = spla_op_none
121 IF (opa ==
'T') spla_op_a = spla_op_transpose
123 IF (opb ==
'N') spla_op_b = spla_op_none
124 IF (opb ==
'T') spla_op_b = spla_op_transpose
127 cpassert(is_contiguous(a))
128 cpassert(is_contiguous(b))
129 cpassert(is_contiguous(c))
133 spla_error = spla_dgemm(spla_op_a, spla_op_b, &
137 beta, c_loc(c), ldc, ctx%spla_context)
138 cpassert(spla_error == spla_success)
141 CALL dgemm(opa, opb, m, n, k, alpha, &
143 b, ldb, beta, c, ldc)
144#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
149 CALL timestop(handle)
158 SUBROUTINE local_gemm_create(ctx, pu)
160 INTEGER,
INTENT(in) :: pu
162#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
165 IF (.NOT. c_associated(ctx%spla_context))
THEN
169 error_ = spla_ctx_create(ctx%spla_context, pu)
170 cpassert(error_ == spla_success)
172 ctx%spla_context = c_null_ptr
177 ctx%spla_context = c_null_ptr
179 END SUBROUTINE local_gemm_create
185 SUBROUTINE local_gemm_destroy(ctx)
188#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
194 error_ = spla_ctx_destroy(ctx%spla_context)
195 cpassert(error_ == spla_success)
198 ctx%spla_context = c_null_ptr
199 END SUBROUTINE local_gemm_destroy
206 SUBROUTINE local_gemm_set_op_threshold_gpu(ctx, opThresholdGPU)
208 INTEGER,
INTENT(in) :: opThresholdGPU
210#if defined(__SPLA) && defined(__OFFLOAD_GEMM)
214 error__ = spla_ctx_set_op_threshold_gpu(ctx%spla_context, opthresholdgpu)
217 mark_used(opthresholdgpu)
219 END SUBROUTINE local_gemm_set_op_threshold_gpu
226 INTEGER,
INTENT(IN) :: dgemm_library
228 do_dgemm = dgemm_library
static void dgemm(const char transa, const char transb, const int m, const int n, const int k, const double alpha, const double *a, const int lda, const double *b, const int ldb, const double beta, double *c, const int ldc)
Convenient wrapper to hide Fortran nature of dgemm_, swapping a and b.
subroutine, public local_gemm_set_library(dgemm_library)
...
integer, parameter, public local_gemm_pu_gpu
subroutine local_gemm(opa, opb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, ctx)
...
integer, parameter, public local_gemm_pu_host
Fortran API for the offload package, which is written in C.
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()