8#include "../../offload/offload_runtime.h"
9#if defined(__OFFLOAD) && !defined(__NO_OFFLOAD_GRID)
19#define GRID_DO_COLLOCATE 0
20#include "../common/grid_common.h"
25#include "../common/grid_process_vab.h"
28#error "OpenMP should not be used in .cu files to accommodate HIP."
32#define GRID_N_CXYZ_REGISTERS 10
40__device__
static inline void
41add_to_register(
const double value,
const int index, cxyz_store *store) {
44 store->regs[0] += value;
47 store->regs[1] += value;
50 store->regs[2] += value;
53 store->regs[3] += value;
56 store->regs[4] += value;
59 store->regs[5] += value;
62 store->regs[6] += value;
65 store->regs[7] += value;
68 store->regs[8] += value;
71 store->regs[9] += value;
80__device__
static void gridpoint_to_cxyz(
const double dx,
const double dy,
81 const double dz,
const double zetp,
82 const int lp,
const double *gridpoint,
86 const double r2 = dx * dx + dy * dy + dz * dz;
87 const double gaussian = exp(-zetp * r2);
90 const double prefactor = __ldg(gridpoint) *
gaussian;
93 if (store->offset == 0) {
94 store->regs[0] += prefactor;
96 store->regs[1] += prefactor * dx;
97 store->regs[2] += prefactor * dy;
98 store->regs[3] += prefactor * dz;
100 store->regs[4] += prefactor * dx * dx;
101 store->regs[5] += prefactor * dx * dy;
102 store->regs[6] += prefactor * dx * dz;
103 store->regs[7] += prefactor * dy * dy;
104 store->regs[8] += prefactor * dy * dz;
105 store->regs[9] += prefactor * dz * dz;
109 }
else if (store->offset == 10) {
110 store->regs[0] += prefactor * dx * dx * dx;
111 store->regs[1] += prefactor * dx * dx * dy;
112 store->regs[2] += prefactor * dx * dx * dz;
113 store->regs[3] += prefactor * dx * dy * dy;
114 store->regs[4] += prefactor * dx * dy * dz;
115 store->regs[5] += prefactor * dx * dz * dz;
116 store->regs[6] += prefactor * dy * dy * dy;
117 store->regs[7] += prefactor * dy * dy * dz;
118 store->regs[8] += prefactor * dy * dz * dz;
119 store->regs[9] += prefactor * dz * dz * dz;
121 }
else if (store->offset == 20) {
122 store->regs[0] += prefactor * dx * dx * dx * dx;
123 store->regs[1] += prefactor * dx * dx * dx * dy;
124 store->regs[2] += prefactor * dx * dx * dx * dz;
125 store->regs[3] += prefactor * dx * dx * dy * dy;
126 store->regs[4] += prefactor * dx * dx * dy * dz;
127 store->regs[5] += prefactor * dx * dx * dz * dz;
128 store->regs[6] += prefactor * dx * dy * dy * dy;
129 store->regs[7] += prefactor * dx * dy * dy * dz;
130 store->regs[8] += prefactor * dx * dy * dz * dz;
131 store->regs[9] += prefactor * dx * dz * dz * dz;
133 }
else if (store->offset == 30) {
134 store->regs[0] += prefactor * dy * dy * dy * dy;
135 store->regs[1] += prefactor * dy * dy * dy * dz;
136 store->regs[2] += prefactor * dy * dy * dz * dz;
137 store->regs[3] += prefactor * dy * dz * dz * dz;
138 store->regs[4] += prefactor * dz * dz * dz * dz;
140 store->regs[5] += prefactor * dx * dx * dx * dx * dx;
141 store->regs[6] += prefactor * dx * dx * dx * dx * dy;
142 store->regs[7] += prefactor * dx * dx * dx * dx * dz;
143 store->regs[8] += prefactor * dx * dx * dx * dy * dy;
144 store->regs[9] += prefactor * dx * dx * dx * dy * dz;
147 }
else if (store->offset == 40) {
148 store->regs[0] += prefactor * dx * dx * dx * dz * dz;
149 store->regs[1] += prefactor * dx * dx * dy * dy * dy;
150 store->regs[2] += prefactor * dx * dx * dy * dy * dz;
151 store->regs[3] += prefactor * dx * dx * dy * dz * dz;
152 store->regs[4] += prefactor * dx * dx * dz * dz * dz;
153 store->regs[5] += prefactor * dx * dy * dy * dy * dy;
154 store->regs[6] += prefactor * dx * dy * dy * dy * dz;
155 store->regs[7] += prefactor * dx * dy * dy * dz * dz;
156 store->regs[8] += prefactor * dx * dy * dz * dz * dz;
157 store->regs[9] += prefactor * dx * dz * dz * dz * dz;
159 }
else if (store->offset == 50) {
160 store->regs[0] += prefactor * dy * dy * dy * dy * dy;
161 store->regs[1] += prefactor * dy * dy * dy * dy * dz;
162 store->regs[2] += prefactor * dy * dy * dy * dz * dz;
163 store->regs[3] += prefactor * dy * dy * dz * dz * dz;
164 store->regs[4] += prefactor * dy * dz * dz * dz * dz;
165 store->regs[5] += prefactor * dz * dz * dz * dz * dz;
167 store->regs[6] += prefactor * dx * dx * dx * dx * dx * dx;
168 store->regs[7] += prefactor * dx * dx * dx * dx * dx * dy;
169 store->regs[8] += prefactor * dx * dx * dx * dx * dx * dz;
170 store->regs[9] += prefactor * dx * dx * dx * dx * dy * dy;
173 }
else if (store->offset == 60) {
174 store->regs[0] += prefactor * dx * dx * dx * dx * dy * dz;
175 store->regs[1] += prefactor * dx * dx * dx * dx * dz * dz;
176 store->regs[2] += prefactor * dx * dx * dx * dy * dy * dy;
177 store->regs[3] += prefactor * dx * dx * dx * dy * dy * dz;
178 store->regs[4] += prefactor * dx * dx * dx * dy * dz * dz;
179 store->regs[5] += prefactor * dx * dx * dx * dz * dz * dz;
180 store->regs[6] += prefactor * dx * dx * dy * dy * dy * dy;
181 store->regs[7] += prefactor * dx * dx * dy * dy * dy * dz;
182 store->regs[8] += prefactor * dx * dx * dy * dy * dz * dz;
183 store->regs[9] += prefactor * dx * dx * dy * dz * dz * dz;
185 }
else if (store->offset == 70) {
186 store->regs[0] += prefactor * dx * dx * dz * dz * dz * dz;
187 store->regs[1] += prefactor * dx * dy * dy * dy * dy * dy;
188 store->regs[2] += prefactor * dx * dy * dy * dy * dy * dz;
189 store->regs[3] += prefactor * dx * dy * dy * dy * dz * dz;
190 store->regs[4] += prefactor * dx * dy * dy * dz * dz * dz;
191 store->regs[5] += prefactor * dx * dy * dz * dz * dz * dz;
192 store->regs[6] += prefactor * dx * dz * dz * dz * dz * dz;
193 store->regs[7] += prefactor * dy * dy * dy * dy * dy * dy;
194 store->regs[8] += prefactor * dy * dy * dy * dy * dy * dz;
195 store->regs[9] += prefactor * dy * dy * dy * dy * dz * dz;
197 }
else if (store->offset == 80) {
198 store->regs[0] += prefactor * dy * dy * dy * dz * dz * dz;
199 store->regs[1] += prefactor * dy * dy * dz * dz * dz * dz;
200 store->regs[2] += prefactor * dy * dz * dz * dz * dz * dz;
201 store->regs[3] += prefactor * dz * dz * dz * dz * dz * dz;
203 store->regs[4] += prefactor * dx * dx * dx * dx * dx * dx * dx;
204 store->regs[5] += prefactor * dx * dx * dx * dx * dx * dx * dy;
205 store->regs[6] += prefactor * dx * dx * dx * dx * dx * dx * dz;
206 store->regs[7] += prefactor * dx * dx * dx * dx * dx * dy * dy;
207 store->regs[8] += prefactor * dx * dx * dx * dx * dx * dy * dz;
208 store->regs[9] += prefactor * dx * dx * dx * dx * dx * dz * dz;
213 for (
int i = 0;
i < GRID_N_CXYZ_REGISTERS;
i++) {
214 double val = prefactor;
216 for (
int j = 0; j <
a.l[0]; j++) {
219 for (
int j = 0; j <
a.l[1]; j++) {
222 for (
int j = 0; j <
a.l[2]; j++) {
225 add_to_register(val,
i, store);
234__device__
static void grid_to_cxyz(
const kernel_params *params,
235 const smem_task *task,
const double *
grid,
243 for (
int offset = 0; offset <
ncoset(task->lp);
244 offset += GRID_N_CXYZ_REGISTERS) {
246 double cxyz_regs[GRID_N_CXYZ_REGISTERS] = {0.0};
247 cxyz_store store = {.regs = cxyz_regs, .offset = offset};
249 if (task->use_orthorhombic_kernel) {
257 for (
int i = 0;
i < GRID_N_CXYZ_REGISTERS;
i++) {
258 if (
i + offset <
ncoset(task->lp)) {
259 atomicAddDouble(&cxyz[
i + offset], cxyz_regs[
i]);
270template <
bool COMPUTE_TAU>
278 for (
int i = threadIdx.x;
i < task->nsgf_setb;
i += blockDim.x) {
279 for (
int j = threadIdx.y; j < task->nsgf_seta; j += blockDim.y) {
280 double block_val = 0.0;
281 const int jco_start =
ncoset(task->lb_min_basis - 1) + threadIdx.z;
282 const int jco_end =
ncoset(task->lb_max_basis);
283 for (
int jco = jco_start; jco < jco_end; jco += blockDim.z) {
285 const double sphib = task->sphib[
i * task->maxcob + jco];
286 const int ico_start =
ncoset(task->la_min_basis - 1);
287 const int ico_end =
ncoset(task->la_max_basis);
288 for (
int ico = ico_start; ico < ico_end; ico++) {
291 get_hab(a, b, task->zeta, task->zetb, cab, COMPUTE_TAU);
292 const double sphia = task->sphia[j * task->maxcoa + ico];
293 block_val += hab * sphia * sphib;
296 if (task->block_transposed) {
297 atomicAddDouble(&task->hab_block[j * task->nsgfb +
i], block_val);
299 atomicAddDouble(&task->hab_block[
i * task->nsgfa + j], block_val);
310template <
bool COMPUTE_TAU>
311__device__
static void store_forces_and_virial(
const kernel_params *params,
312 const smem_task *task,
315 for (
int i = threadIdx.x;
i < task->nsgf_setb;
i += blockDim.x) {
316 for (
int j = threadIdx.y; j < task->nsgf_seta; j += blockDim.y) {
318 if (task->block_transposed) {
319 block_val = task->pab_block[j * task->nsgfb +
i] * task->off_diag_twice;
321 block_val = task->pab_block[
i * task->nsgfa + j] * task->off_diag_twice;
323 const int jco_start =
ncoset(task->lb_min_basis - 1) + threadIdx.z;
324 const int jco_end =
ncoset(task->lb_max_basis);
325 for (
int jco = jco_start; jco < jco_end; jco += blockDim.z) {
326 const double sphib = task->sphib[
i * task->maxcob + jco];
327 const int ico_start =
ncoset(task->la_min_basis - 1);
328 const int ico_end =
ncoset(task->la_max_basis);
329 for (
int ico = ico_start; ico < ico_end; ico++) {
330 const double sphia = task->sphia[j * task->maxcoa + ico];
331 const double pabval = block_val * sphia * sphib;
334 for (
int k = 0; k < 3; k++) {
335 const double force_a =
336 get_force_a(a, b, k, task->zeta, task->zetb, cab, COMPUTE_TAU);
337 atomicAddDouble(&task->forces_a[k], force_a * pabval);
338 const double force_b =
get_force_b(a, b, k, task->zeta, task->zetb,
339 task->rab, cab, COMPUTE_TAU);
340 atomicAddDouble(&task->forces_b[k], force_b * pabval);
342 if (params->virial != NULL) {
343 for (
int k = 0; k < 3; k++) {
344 for (
int l = 0; l < 3; l++) {
346 a, b, k, l, task->zeta, task->zetb, cab, COMPUTE_TAU);
347 const double virial_b =
348 get_virial_b(a, b, k, l, task->zeta, task->zetb, task->rab,
350 const double virial = pabval * (virial_a + virial_b);
351 atomicAddDouble(¶ms->virial[k * 3 + l], virial);
366__device__
static void zero_cxyz(
const smem_task *task,
double *cxyz) {
367 if (threadIdx.z == 0 && threadIdx.y == 0) {
368 for (
int i = threadIdx.x;
i <
ncoset(task->lp);
i += blockDim.x) {
379template <
bool COMPUTE_TAU,
bool CALCULATE_FORCES>
380__device__
static void integrate_kernel(
const kernel_params *params) {
383 __shared__ smem_task task;
384 load_task(params, &task);
387 if (2.0 * task.radius < task.dh_max) {
392 extern __shared__
double shared_memory[];
393 double *smem_cab = &shared_memory[params->smem_cab_offset];
394 double *smem_alpha = &shared_memory[params->smem_alpha_offset];
395 double *smem_cxyz = &shared_memory[params->smem_cxyz_offset];
399 if (params->smem_cab_length < task.n1 * task.n2) {
400 cab.
data = malloc_cab(&task);
405 zero_cab(&cab, task.n1 * task.n2);
408 zero_cxyz(&task, smem_cxyz);
409 grid_to_cxyz(params, &task, params->grid, smem_cxyz);
412 store_hab<COMPUTE_TAU>(&task, &cab);
413 if (CALCULATE_FORCES) {
414 store_forces_and_virial<COMPUTE_TAU>(params, &task, &cab);
417 if (params->smem_cab_length < task.n1 * task.n2) {
426__global__
static void grid_integrate_density(
const kernel_params params) {
427 integrate_kernel<false, false>(¶ms);
434__global__
static void grid_integrate_tau(
const kernel_params params) {
435 integrate_kernel<true, false>(¶ms);
442__global__
static void
443grid_integrate_density_forces(
const kernel_params params) {
444 integrate_kernel<false, true>(¶ms);
451__global__
static void grid_integrate_tau_forces(
const kernel_params params) {
452 integrate_kernel<true, true>(¶ms);
459void grid_gpu_integrate_one_grid_level(
460 const grid_gpu_task_list *task_list,
const int first_task,
461 const int last_task,
const bool compute_tau,
const grid_gpu_layout *layout,
462 const offloadStream_t stream,
const double *pab_blocks_dev,
463 const double *grid_dev,
double *hab_blocks_dev,
double *forces_dev,
464 double *virial_dev,
int *lp_diff) {
467 const bool calculate_forces = (forces_dev != NULL);
468 const bool calculate_virial = (virial_dev != NULL);
469 assert(!calculate_virial || calculate_forces);
473 const int la_max = task_list->lmax + ldiffs.
la_max_diff;
474 const int lb_max = task_list->lmax + ldiffs.
lb_max_diff;
475 const int lp_max = la_max + lb_max;
477 const int ntasks = last_task - first_task + 1;
482 init_constant_memory();
488 const int alpha_len = 3 * (lb_max + 1) * (la_max + 1) * (lp_max + 1);
489 const int cxyz_len =
ncoset(lp_max);
491 const size_t smem_per_block =
492 (alpha_len + cxyz_len + cab_len) *
sizeof(
double);
495 kernel_params params;
496 params.smem_cab_length = cab_len;
497 params.smem_cab_offset = 0;
498 params.smem_alpha_offset = params.smem_cab_offset + cab_len;
499 params.smem_cxyz_offset = params.smem_alpha_offset + alpha_len;
500 params.first_task = first_task;
501 params.grid = grid_dev;
502 params.tasks = task_list->tasks_dev;
503 params.pab_blocks = pab_blocks_dev;
504 params.hab_blocks = hab_blocks_dev;
505 params.forces = forces_dev;
506 params.virial = virial_dev;
511 memcpy(params.dh, layout->dh, 9 *
sizeof(
double));
512 memcpy(params.dh_inv, layout->dh_inv, 9 *
sizeof(
double));
513 memcpy(params.npts_global, layout->npts_global, 3 *
sizeof(
int));
514 memcpy(params.npts_local, layout->npts_local, 3 *
sizeof(
int));
515 memcpy(params.shift_local, layout->shift_local, 3 *
sizeof(
int));
518 const int nblocks = ntasks;
519 const dim3 threads_per_block(4, 4, 4);
521 if (!compute_tau && !calculate_forces) {
522 grid_integrate_density<<<nblocks, threads_per_block, smem_per_block,
524 }
else if (compute_tau && !calculate_forces) {
525 grid_integrate_tau<<<nblocks, threads_per_block, smem_per_block, stream>>>(
527 }
else if (!compute_tau && calculate_forces) {
528 grid_integrate_density_forces<<<nblocks, threads_per_block, smem_per_block,
530 }
else if (compute_tau && calculate_forces) {
531 grid_integrate_tau_forces<<<nblocks, threads_per_block, smem_per_block,
534 OFFLOAD_CHECK(offloadGetLastError());
static int imin(int x, int y)
Returns the smaller of the two integers (missing from the C standard).
static void const int const int const int const int const int const double const int const int const int int GRID_CONST_WHEN_COLLOCATE double GRID_CONST_WHEN_INTEGRATE double * grid
static void const int const int i
static void store_hab(const grid_basis_set *ibasis, const grid_basis_set *jbasis, const int iset, const int jset, const bool transpose, const double *hab, double *block)
Transforms hab from prim. cartesian to contracted spherical basis.
static GRID_DEVICE double get_hab(const orbital a, const orbital b, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns element i,j of hab matrix.
static process_ldiffs process_get_ldiffs(bool calculate_forces, bool calculate_virial, bool compute_tau)
Returns difference in angular momentum range for given flags.
static GRID_DEVICE double get_force_b(const orbital a, const orbital b, const int i, const double zeta, const double zetb, const double rab[3], const cab_store *cab, const bool compute_tau)
Returns i'th component of force on atom b.
static GRID_DEVICE double get_virial_b(const orbital a, const orbital b, const int i, const int j, const double zeta, const double zetb, const double rab[3], const cab_store *cab, const bool compute_tau)
Returns element i,j of virial on atom b.
static GRID_DEVICE double get_virial_a(const orbital a, const orbital b, const int i, const int j, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns element i,j of virial on atom a.
static GRID_DEVICE double get_force_a(const orbital a, const orbital b, const int i, const double zeta, const double zetb, const cab_store *cab, const bool compute_tau)
Returns i'th component of force on atom a.
static void cab_to_cxyz(const int la_max, const int la_min, const int lb_max, const int lb_min, const double prefactor, const double ra[3], const double rb[3], const double rp[3], GRID_CONST_WHEN_COLLOCATE double *cab, GRID_CONST_WHEN_INTEGRATE double *cxyz)
Transforms coefficients C_ab into C_xyz.
static void general_cxyz_to_grid(const int border_mask, const int lp, const double zetp, const double dh[3][3], const double dh_inv[3][3], const double rp[3], const int npts_global[3], const int npts_local[3], const int shift_local[3], const int border_width[3], const double radius, GRID_CONST_WHEN_COLLOCATE double *cxyz, GRID_CONST_WHEN_INTEGRATE double *grid)
Collocates coefficients C_xyz onto the grid for general case.
static void ortho_cxyz_to_grid(const int lp, const double zetp, const double dh[3][3], const double dh_inv[3][3], const double rp[3], const int npts_global[3], const int npts_local[3], const int shift_local[3], const double radius, GRID_CONST_WHEN_COLLOCATE double *cxyz, GRID_CONST_WHEN_INTEGRATE double *grid)
Collocates coefficients C_xyz onto the grid for orthorhombic case.
integer, dimension(:), allocatable, public ncoset
integer, parameter, public gaussian
__constant__ orbital coset_inv[1330]
__inline__ __device__ void compute_alpha(const smem_task< T > &task, T *__restrict__ alpha)
Computes the polynomial expansion coefficients: (x-a)**lxa (x-b)**lxb -> sum_{ls} alpha(ls,...
Cab matrix container to be passed through get_force/virial to cab_get.
Orbital angular momentum.
Differences in angular momentum.