(git:1155b05)
Loading...
Searching...
No Matches
cp_fm_cusolver.c
Go to the documentation of this file.
1/*----------------------------------------------------------------------------*/
2/* CP2K: A general program to perform molecular dynamics simulations */
3/* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4/* */
5/* SPDX-License-Identifier: GPL-2.0-or-later */
6/*----------------------------------------------------------------------------*/
7
8#if defined(__CUSOLVERMP)
9
10#include "../offload/offload_library.h"
11#include <assert.h>
12#include <cuda_runtime.h>
13#include <cusolverMp.h>
14#include <math.h>
15#include <mpi.h>
16#include <stdlib.h>
17#include <string.h>
18
19#if defined(__CUSOLVERMP_NCCL)
20#include <nccl.h>
21#else
22#include <cal.h>
23#endif
24
25/*******************************************************************************
26 * \brief Run given CUDA command and upon failure abort with a nice message.
27 * \author Ole Schuett
28 ******************************************************************************/
29#define CUDA_CHECK(cmd) \
30 do { \
31 cudaError_t status = cmd; \
32 if (status != cudaSuccess) { \
33 fprintf(stderr, "ERROR: %s %s %d\n", cudaGetErrorString(status), \
34 __FILE__, __LINE__); \
35 abort(); \
36 } \
37 } while (0)
38
39#if defined(__CUSOLVERMP_NCCL)
40/*******************************************************************************
41 * \brief Run given NCCL command and upon failure abort with a nice message.
42 * \author Jiri Vyskocil
43 ******************************************************************************/
44#define NCCL_CHECK(cmd) \
45 do { \
46 ncclResult_t status = cmd; \
47 if (status != ncclSuccess) { \
48 fprintf(stderr, "ERROR: %s %s %d\n", ncclGetErrorString(status), \
49 __FILE__, __LINE__); \
50 abort(); \
51 } \
52 } while (0)
53
54#else
55/*******************************************************************************
56 * \brief Decode given cal error.
57 * \author Ole Schuett
58 ******************************************************************************/
59static char *calGetErrorString(calError_t status) {
60 switch (status) {
61 case CAL_OK:
62 return "CAL_OK";
63 case CAL_ERROR:
64 return "CAL_ERROR";
65 case CAL_ERROR_INVALID_PARAMETER:
66 return "CAL_ERROR_INVALID_PARAMETER";
67 case CAL_ERROR_INTERNAL:
68 return "CAL_ERROR_INTERNAL";
69 case CAL_ERROR_CUDA:
70 return "CAL_ERROR_CUDA";
71 case CAL_ERROR_UCC:
72 return "CAL_ERROR_UCC";
73 case CAL_ERROR_NOT_SUPPORTED:
74 return "CAL_ERROR_NOT_SUPPORTED";
75 case CAL_ERROR_INPROGRESS:
76 return "CAL_ERROR_INPROGRESS";
77 default:
78 return "CAL UNKNOWN ERROR";
79 }
80}
81
82/*******************************************************************************
83 * \brief Run given cal command and upon failure abort with a nice message.
84 * \author Ole Schuett
85 ******************************************************************************/
86#define CAL_CHECK(cmd) \
87 do { \
88 calError_t status = cmd; \
89 if (status != CAL_OK) { \
90 fprintf(stderr, "ERROR: %s %s %d\n", calGetErrorString(status), \
91 __FILE__, __LINE__); \
92 abort(); \
93 } \
94 } while (0)
95
96/*******************************************************************************
97 * \brief Callback for cal library to initiate an allgather operation.
98 * \author Ole Schuett
99 ******************************************************************************/
100static calError_t allgather(void *src_buf, void *recv_buf, size_t size,
101 void *data, void **req) {
102 const MPI_Comm comm = *(MPI_Comm *)data;
103 MPI_Request *request = malloc(sizeof(MPI_Request));
104 *req = request;
105 const int status = MPI_Iallgather(src_buf, size, MPI_BYTE, recv_buf, size,
106 MPI_BYTE, comm, request);
107 return (status == MPI_SUCCESS) ? CAL_OK : CAL_ERROR;
108}
109
110/*******************************************************************************
111 * \brief Callback for cal library to test if a request has completed.
112 * \author Ole Schuett
113 ******************************************************************************/
114static calError_t req_test(void *req) {
115 MPI_Request *request = (MPI_Request *)(req);
116 int completed;
117 const int status = MPI_Test(request, &completed, MPI_STATUS_IGNORE);
118 if (status != MPI_SUCCESS) {
119 return CAL_ERROR;
120 }
121 return completed ? CAL_OK : CAL_ERROR_INPROGRESS;
122}
123
124/*******************************************************************************
125 * \brief Callback for cal library to free a request.
126 * \author Ole Schuett
127 ******************************************************************************/
128static calError_t req_free(void *req) {
129 free(req);
130 return CAL_OK;
131}
132#endif /* __CUSOLVERMP_NCCL */
133
134/*******************************************************************************
135 * \brief Decode given cusolver error.
136 * \author Ole Schuett
137 ******************************************************************************/
138static char *cusolverGetErrorString(cusolverStatus_t status) {
139 switch (status) {
140 case CUSOLVER_STATUS_SUCCESS:
141 return "CUSOLVER_STATUS_SUCCESS";
142 case CUSOLVER_STATUS_NOT_INITIALIZED:
143 return "CUSOLVER_STATUS_NOT_INITIALIZED";
144 case CUSOLVER_STATUS_ALLOC_FAILED:
145 return "CUSOLVER_STATUS_ALLOC_FAILED";
146 case CUSOLVER_STATUS_INVALID_VALUE:
147 return "CUSOLVER_STATUS_INVALID_VALUE";
148 case CUSOLVER_STATUS_ARCH_MISMATCH:
149 return "CUSOLVER_STATUS_ARCH_MISMATCH";
150 case CUSOLVER_STATUS_MAPPING_ERROR:
151 return "CUSOLVER_STATUS_MAPPING_ERROR";
152 case CUSOLVER_STATUS_EXECUTION_FAILED:
153 return "CUSOLVER_STATUS_EXECUTION_FAILED";
154 case CUSOLVER_STATUS_INTERNAL_ERROR:
155 return "CUSOLVER_STATUS_INTERNAL_ERROR";
156 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
157 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
158 case CUSOLVER_STATUS_NOT_SUPPORTED:
159 return "CUSOLVER_STATUS_NOT_SUPPORTED";
160 case CUSOLVER_STATUS_ZERO_PIVOT:
161 return "CUSOLVER_STATUS_ZERO_PIVOT";
162 case CUSOLVER_STATUS_INVALID_LICENSE:
163 return "CUSOLVER_STATUS_INVALID_LICENSE";
164 case CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED:
165 return "CUSOLVER_STATUS_IRS_PARAMS_NOT_INITIALIZED";
166 case CUSOLVER_STATUS_IRS_PARAMS_INVALID:
167 return "CUSOLVER_STATUS_IRS_PARAMS_INVALID";
168 case CUSOLVER_STATUS_IRS_PARAMS_INVALID_PREC:
169 return "CUSOLVER_STATUS_IRS_PARAMS_INVALID_PREC";
170 case CUSOLVER_STATUS_IRS_PARAMS_INVALID_REFINE:
171 return "CUSOLVER_STATUS_IRS_PARAMS_INVALID_REFINE";
172 case CUSOLVER_STATUS_IRS_PARAMS_INVALID_MAXITER:
173 return "CUSOLVER_STATUS_IRS_PARAMS_INVALID_MAXITER";
174 case CUSOLVER_STATUS_IRS_INTERNAL_ERROR:
175 return "CUSOLVER_STATUS_IRS_INTERNAL_ERROR";
176 case CUSOLVER_STATUS_IRS_NOT_SUPPORTED:
177 return "CUSOLVER_STATUS_IRS_NOT_SUPPORTED";
178 case CUSOLVER_STATUS_IRS_OUT_OF_RANGE:
179 return "CUSOLVER_STATUS_IRS_OUT_OF_RANGE";
180 case CUSOLVER_STATUS_IRS_NRHS_NOT_SUPPORTED_FOR_REFINE_GMRES:
181 return "CUSOLVER_STATUS_IRS_NRHS_NOT_SUPPORTED_FOR_REFINE_GMRES";
182 case CUSOLVER_STATUS_IRS_INFOS_NOT_INITIALIZED:
183 return "CUSOLVER_STATUS_IRS_INFOS_NOT_INITIALIZED";
184 case CUSOLVER_STATUS_IRS_INFOS_NOT_DESTROYED:
185 return "CUSOLVER_STATUS_IRS_INFOS_NOT_DESTROYED";
186 case CUSOLVER_STATUS_IRS_MATRIX_SINGULAR:
187 return "CUSOLVER_STATUS_IRS_MATRIX_SINGULAR";
188 case CUSOLVER_STATUS_INVALID_WORKSPACE:
189 return "CUSOLVER_STATUS_INVALID_WORKSPACE";
190 default:
191 return "CUSOLVER UNKNOWN ERROR";
192 }
193}
194
195/*******************************************************************************
196 * \brief Run given cusolver command and upon failure abort with a nice message.
197 * \author Ole Schuett
198 ******************************************************************************/
199#define CUSOLVER_CHECK(cmd) \
200 do { \
201 cusolverStatus_t status = cmd; \
202 if (status != CUSOLVER_STATUS_SUCCESS) { \
203 fprintf(stderr, "ERROR: %s %s %d\n", cusolverGetErrorString(status), \
204 __FILE__, __LINE__); \
205 abort(); \
206 } \
207 } while (0)
208
209/*******************************************************************************
210 * \brief Driver routine to diagonalize a matrix with the cuSOLVERMp library.
211 * \author Ole Schuett
212 ******************************************************************************/
213void cp_fm_diag_cusolver(const int fortran_comm, const int matrix_desc[9],
214 const int nprow, const int npcol, const int myprow,
215 const int mypcol, const int n, const double *matrix,
216 double *eigenvectors, double *eigenvalues) {
217
219 const int local_device = offload_get_chosen_device();
220
221 MPI_Comm comm = MPI_Comm_f2c(fortran_comm);
222 int rank, nranks;
223 MPI_Comm_rank(comm, &rank);
224 MPI_Comm_size(comm, &nranks);
225
226#if defined(__CUSOLVERMP_NCCL)
227 // Create NCCL communicator.
228 ncclUniqueId nccl_id;
229 if (rank == 0) {
230 NCCL_CHECK(ncclGetUniqueId(&nccl_id));
231 }
232 MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, comm);
233
234 ncclComm_t nccl_comm;
235 NCCL_CHECK(ncclCommInitRank(&nccl_comm, nranks, nccl_id, rank));
236#else
237 // Create CAL communicator.
238 cal_comm_t cal_comm = NULL;
239 cal_comm_create_params_t params;
240 params.allgather = &allgather;
241 params.req_test = &req_test;
242 params.req_free = &req_free;
243 params.data = &comm;
244 params.rank = rank;
245 params.nranks = nranks;
246 params.local_device = local_device;
247 CAL_CHECK(cal_comm_create(params, &cal_comm));
248#endif
249
250 // Create various handles.
251 cudaStream_t stream = NULL;
252 CUDA_CHECK(cudaStreamCreate(&stream));
253
254 cusolverMpHandle_t cusolvermp_handle = NULL;
255 CUSOLVER_CHECK(cusolverMpCreate(&cusolvermp_handle, local_device, stream));
256
257 cusolverMpGrid_t grid = NULL;
258#if defined(__CUSOLVERMP_NCCL)
259 CUSOLVER_CHECK(cusolverMpCreateDeviceGrid(cusolvermp_handle, &grid, nccl_comm,
260 nprow, npcol,
261 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR));
262#else
263 CUSOLVER_CHECK(cusolverMpCreateDeviceGrid(cusolvermp_handle, &grid, cal_comm,
264 nprow, npcol,
265 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR));
266#endif
267 const int mb = matrix_desc[4];
268 const int nb = matrix_desc[5];
269 const int rsrc = matrix_desc[6];
270 const int csrc = matrix_desc[7];
271 const int ldA = matrix_desc[8];
272 assert(rsrc == csrc);
273 assert(ldA >= 1);
274
275 const int np = cusolverMpNUMROC(n, mb, myprow, rsrc, nprow);
276 const int nq = cusolverMpNUMROC(n, nb, mypcol, csrc, npcol);
277 assert(np == ldA);
278
279 const cublasFillMode_t fill_mode = CUBLAS_FILL_MODE_UPPER;
280 const cudaDataType_t data_type = CUDA_R_64F; // double
281
282 cusolverMpMatrixDescriptor_t cusolvermp_matrix_desc = NULL;
283 CUSOLVER_CHECK(cusolverMpCreateMatrixDesc(
284 &cusolvermp_matrix_desc, grid, data_type, n, n, mb, nb, rsrc, csrc, np));
285
286 // Allocate workspaces.
287 size_t work_dev_size, work_host_size;
288 void *DUMMY = (void *)1; // Workaround to avoid crash when passing NULL.
289 CUSOLVER_CHECK(cusolverMpSyevd_bufferSize(
290 cusolvermp_handle, "V", fill_mode, n, DUMMY, 1, 1, cusolvermp_matrix_desc,
291 NULL, NULL, 1, 1, cusolvermp_matrix_desc, data_type, &work_dev_size,
292 &work_host_size));
293
294 double *work_dev = NULL;
295 CUDA_CHECK(cudaMalloc((void **)&work_dev, work_dev_size));
296
297 double *work_host = NULL;
298 CUDA_CHECK(cudaMallocHost((void **)&work_host, work_host_size));
299
300 // Upload input matrix.
301 const size_t matrix_local_size = ldA * nq * sizeof(double);
302 double *matrix_dev = NULL;
303 CUDA_CHECK(cudaMalloc((void **)&matrix_dev, matrix_local_size));
304 CUDA_CHECK(cudaMemcpyAsync(matrix_dev, matrix, matrix_local_size,
305 cudaMemcpyHostToDevice, stream));
306
307 // Allocate result buffers.
308 int *info_dev = NULL;
309 CUDA_CHECK(cudaMalloc((void **)&info_dev, sizeof(int)));
310
311 double *eigenvectors_dev = NULL;
312 CUDA_CHECK(cudaMalloc((void **)&eigenvectors_dev, matrix_local_size));
313
314 double *eigenvalues_dev = NULL;
315 CUDA_CHECK(cudaMalloc((void **)&eigenvalues_dev, n * sizeof(double)));
316
317 // Call solver.
318 CUSOLVER_CHECK(
319 cusolverMpSyevd(cusolvermp_handle, "V", fill_mode, n, matrix_dev, 1, 1,
320 cusolvermp_matrix_desc, eigenvalues_dev, eigenvectors_dev,
321 1, 1, cusolvermp_matrix_desc, data_type, work_dev,
322 work_dev_size, work_host, work_host_size, info_dev));
323
324 // Wait for solver to finish.
325 CUDA_CHECK(cudaStreamSynchronize(stream));
326#if !defined(__CUSOLVERMP_NCCL)
327 CAL_CHECK(cal_stream_sync(cal_comm, stream));
328#endif
329
330 // Check info.
331 int info = -1;
332 CUDA_CHECK(cudaMemcpy(&info, info_dev, sizeof(int), cudaMemcpyDeviceToHost));
333 assert(info == 0);
334
335 // Download results.
336 CUDA_CHECK(cudaMemcpyAsync(eigenvectors, eigenvectors_dev, matrix_local_size,
337 cudaMemcpyDeviceToHost, stream));
338 CUDA_CHECK(cudaMemcpyAsync(eigenvalues, eigenvalues_dev, n * sizeof(double),
339 cudaMemcpyDeviceToHost, stream));
340
341 // Wait for download to finish.
342 CUDA_CHECK(cudaStreamSynchronize(stream));
343
344 // Free buffers.
345 CUDA_CHECK(cudaFree(matrix_dev));
346 CUDA_CHECK(cudaFree(info_dev));
347 CUDA_CHECK(cudaFree(eigenvectors_dev));
348 CUDA_CHECK(cudaFree(eigenvalues_dev));
349 CUDA_CHECK(cudaFree(work_dev));
350 CUDA_CHECK(cudaFreeHost(work_host));
351
352 // Destroy handles.
353 CUSOLVER_CHECK(cusolverMpDestroyMatrixDesc(cusolvermp_matrix_desc));
354 CUSOLVER_CHECK(cusolverMpDestroyGrid(grid));
355 CUSOLVER_CHECK(cusolverMpDestroy(cusolvermp_handle));
356 CUDA_CHECK(cudaStreamDestroy(stream));
357#if defined(__CUSOLVERMP_NCCL)
358 NCCL_CHECK(ncclCommDestroy(nccl_comm));
359#else
360 CAL_CHECK(cal_comm_destroy(cal_comm));
361#endif
362
363 // Sync MPI ranks to include load imbalance in total timings.
364 MPI_Barrier(comm);
365}
366
367/*******************************************************************************
368 * \brief Driver routine to solve A*x = lambda*B*x with cuSOLVERMp sygvd.
369 * \author Jiri Vyskocil
370 ******************************************************************************/
371void cp_fm_diag_cusolver_sygvd(const int fortran_comm,
372 const int a_matrix_desc[9],
373 const int b_matrix_desc[9], const int nprow,
374 const int npcol, const int myprow,
375 const int mypcol, const int n,
376 const double *aMatrix, const double *bMatrix,
377 double *eigenvectors, double *eigenvalues) {
378
380 const int local_device = offload_get_chosen_device();
381
382 MPI_Comm comm = MPI_Comm_f2c(fortran_comm);
383 int rank, nranks;
384 MPI_Comm_rank(comm, &rank);
385 MPI_Comm_size(comm, &nranks);
386
387#if defined(__CUSOLVERMP_NCCL)
388 // Create NCCL communicator.
389 ncclUniqueId nccl_id;
390 if (rank == 0) {
391 NCCL_CHECK(ncclGetUniqueId(&nccl_id));
392 }
393 MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, comm);
394
395 ncclComm_t nccl_comm;
396 NCCL_CHECK(ncclCommInitRank(&nccl_comm, nranks, nccl_id, rank));
397#else
398 // Create CAL communicator
399 cal_comm_t cal_comm = NULL;
400 cal_comm_create_params_t params;
401 params.allgather = &allgather;
402 params.req_test = &req_test;
403 params.req_free = &req_free;
404 params.data = &comm;
405 params.rank = rank;
406 params.nranks = nranks;
407 params.local_device = local_device;
408 CAL_CHECK(cal_comm_create(params, &cal_comm));
409#endif
410
411 // Create CUDA stream and cuSOLVER handle
412 cudaStream_t stream = NULL;
413 CUDA_CHECK(cudaStreamCreate(&stream));
414
415 cusolverMpHandle_t cusolvermp_handle = NULL;
416 CUSOLVER_CHECK(cusolverMpCreate(&cusolvermp_handle, local_device, stream));
417
418 // Define grid for device computation
419 cusolverMpGrid_t grid = NULL;
420#if defined(__CUSOLVERMP_NCCL)
421 CUSOLVER_CHECK(cusolverMpCreateDeviceGrid(cusolvermp_handle, &grid, nccl_comm,
422 nprow, npcol,
423 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR));
424#else
425 CUSOLVER_CHECK(cusolverMpCreateDeviceGrid(cusolvermp_handle, &grid, cal_comm,
426 nprow, npcol,
427 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR));
428#endif
429
430 // Matrix descriptors for A, B, and Z
431 const int mb_a = a_matrix_desc[4];
432 const int nb_a = a_matrix_desc[5];
433 const int rsrc_a = a_matrix_desc[6];
434 const int csrc_a = a_matrix_desc[7];
435 const int ldA = a_matrix_desc[8];
436
437 const int mb_b = b_matrix_desc[4];
438 const int nb_b = b_matrix_desc[5];
439 const int rsrc_b = b_matrix_desc[6];
440 const int csrc_b = b_matrix_desc[7];
441 const int ldB = b_matrix_desc[8];
442
443 // Ensure consistency in block sizes, sources, and leading dimensions
444 assert(mb_a == mb_b && nb_a == nb_b);
445 assert(rsrc_a == rsrc_b && csrc_a == csrc_b);
446 (void)ldB; // Suppress unused variable warning
447
448 const int np_a = cusolverMpNUMROC(n, mb_a, myprow, rsrc_a, nprow);
449 const int nq_a = cusolverMpNUMROC(n, nb_a, mypcol, csrc_a, npcol);
450
451 const cublasFillMode_t uplo = CUBLAS_FILL_MODE_LOWER;
452 const cusolverEigType_t itype = CUSOLVER_EIG_TYPE_1;
453 const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
454 const cudaDataType_t data_type = CUDA_R_64F;
455
456 // Create matrix descriptors
457 cusolverMpMatrixDescriptor_t descrA = NULL;
458 cusolverMpMatrixDescriptor_t descrB = NULL;
459 cusolverMpMatrixDescriptor_t descrZ = NULL;
460
461 // Create matrix descriptors using ldA as local leading dimension (LLD)
462 // Note: We use ldA for all matrices. The assertion above verifies ldA == ldB.
463 CUSOLVER_CHECK(cusolverMpCreateMatrixDesc(&descrA, grid, data_type, n, n,
464 mb_a, nb_a, rsrc_a, csrc_a, ldA));
465 CUSOLVER_CHECK(cusolverMpCreateMatrixDesc(&descrB, grid, data_type, n, n,
466 mb_b, nb_b, rsrc_b, csrc_b, ldA));
467 CUSOLVER_CHECK(cusolverMpCreateMatrixDesc(&descrZ, grid, data_type, n, n,
468 mb_a, nb_a, rsrc_a, csrc_a, ldA));
469
470 // Allocate device memory for matrices
471 double *dev_A = NULL, *dev_B = NULL;
472 size_t matrix_local_size = ldA * nq_a * sizeof(double);
473 CUDA_CHECK(cudaMalloc((void **)&dev_A, matrix_local_size));
474 CUDA_CHECK(cudaMalloc((void **)&dev_B, matrix_local_size));
475
476 // Copy matrices from host to device
477 CUDA_CHECK(cudaMemcpyAsync(dev_A, aMatrix, matrix_local_size,
478 cudaMemcpyHostToDevice, stream));
479 CUDA_CHECK(cudaMemcpyAsync(dev_B, bMatrix, matrix_local_size,
480 cudaMemcpyHostToDevice, stream));
481
482 // Allocate device memory for eigenvalues and eigenvectors
483 double *dev_Z = NULL, *eigenvalues_dev = NULL;
484 CUDA_CHECK(cudaMalloc((void **)&dev_Z, matrix_local_size));
485 CUDA_CHECK(cudaMalloc((void **)&eigenvalues_dev, n * sizeof(double)));
486
487 // Query workspace size
488 size_t work_dev_size = 0, work_host_size = 0;
489 const int64_t ia = 1, ja = 1, ib = 1, jb = 1, iz = 1, jz = 1;
490 const int64_t m = (int64_t)n;
491
492 cusolverStatus_t status_bufsize = cusolverMpSygvd_bufferSize(
493 cusolvermp_handle, itype, jobz, uplo, m, ia, ja, descrA, ib, jb, descrB,
494 iz, jz, descrZ, data_type, &work_dev_size, &work_host_size);
495 if (status_bufsize != CUSOLVER_STATUS_SUCCESS) {
496 fprintf(stderr, "ERROR: cusolverMpSygvd_bufferSize failed with status=%d\n",
497 (int)status_bufsize);
498 abort();
499 }
500
501 void *work_dev = NULL, *work_host = NULL;
502 CUDA_CHECK(cudaMalloc(&work_dev, work_dev_size));
503 CUDA_CHECK(cudaMallocHost(&work_host, work_host_size));
504
505 // Allocate and initialize device memory for info
506 int *info_dev = NULL;
507 CUDA_CHECK(cudaMalloc((void **)&info_dev, sizeof(int)));
508 CUDA_CHECK(cudaMemset(info_dev, 0, sizeof(int)));
509
510 // Call cusolverMpSygvd
511 cusolverStatus_t status_sygvd = cusolverMpSygvd(
512 cusolvermp_handle, itype, jobz, uplo, m, dev_A, ia, ja, descrA, dev_B, ib,
513 jb, descrB, eigenvalues_dev, dev_Z, iz, jz, descrZ, data_type, work_dev,
514 work_dev_size, work_host, work_host_size, info_dev);
515 if (status_sygvd != CUSOLVER_STATUS_SUCCESS) {
516 fprintf(stderr, "ERROR: cusolverMpSygvd failed with status=%d\n",
517 (int)status_sygvd);
518 abort();
519 }
520
521 // Wait for computation to finish
522 CUDA_CHECK(cudaStreamSynchronize(stream));
523#if !defined(__CUSOLVERMP_NCCL)
524 CAL_CHECK(cal_stream_sync(cal_comm, stream));
525#endif
526
527 // Check info
528 int info;
529 CUDA_CHECK(cudaMemcpy(&info, info_dev, sizeof(int), cudaMemcpyDeviceToHost));
530 if (info != 0) {
531 fprintf(stderr, "ERROR: cusolverMpSygvd failed with info = %d\n", info);
532 abort();
533 }
534
535 // Copy results back to host
536 CUDA_CHECK(cudaMemcpyAsync(eigenvectors, dev_Z, matrix_local_size,
537 cudaMemcpyDeviceToHost, stream));
538 CUDA_CHECK(cudaMemcpyAsync(eigenvalues, eigenvalues_dev, n * sizeof(double),
539 cudaMemcpyDeviceToHost, stream));
540
541 // Wait for copy to finish
542 CUDA_CHECK(cudaStreamSynchronize(stream));
543
544 // Clean up resources
545 CUDA_CHECK(cudaFree(dev_A));
546 CUDA_CHECK(cudaFree(dev_B));
547 CUDA_CHECK(cudaFree(dev_Z));
548 CUDA_CHECK(cudaFree(eigenvalues_dev));
549 CUDA_CHECK(cudaFree(info_dev));
550 CUDA_CHECK(cudaFree(work_dev));
551 CUDA_CHECK(cudaFreeHost(work_host));
552 CUSOLVER_CHECK(cusolverMpDestroyMatrixDesc(descrA));
553 CUSOLVER_CHECK(cusolverMpDestroyMatrixDesc(descrB));
554 CUSOLVER_CHECK(cusolverMpDestroyMatrixDesc(descrZ));
555 CUSOLVER_CHECK(cusolverMpDestroyGrid(grid));
556 CUSOLVER_CHECK(cusolverMpDestroy(cusolvermp_handle));
557 CUDA_CHECK(cudaStreamDestroy(stream));
558#if defined(__CUSOLVERMP_NCCL)
559 NCCL_CHECK(ncclCommDestroy(nccl_comm));
560#else
561 CAL_CHECK(cal_comm_destroy(cal_comm));
562#endif
563
564 MPI_Barrier(comm); // Synchronize MPI ranks
565}
566#endif
567
568// EOF
static void const int const int const int const int const int const double const int const int const int int GRID_CONST_WHEN_COLLOCATE double GRID_CONST_WHEN_INTEGRATE double * grid
subroutine, public cp_fm_diag_cusolver(matrix, eigenvectors, eigenvalues)
Driver routine to diagonalize a FM matrix with the cuSOLVERMp library.
subroutine, public offload_activate_chosen_device()
Activates the device selected via offload_set_chosen_device()
integer function, public offload_get_chosen_device()
Returns the chosen device.