8#include "../../offload/offload_runtime.h"
9#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_GRID)
19#define GRID_DO_COLLOCATE 0
20#include "../common/grid_common.h"
25#include "../common/grid_process_vab.h"
28#error "OpenMP should not be used in .cu files to accommodate HIP."
33#define GRID_N_CXYZ_REGISTERS 20
41__device__
static inline void
42add_to_register(
const double value,
const int index, cxyz_store *store) {
45 store->regs[0] += value;
48 store->regs[1] += value;
51 store->regs[2] += value;
54 store->regs[3] += value;
57 store->regs[4] += value;
60 store->regs[5] += value;
63 store->regs[6] += value;
66 store->regs[7] += value;
69 store->regs[8] += value;
72 store->regs[9] += value;
75 store->regs[10] += value;
78 store->regs[11] += value;
81 store->regs[12] += value;
84 store->regs[13] += value;
87 store->regs[14] += value;
90 store->regs[15] += value;
93 store->regs[16] += value;
96 store->regs[17] += value;
99 store->regs[18] += value;
102 store->regs[19] += value;
111__device__
static void gridpoint_to_cxyz(
const double dx,
const double dy,
112 const double dz,
const double zetp,
113 const int lp,
const double *gridpoint,
117 const double r2 = dx * dx + dy * dy + dz * dz;
118 const double gaussian = exp(-zetp * r2);
121 const double prefactor = __ldg(gridpoint) *
gaussian;
124 if (store->offset == 0) {
125 store->regs[0] += prefactor;
127 store->regs[1] += prefactor * dx;
128 store->regs[2] += prefactor * dy;
129 store->regs[3] += prefactor * dz;
131 store->regs[4] += prefactor * dx * dx;
132 store->regs[5] += prefactor * dx * dy;
133 store->regs[6] += prefactor * dx * dz;
134 store->regs[7] += prefactor * dy * dy;
135 store->regs[8] += prefactor * dy * dz;
136 store->regs[9] += prefactor * dz * dz;
138 store->regs[10] += prefactor * dx * dx * dx;
139 store->regs[11] += prefactor * dx * dx * dy;
140 store->regs[12] += prefactor * dx * dx * dz;
141 store->regs[13] += prefactor * dx * dy * dy;
142 store->regs[14] += prefactor * dx * dy * dz;
143 store->regs[15] += prefactor * dx * dz * dz;
144 store->regs[16] += prefactor * dy * dy * dy;
145 store->regs[17] += prefactor * dy * dy * dz;
146 store->regs[18] += prefactor * dy * dz * dz;
147 store->regs[19] += prefactor * dz * dz * dz;
152 }
else if (store->offset == 20) {
153 store->regs[0] += prefactor * dx * dx * dx * dx;
154 store->regs[1] += prefactor * dx * dx * dx * dy;
155 store->regs[2] += prefactor * dx * dx * dx * dz;
156 store->regs[3] += prefactor * dx * dx * dy * dy;
157 store->regs[4] += prefactor * dx * dx * dy * dz;
158 store->regs[5] += prefactor * dx * dx * dz * dz;
159 store->regs[6] += prefactor * dx * dy * dy * dy;
160 store->regs[7] += prefactor * dx * dy * dy * dz;
161 store->regs[8] += prefactor * dx * dy * dz * dz;
162 store->regs[9] += prefactor * dx * dz * dz * dz;
163 store->regs[10] += prefactor * dy * dy * dy * dy;
164 store->regs[11] += prefactor * dy * dy * dy * dz;
165 store->regs[12] += prefactor * dy * dy * dz * dz;
166 store->regs[13] += prefactor * dy * dz * dz * dz;
167 store->regs[14] += prefactor * dz * dz * dz * dz;
169 store->regs[15] += prefactor * dx * dx * dx * dx * dx;
170 store->regs[16] += prefactor * dx * dx * dx * dx * dy;
171 store->regs[17] += prefactor * dx * dx * dx * dx * dz;
172 store->regs[18] += prefactor * dx * dx * dx * dy * dy;
173 store->regs[19] += prefactor * dx * dx * dx * dy * dz;
175 }
else if (store->offset == 40) {
176 store->regs[0] += prefactor * dx * dx * dx * dz * dz;
177 store->regs[1] += prefactor * dx * dx * dy * dy * dy;
178 store->regs[2] += prefactor * dx * dx * dy * dy * dz;
179 store->regs[3] += prefactor * dx * dx * dy * dz * dz;
180 store->regs[4] += prefactor * dx * dx * dz * dz * dz;
181 store->regs[5] += prefactor * dx * dy * dy * dy * dy;
182 store->regs[6] += prefactor * dx * dy * dy * dy * dz;
183 store->regs[7] += prefactor * dx * dy * dy * dz * dz;
184 store->regs[8] += prefactor * dx * dy * dz * dz * dz;
185 store->regs[9] += prefactor * dx * dz * dz * dz * dz;
186 store->regs[10] += prefactor * dy * dy * dy * dy * dy;
187 store->regs[11] += prefactor * dy * dy * dy * dy * dz;
188 store->regs[12] += prefactor * dy * dy * dy * dz * dz;
189 store->regs[13] += prefactor * dy * dy * dz * dz * dz;
190 store->regs[14] += prefactor * dy * dz * dz * dz * dz;
191 store->regs[15] += prefactor * dz * dz * dz * dz * dz;
193 store->regs[16] += prefactor * dx * dx * dx * dx * dx * dx;
194 store->regs[17] += prefactor * dx * dx * dx * dx * dx * dy;
195 store->regs[18] += prefactor * dx * dx * dx * dx * dx * dz;
196 store->regs[19] += prefactor * dx * dx * dx * dx * dy * dy;
198 }
else if (store->offset == 60) {
199 store->regs[0] += prefactor * dx * dx * dx * dx * dy * dz;
200 store->regs[1] += prefactor * dx * dx * dx * dx * dz * dz;
201 store->regs[2] += prefactor * dx * dx * dx * dy * dy * dy;
202 store->regs[3] += prefactor * dx * dx * dx * dy * dy * dz;
203 store->regs[4] += prefactor * dx * dx * dx * dy * dz * dz;
204 store->regs[5] += prefactor * dx * dx * dx * dz * dz * dz;
205 store->regs[6] += prefactor * dx * dx * dy * dy * dy * dy;
206 store->regs[7] += prefactor * dx * dx * dy * dy * dy * dz;
207 store->regs[8] += prefactor * dx * dx * dy * dy * dz * dz;
208 store->regs[9] += prefactor * dx * dx * dy * dz * dz * dz;
209 store->regs[10] += prefactor * dx * dx * dz * dz * dz * dz;
210 store->regs[11] += prefactor * dx * dy * dy * dy * dy * dy;
211 store->regs[12] += prefactor * dx * dy * dy * dy * dy * dz;
212 store->regs[13] += prefactor * dx * dy * dy * dy * dz * dz;
213 store->regs[14] += prefactor * dx * dy * dy * dz * dz * dz;
214 store->regs[15] += prefactor * dx * dy * dz * dz * dz * dz;
215 store->regs[16] += prefactor * dx * dz * dz * dz * dz * dz;
216 store->regs[17] += prefactor * dy * dy * dy * dy * dy * dy;
217 store->regs[18] += prefactor * dy * dy * dy * dy * dy * dz;
218 store->regs[19] += prefactor * dy * dy * dy * dy * dz * dz;
219 }
else if (store->offset == 80) {
220 store->regs[0] += prefactor * dy * dy * dy * dz * dz * dz;
221 store->regs[1] += prefactor * dy * dy * dz * dz * dz * dz;
222 store->regs[2] += prefactor * dy * dz * dz * dz * dz * dz;
223 store->regs[3] += prefactor * dz * dz * dz * dz * dz * dz;
225 store->regs[4] += prefactor * dx * dx * dx * dx * dx * dx * dx;
226 store->regs[5] += prefactor * dx * dx * dx * dx * dx * dx * dy;
227 store->regs[6] += prefactor * dx * dx * dx * dx * dx * dx * dz;
228 store->regs[7] += prefactor * dx * dx * dx * dx * dx * dy * dy;
229 store->regs[8] += prefactor * dx * dx * dx * dx * dx * dy * dz;
230 store->regs[9] += prefactor * dx * dx * dx * dx * dx * dz * dz;
231 store->regs[10] += prefactor * dx * dx * dx * dx * dy * dy * dy;
232 store->regs[11] += prefactor * dx * dx * dx * dx * dy * dy * dz;
233 store->regs[12] += prefactor * dx * dx * dx * dx * dy * dz * dz;
234 store->regs[13] += prefactor * dx * dx * dx * dx * dz * dz * dz;
235 store->regs[14] += prefactor * dx * dx * dx * dy * dy * dy * dy;
236 store->regs[15] += prefactor * dx * dx * dx * dy * dy * dy * dz;
237 store->regs[16] += prefactor * dx * dx * dx * dy * dy * dz * dz;
238 store->regs[17] += prefactor * dx * dx * dx * dy * dz * dz * dz;
239 store->regs[18] += prefactor * dx * dx * dx * dz * dz * dz * dz;
240 store->regs[19] += prefactor * dx * dx * dy * dy * dy * dy * dy;
245 for (
int i = 0;
i < GRID_N_CXYZ_REGISTERS;
i++) {
246 double val = prefactor;
248 for (
int j = 0; j <
a.l[0]; j++) {
251 for (
int j = 0; j <
a.l[1]; j++) {
254 for (
int j = 0; j <
a.l[2]; j++) {
257 add_to_register(val,
i, store);
266__device__
static void grid_to_cxyz(
const kernel_params *params,
267 const smem_task *task,
const double *
grid,
275 for (
int offset = 0; offset <
ncoset(task->lp);
276 offset += GRID_N_CXYZ_REGISTERS) {
278 double cxyz_regs[GRID_N_CXYZ_REGISTERS] = {0.0};
279 cxyz_store store = {.regs = cxyz_regs, .offset = offset};
281 if (task->use_orthorhombic_kernel) {
289 for (
int i = 0;
i < GRID_N_CXYZ_REGISTERS;
i++) {
290 if (
i + offset <
ncoset(task->lp)) {
291 atomicAddDouble(&cxyz[
i + offset], cxyz_regs[
i]);
302template <
bool COMPUTE_TAU>
310 for (
int i = threadIdx.x;
i < task->nsgf_setb;
i += blockDim.x) {
311 for (
int j = threadIdx.y; j < task->nsgf_seta; j += blockDim.y) {
312 double block_val = 0.0;
313 const int jco_start =
ncoset(task->lb_min_basis - 1) + threadIdx.z;
314 const int jco_end =
ncoset(task->lb_max_basis);
315 for (
int jco = jco_start; jco < jco_end; jco += blockDim.z) {
317 const double sphib = task->sphib[
i * task->maxcob + jco];
318 const int ico_start =
ncoset(task->la_min_basis - 1);
319 const int ico_end =
ncoset(task->la_max_basis);
320 for (
int ico = ico_start; ico < ico_end; ico++) {
323 get_hab(a, b, task->zeta, task->zetb, cab, COMPUTE_TAU);
324 const double sphia = task->sphia[j * task->maxcoa + ico];
325 block_val += hab * sphia * sphib;
328 if (task->block_transposed) {
329 atomicAddDouble(&task->hab_block[j * task->nsgfb +
i], block_val);
331 atomicAddDouble(&task->hab_block[
i * task->nsgfa + j], block_val);
342template <
bool COMPUTE_TAU>
343__device__
static void store_forces_and_virial(
const kernel_params *params,
344 const smem_task *task,
347 for (
int i = threadIdx.x;
i < task->nsgf_setb;
i += blockDim.x) {
348 for (
int j = threadIdx.y; j < task->nsgf_seta; j += blockDim.y) {
350 if (task->block_transposed) {
351 block_val = task->pab_block[j * task->nsgfb +
i] * task->off_diag_twice;
353 block_val = task->pab_block[
i * task->nsgfa + j] * task->off_diag_twice;
355 const int jco_start =
ncoset(task->lb_min_basis - 1) + threadIdx.z;
356 const int jco_end =
ncoset(task->lb_max_basis);
357 for (
int jco = jco_start; jco < jco_end; jco += blockDim.z) {
358 const double sphib = task->sphib[
i * task->maxcob + jco];
359 const int ico_start =
ncoset(task->la_min_basis - 1);
360 const int ico_end =
ncoset(task->la_max_basis);
361 for (
int ico = ico_start; ico < ico_end; ico++) {
362 const double sphia = task->sphia[j * task->maxcoa + ico];
363 const double pabval = block_val * sphia * sphib;
366 for (
int k = 0; k < 3; k++) {
367 const double force_a =
368 get_force_a(a, b, k, task->zeta, task->zetb, cab, COMPUTE_TAU);
369 atomicAddDouble(&task->forces_a[k], force_a * pabval);
370 const double force_b =
get_force_b(a, b, k, task->zeta, task->zetb,
371 task->rab, cab, COMPUTE_TAU);
372 atomicAddDouble(&task->forces_b[k], force_b * pabval);
374 if (params->virial != NULL) {
375 for (
int k = 0; k < 3; k++) {
376 for (
int l = 0; l < 3; l++) {
378 a, b, k, l, task->zeta, task->zetb, cab, COMPUTE_TAU);
379 const double virial_b =
380 get_virial_b(a, b, k, l, task->zeta, task->zetb, task->rab,
382 const double virial = pabval * (virial_a + virial_b);
383 atomicAddDouble(¶ms->virial[k * 3 + l], virial);
398__device__
static void zero_cxyz(
const smem_task *task,
double *cxyz) {
399 if (threadIdx.z == 0 && threadIdx.y == 0) {
400 for (
int i = threadIdx.x;
i <
ncoset(task->lp);
i += blockDim.x) {
411template <
bool COMPUTE_TAU,
bool CALCULATE_FORCES>
412__device__
static void integrate_kernel(
const kernel_params *params) {
415 __shared__ smem_task task;
416 load_task(params, &task);
419 if (2.0 * task.radius < task.dh_max) {
424 extern __shared__
double shared_memory[];
425 double *smem_cab = &shared_memory[params->smem_cab_offset];
426 double *smem_alpha = &shared_memory[params->smem_alpha_offset];
427 double *smem_cxyz = &shared_memory[params->smem_cxyz_offset];
431 if (params->smem_cab_length < task.n1 * task.n2) {
432 cab.
data = malloc_cab(&task);
437 zero_cab(&cab, task.n1 * task.n2);
440 zero_cxyz(&task, smem_cxyz);
441 grid_to_cxyz(params, &task, params->grid, smem_cxyz);
444 store_hab<COMPUTE_TAU>(&task, &cab);
445 if (CALCULATE_FORCES) {
446 store_forces_and_virial<COMPUTE_TAU>(params, &task, &cab);
449 if (params->smem_cab_length < task.n1 * task.n2) {
458__global__
static void grid_integrate_density(
const kernel_params params) {
459 integrate_kernel<false, false>(¶ms);
466__global__
static void grid_integrate_tau(
const kernel_params params) {
467 integrate_kernel<true, false>(¶ms);
474__global__
static void
475grid_integrate_density_forces(
const kernel_params params) {
476 integrate_kernel<false, true>(¶ms);
483__global__
static void grid_integrate_tau_forces(
const kernel_params params) {
484 integrate_kernel<true, true>(¶ms);
491void grid_gpu_integrate_one_grid_level(
492 const grid_gpu_task_list *task_list,
const int first_task,
493 const int last_task,
const bool compute_tau,
const grid_gpu_layout *layout,
494 const offloadStream_t stream,
const double *pab_blocks_dev,
495 const double *grid_dev,
double *hab_blocks_dev,
double *forces_dev,
496 double *virial_dev,
int *lp_diff) {
499 const bool calculate_forces = (forces_dev != NULL);
500 const bool calculate_virial = (virial_dev != NULL);
501 assert(!calculate_virial || calculate_forces);
505 const int la_max = task_list->lmax + ldiffs.
la_max_diff;
506 const int lb_max = task_list->lmax + ldiffs.
lb_max_diff;
507 const int lp_max = la_max + lb_max;
509 const int ntasks = last_task - first_task + 1;
514 init_constant_memory();
520 const int alpha_len = 3 * (lb_max + 1) * (la_max + 1) * (lp_max + 1);
521 const int cxyz_len =
ncoset(lp_max);
523 const size_t smem_per_block =
524 (alpha_len + cxyz_len + cab_len) *
sizeof(
double);
527 kernel_params params;
528 params.smem_cab_length = cab_len;
529 params.smem_cab_offset = 0;
530 params.smem_alpha_offset = params.smem_cab_offset + cab_len;
531 params.smem_cxyz_offset = params.smem_alpha_offset + alpha_len;
532 params.first_task = first_task;
533 params.grid = grid_dev;
534 params.tasks = task_list->tasks_dev;
535 params.pab_blocks = pab_blocks_dev;
536 params.hab_blocks = hab_blocks_dev;
537 params.forces = forces_dev;
538 params.virial = virial_dev;
543 memcpy(params.dh, layout->dh, 9 *
sizeof(
double));
544 memcpy(params.dh_inv, layout->dh_inv, 9 *
sizeof(
double));
545 memcpy(params.npts_global, layout->npts_global, 3 *
sizeof(
int));
546 memcpy(params.npts_local, layout->npts_local, 3 *
sizeof(
int));
547 memcpy(params.shift_local, layout->shift_local, 3 *
sizeof(
int));
550 const int nblocks = ntasks;
551 const dim3 threads_per_block(4, 4, 4);
553 if (!compute_tau && !calculate_forces) {
554 grid_integrate_density<<<nblocks, threads_per_block, smem_per_block,
556 }
else if (compute_tau && !calculate_forces) {
557 grid_integrate_tau<<<nblocks, threads_per_block, smem_per_block, stream>>>(
559 }
else if (!compute_tau && calculate_forces) {
560 grid_integrate_density_forces<<<nblocks, threads_per_block, smem_per_block,
562 }
else if (compute_tau && calculate_forces) {
563 grid_integrate_tau_forces<<<nblocks, threads_per_block, smem_per_block,
566 OFFLOAD_CHECK(offloadGetLastError());
static int imin(int x, int y)
Returns the smaller of the two integers (missing from the C standard).
static void const int const int const int const int const int const double const int const int const int int GRID_CONST_WHEN_COLLOCATE double GRID_CONST_WHEN_INTEGRATE double * grid
static void const int const int i
static void store_hab(const grid_basis_set *ibasis, const grid_basis_set *jbasis, const int iset, const int jset, const bool transpose, const double *hab, double *block)
Transforms hab from prim. cartesian to contracted spherical basis.
static GRID_DEVICE double get_hab(const orbital a, const orbital b, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns element i,j of hab matrix.
static process_ldiffs process_get_ldiffs(bool calculate_forces, bool calculate_virial, bool compute_tau)
Returns difference in angular momentum range for given flags.
static GRID_DEVICE double get_force_b(const orbital a, const orbital b, const int i, const double zeta, const double zetb, const double rab[3], const cab_store *cab, const bool compute_tau)
Returns i'th component of force on atom b.
static GRID_DEVICE double get_virial_b(const orbital a, const orbital b, const int i, const int j, const double zeta, const double zetb, const double rab[3], const cab_store *cab, const bool compute_tau)
Returns element i,j of virial on atom b.
static GRID_DEVICE double get_virial_a(const orbital a, const orbital b, const int i, const int j, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns element i,j of virial on atom a.
static GRID_DEVICE double get_force_a(const orbital a, const orbital b, const int i, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns i'th component of force on atom a.
static void cab_to_cxyz(const int la_max, const int la_min, const int lb_max, const int lb_min, const double prefactor, const double ra[3], const double rb[3], const double rp[3], GRID_CONST_WHEN_COLLOCATE double *cab, GRID_CONST_WHEN_INTEGRATE double *cxyz)
Transforms coefficients C_ab into C_xyz.
static void general_cxyz_to_grid(const int border_mask, const int lp, const double zetp, const double dh[3][3], const double dh_inv[3][3], const double rp[3], const int npts_global[3], const int npts_local[3], const int shift_local[3], const int border_width[3], const double radius, GRID_CONST_WHEN_COLLOCATE double *cxyz, GRID_CONST_WHEN_INTEGRATE double *grid)
Collocates coefficients C_xyz onto the grid for general case.
static void ortho_cxyz_to_grid(const int lp, const double zetp, const double dh[3][3], const double dh_inv[3][3], const double rp[3], const int npts_global[3], const int npts_local[3], const int shift_local[3], const double radius, GRID_CONST_WHEN_COLLOCATE double *cxyz, GRID_CONST_WHEN_INTEGRATE double *grid)
Collocates coefficients C_xyz onto the grid for orthorhombic case.
integer, dimension(:), allocatable, public ncoset
integer, parameter, public gaussian
__constant__ orbital coset_inv[1330]
__inline__ __device__ void compute_alpha(const smem_task< T > &task, T *__restrict__ alpha)
Computes the polynomial expansion coefficients: (x-a)**lxa (x-b)**lxb -> sum_{ls} alpha(ls,...
Cab matrix container to be passed through get_force/virial to cab_get.
Orbital angular momentum.
Differences in angular momentum.