From b69c08538fa2cf11bb42b3d0e9a3c2b35fcae5af Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 23 Jan 2026 20:25:54 +0800 Subject: [PATCH 1/3] Refactor: Encapsulate timer functionality in timer_wrapper.h --- source/source_base/timer_wrapper.h | 56 +++++++++++++++++++ source/source_esolver/esolver_fp.h | 10 +--- source/source_esolver/esolver_ks.cpp | 15 +---- source/source_esolver/esolver_of.cpp | 14 ++--- source/source_esolver/esolver_of_tddft.cpp | 6 +- .../source_pw/module_ofdft/of_print_info.cpp | 32 +++-------- source/source_pw/module_ofdft/of_print_info.h | 20 +++---- 7 files changed, 84 insertions(+), 69 deletions(-) create mode 100644 source/source_base/timer_wrapper.h diff --git a/source/source_base/timer_wrapper.h b/source/source_base/timer_wrapper.h new file mode 100644 index 0000000000..6da3f391e3 --- /dev/null +++ b/source/source_base/timer_wrapper.h @@ -0,0 +1,56 @@ +#ifndef TIMER_WRAPPER_H +#define TIMER_WRAPPER_H + +#include + +#ifdef __MPI +#include +#endif + +namespace ModuleBase { + +/** + * @brief Time point type that works in both MPI and non-MPI environments + */ +typedef double TimePoint; + +/** + * @brief Get current time as a TimePoint + * + * @return TimePoint Current time + */ +inline TimePoint get_time() +{ +#ifdef __MPI + int is_initialized = 0; + MPI_Initialized(&is_initialized); + if (is_initialized) + { + return MPI_Wtime(); + } + else + { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count() / 1e6; + } +#else + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count() / 1e6; +#endif +} + +/** + * @brief Calculate duration between two TimePoints in seconds + * + * @param start Start time point + * @param end End time point + * @return double Duration in seconds + */ +inline double get_duration(const TimePoint& start, const TimePoint& end) +{ + return end - start; +} + +} + +#endif // TIMER_WRAPPER_H \ No newline at end of file diff --git a/source/source_esolver/esolver_fp.h b/source/source_esolver/esolver_fp.h index b2bb8f065e..94faa31e74 100644 --- a/source/source_esolver/esolver_fp.h +++ b/source/source_esolver/esolver_fp.h @@ -3,9 +3,7 @@ #include "esolver.h" -#ifndef __MPI -#include -#endif +#include "source_base/timer_wrapper.h" #include "source_basis/module_pw/pw_basis.h" // plane wave basis #include "source_cell/module_symmetry/symmetry.h" // symmetry analysis @@ -83,11 +81,7 @@ class ESolver_FP: public ESolver bool pw_rho_flag = false; ///< flag for pw_rho, 0: not initialized, 1: initialized //! the start time of scf iteration - #ifdef __MPI - double iter_time; - #else - std::chrono::system_clock::time_point iter_time; - #endif + ModuleBase::TimePoint iter_time; }; } // namespace ModuleESolver diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 166e7b3fb9..8c0c651172 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -1,4 +1,5 @@ #include "esolver_ks.h" +#include "source_base/timer_wrapper.h" // for jason output information #include "source_io/json_output/init_info.h" @@ -190,11 +191,7 @@ void ESolver_KS::iter_init(UnitCell& ucell, const int istep, const in ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname); } -#ifdef __MPI - iter_time = MPI_Wtime(); -#else - iter_time = std::chrono::system_clock::now(); -#endif + iter_time = ModuleBase::get_time(); if (PARAM.inp.esolver_type == "ksdft") { @@ -281,13 +278,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i // the end, print time -#ifdef __MPI - double duration = (double)(MPI_Wtime() - iter_time); -#else - double duration - = (std::chrono::duration_cast(std::chrono::system_clock::now() - iter_time)).count() - / static_cast(1e6); -#endif + double duration = ModuleBase::get_duration(iter_time, ModuleBase::get_time()); // print energies elecstate::print_etot(ucell.magnet, *pelec, conv_esolver, iter, drho, diff --git a/source/source_esolver/esolver_of.cpp b/source/source_esolver/esolver_of.cpp index 4debfde4d5..b17cf6fb9d 100644 --- a/source/source_esolver/esolver_of.cpp +++ b/source/source_esolver/esolver_of.cpp @@ -27,10 +27,10 @@ ESolver_OF::ESolver_OF() ESolver_OF::~ESolver_OF() { - //**************************************************** - // do not add any codes in this deconstructor funcion - //**************************************************** - delete psi_; + //**************************************************** + // do not add any codes in this deconstructor funcion + //**************************************************** + delete psi_; delete[] this->pphi_; for (int i = 0; i < PARAM.inp.nspin; ++i) @@ -137,11 +137,7 @@ void ESolver_OF::runner(UnitCell& ucell, const int istep) this->iter_ = 0; bool conv_esolver = false; // this conv_esolver is added by mohan 20250302 -#ifdef __MPI - this->iter_time = MPI_Wtime(); -#else - this->iter_time = std::chrono::system_clock::now(); -#endif + this->iter_time = ModuleBase::get_time(); while (true) { diff --git a/source/source_esolver/esolver_of_tddft.cpp b/source/source_esolver/esolver_of_tddft.cpp index daeda628cb..12a398a2f7 100644 --- a/source/source_esolver/esolver_of_tddft.cpp +++ b/source/source_esolver/esolver_of_tddft.cpp @@ -41,11 +41,7 @@ void ESolver_OF_TDDFT::runner(UnitCell& ucell, const int istep) this->iter_ = 0; bool conv_esolver = false; // this conv_esolver is added by mohan 20250302 -#ifdef __MPI - this->iter_time = MPI_Wtime(); -#else - this->iter_time = std::chrono::system_clock::now(); -#endif + this->iter_time = ModuleBase::get_time(); if (this->phi_td.empty()) { diff --git a/source/source_pw/module_ofdft/of_print_info.cpp b/source/source_pw/module_ofdft/of_print_info.cpp index ea411bcb1b..fa19083dcd 100644 --- a/source/source_pw/module_ofdft/of_print_info.cpp +++ b/source/source_pw/module_ofdft/of_print_info.cpp @@ -8,17 +8,13 @@ * and write the components of the total energy into running_log. */ void OFDFT::print_info(const int iter, - #ifdef __MPI - double &iter_time, - #else - std::chrono::system_clock::time_point &iter_time, - #endif - const double &energy_current, - const double &energy_last, - const double &normdLdphi, - const elecstate::ElecState *pelec, - KEDF_Manager *kedf_manager, - const bool conv_esolver) + ModuleBase::TimePoint &iter_time, + const double &energy_current, + const double &energy_last, + const double &normdLdphi, + const elecstate::ElecState *pelec, + KEDF_Manager *kedf_manager, + const bool conv_esolver) { if (iter == 0) { @@ -35,13 +31,7 @@ void OFDFT::print_info(const int iter, {"tn", "TN"} }; std::string iteration = prefix_map[PARAM.inp.of_method] + std::to_string(iter); -#ifdef __MPI - double duration = (double)(MPI_Wtime() - iter_time); -#else - double duration - = (std::chrono::duration_cast(std::chrono::system_clock::now() - iter_time)).count() - / static_cast(1e6); -#endif + double duration = ModuleBase::get_duration(iter_time, ModuleBase::get_time()); std::cout << " " << std::setw(8) << iteration << std::setw(18) << std::scientific << std::setprecision(8) << energy_current * ModuleBase::Ry_to_eV << std::setw(18) << (energy_current - energy_last) * ModuleBase::Ry_to_eV @@ -141,9 +131,5 @@ void OFDFT::print_info(const int iter, GlobalV::ofs_running << table.str() << std::endl; // reset the iter_time for the next iteration -#ifdef __MPI - iter_time = MPI_Wtime(); -#else - iter_time = std::chrono::system_clock::now(); -#endif + iter_time = ModuleBase::get_time(); } diff --git a/source/source_pw/module_ofdft/of_print_info.h b/source/source_pw/module_ofdft/of_print_info.h index b60eeb69db..dd45e6bbc6 100644 --- a/source/source_pw/module_ofdft/of_print_info.h +++ b/source/source_pw/module_ofdft/of_print_info.h @@ -4,24 +4,20 @@ #include "source_estate/elecstate.h" // electronic states #include "source_pw/module_ofdft/kedf_manager.h" -#include +#include "source_base/timer_wrapper.h" namespace OFDFT { void print_info(const int iter, - #ifdef __MPI - double &iter_time, - #else - std::chrono::system_clock::time_point &iter_time, - #endif - const double &energy_current, - const double &energy_last, - const double &normdLdphi, - const elecstate::ElecState *pelec, - KEDF_Manager *kedf_manager, - const bool conv_esolver); + ModuleBase::TimePoint &iter_time, + const double &energy_current, + const double &energy_last, + const double &normdLdphi, + const elecstate::ElecState *pelec, + KEDF_Manager *kedf_manager, + const bool conv_esolver); } From 382926887a20b868df74a8eefef3b299c232fb1a Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Fri, 23 Jan 2026 20:54:14 +0800 Subject: [PATCH 2/3] Refactor timer code and clean_esolver function 1. Remove #ifdef __MPI from timer code, encapsulate in timer_wrapper.h 2. Move ESolver clean logic to after_all_runners method 3. Replace clean_esolver calls with direct delete p_esolver 4. Remove #ifdef __MPI from delete p_esolver 5. Add Cblacs_exit(1) in after_all_runners for LCAO calculations --- source/source_esolver/esolver.cpp | 51 ++++++++--------------- source/source_esolver/esolver.h | 2 +- source/source_esolver/esolver_ks_lcao.cpp | 28 ++++++++----- source/source_main/driver_run.cpp | 8 +--- 4 files changed, 39 insertions(+), 50 deletions(-) diff --git a/source/source_esolver/esolver.cpp b/source/source_esolver/esolver.cpp index 2d89673313..4809c6df77 100644 --- a/source/source_esolver/esolver.cpp +++ b/source/source_esolver/esolver.cpp @@ -311,23 +311,23 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell) // of LR-TDDFT is implemented. std::cout << " PREPARING FOR EXCITED STATES." << std::endl; // initialize the 2nd ESolver_LR at the temporary pointer - ModuleESolver::ESolver* p_esolver_lr = nullptr; - if (PARAM.globalv.gamma_only_local) - { - p_esolver_lr = new LR::ESolver_LR( - std::move(*dynamic_cast*>(p_esolver)), - inp, - ucell); - } - else - { - p_esolver_lr = new LR::ESolver_LR, double>( - std::move(*dynamic_cast, double>*>(p_esolver)), - inp, - ucell); - } - // clean the 1st ESolver_KS and swap the pointer - ModuleESolver::clean_esolver(p_esolver, false); // do not call Cblacs_exit, remain it for the 2nd ESolver + ModuleESolver::ESolver* p_esolver_lr = nullptr; + if (PARAM.globalv.gamma_only_local) + { + p_esolver_lr = new LR::ESolver_LR( + std::move(*dynamic_cast*>(p_esolver)), + inp, + ucell); + } + else + { + p_esolver_lr = new LR::ESolver_LR, double>( + std::move(*dynamic_cast, double>*>(p_esolver)), + inp, + ucell); + } + // clean the 1st ESolver_KS and swap the pointer + delete p_esolver; return p_esolver_lr; } #endif @@ -355,20 +355,5 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell) + " line " + std::to_string(__LINE__)); } -void clean_esolver(ESolver*& pesolver, const bool lcao_cblacs_exit) -{ -// Zhang Xiaoyang modified in 2024/7/6: -// Note: because of the init method of serial lcao hsolver -// it needs no release step for it, or this [delete] will cause Segmentation Fault -// Probably it will be modified later. -#ifdef __MPI - delete pesolver; -#ifdef __LCAO - if (lcao_cblacs_exit) - { - Cblacs_exit(1); - } -#endif -#endif -} + } // namespace ModuleESolver diff --git a/source/source_esolver/esolver.h b/source/source_esolver/esolver.h index 6716ea0c96..dd621cfe15 100644 --- a/source/source_esolver/esolver.h +++ b/source/source_esolver/esolver.h @@ -69,7 +69,7 @@ std::string determine_type(); */ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell); -void clean_esolver(ESolver*& pesolver, const bool lcao_cblacs_exit = false); + } // namespace ModuleESolver diff --git a/source/source_esolver/esolver_ks_lcao.cpp b/source/source_esolver/esolver_ks_lcao.cpp index 3a2fb57496..47b6648954 100644 --- a/source/source_esolver/esolver_ks_lcao.cpp +++ b/source/source_esolver/esolver_ks_lcao.cpp @@ -293,17 +293,25 @@ void ESolver_KS_LCAO::after_all_runners(UnitCell& ucell) ESolver_KS::after_all_runners(ucell); auto* hamilt_lcao = dynamic_cast*>(this->p_hamilt); - if(!hamilt_lcao) - { - ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::after_all_runners","p_hamilt does not exist"); - } + if(!hamilt_lcao) + { + ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::after_all_runners","p_hamilt does not exist"); + } - ModuleIO::ctrl_runner_lcao(ucell, - PARAM.inp, this->kv, this->pelec, this->dmat, this->pv, this->Pgrid, - this->gd, this->psi, this->chr, hamilt_lcao, - this->two_center_bundle_, - this->orb_, this->pw_rho, this->pw_rhod, - this->sf, this->locpp.vloc, this->exx_nao, this->solvent); + ModuleIO::ctrl_runner_lcao(ucell, + PARAM.inp, this->kv, this->pelec, this->dmat, this->pv, this->Pgrid, + this->gd, this->psi, this->chr, hamilt_lcao, + this->two_center_bundle_, + this->orb_, this->pw_rho, this->pw_rhod, + this->sf, this->locpp.vloc, this->exx_nao, this->solvent); + + +#ifdef __MPI +#ifdef __LCAO + // Exit BLACS environment for LCAO calculations + Cblacs_exit(1); +#endif +#endif ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners"); } diff --git a/source/source_main/driver_run.cpp b/source/source_main/driver_run.cpp index 990aa56751..895b06bf57 100644 --- a/source/source_main/driver_run.cpp +++ b/source/source_main/driver_run.cpp @@ -90,11 +90,6 @@ void Driver::driver_run() else if (cal == "get_pchg" || cal == "get_wf" || cal == "gen_bessel" || cal == "gen_opt_abfs" || cal == "test_memory" || cal == "test_neighbour") { - //! supported "other" functions: - //! get_pchg(LCAO), - //! test_memory(PW,LCAO), - //! test_neighbour(LCAO), - //! gen_bessel(PW), et al. const int istep = 0; p_esolver->others(ucell, istep); } @@ -106,7 +101,8 @@ void Driver::driver_run() //! 5: clean up esolver p_esolver->after_all_runners(ucell); - ModuleESolver::clean_esolver(p_esolver); + delete p_esolver; + this->finalize_hardware(); //! 6: output the json file From c2767b78c88a10def49984fe5aef60ae86b87067 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Sat, 24 Jan 2026 16:45:52 +0800 Subject: [PATCH 3/3] Refactor: Move heterogeneous parallel code to source_base/module_device --- .../source_base/module_device/device_check.h | 182 ++++++++++++++++ source/source_pw/module_pwdft/global.h | 196 +----------------- 2 files changed, 183 insertions(+), 195 deletions(-) create mode 100644 source/source_base/module_device/device_check.h diff --git a/source/source_base/module_device/device_check.h b/source/source_base/module_device/device_check.h new file mode 100644 index 0000000000..a708cc1d7f --- /dev/null +++ b/source/source_base/module_device/device_check.h @@ -0,0 +1,182 @@ +#ifndef DEVICE_CHECK_H +#define DEVICE_CHECK_H + +#include + +#ifdef __CUDA +#include "cublas_v2.h" +#include "cufft.h" +#include "source_base/module_device/cuda_compat.h" + +static const char* _cublasGetErrorString(cublasStatus_t error) +{ + switch (error) + { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + } + return ""; +} + +#define CHECK_CUDA(func) \ + { \ + cudaError_t status = (func); \ + if (status != cudaSuccess) \ + { \ + printf("In File %s : CUDA API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + cudaGetErrorString(status), status); \ + } \ + } + +#define CHECK_CUBLAS(func) \ + { \ + cublasStatus_t status = (func); \ + if (status != CUBLAS_STATUS_SUCCESS) \ + { \ + printf("In File %s : CUBLAS API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _cublasGetErrorString(status), status); \ + } \ + } + +#define CHECK_CUSOLVER(func) \ + { \ + cusolverStatus_t status = (func); \ + if (status != CUSOLVER_STATUS_SUCCESS) \ + { \ + printf("In File %s : CUSOLVER API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _cusolverGetErrorString(status), status); \ + } \ + } + +#define CHECK_CUFFT(func) \ + { \ + cufftResult_t status = (func); \ + if (status != CUFFT_SUCCESS) \ + { \ + printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + ModuleBase::cuda_compat::cufftGetErrorStringCompat(status), status); \ + } \ + } +#endif // __CUDA + +#ifdef __ROCM +#include +#include +#include + +static const char* _hipblasGetErrorString(hipblasStatus_t error) +{ + switch (error) + { + case HIPBLAS_STATUS_SUCCESS: + return "HIPBLAS_STATUS_SUCCESS"; + case HIPBLAS_STATUS_NOT_INITIALIZED: + return "HIPBLAS_STATUS_NOT_INITIALIZED"; + case HIPBLAS_STATUS_ALLOC_FAILED: + return "HIPBLAS_STATUS_ALLOC_FAILED"; + case HIPBLAS_STATUS_INVALID_VALUE: + return "HIPBLAS_STATUS_INVALID_VALUE"; + case HIPBLAS_STATUS_ARCH_MISMATCH: + return "HIPBLAS_STATUS_ARCH_MISMATCH"; + case HIPBLAS_STATUS_MAPPING_ERROR: + return "HIPBLAS_STATUS_MAPPING_ERROR"; + case HIPBLAS_STATUS_EXECUTION_FAILED: + return "HIPBLAS_STATUS_EXECUTION_FAILED"; + case HIPBLAS_STATUS_INTERNAL_ERROR: + return "HIPBLAS_STATUS_INTERNAL_ERROR"; + case HIPBLAS_STATUS_NOT_SUPPORTED: + return "HIPBLAS_STATUS_NOT_SUPPORTED"; + case HIPBLAS_STATUS_HANDLE_IS_NULLPTR: + return "HIPBLAS_STATUS_HANDLE_IS_NULLPTR"; + default: + return ""; + } + return ""; +} + +static const char* _hipfftGetErrorString(hipfftResult_t error) +{ + switch (error) + { + case HIPFFT_SUCCESS: + return "HIPFFT_SUCCESS"; + case HIPFFT_INVALID_PLAN: + return "HIPFFT_INVALID_PLAN"; + case HIPFFT_ALLOC_FAILED: + return "HIPFFT_ALLOC_FAILED"; + case HIPFFT_INVALID_TYPE: + return "HIPFFT_INVALID_TYPE"; + case HIPFFT_INVALID_VALUE: + return "HIPFFT_INVALID_VALUE"; + case HIPFFT_INTERNAL_ERROR: + return "HIPFFT_INTERNAL_ERROR"; + case HIPFFT_EXEC_FAILED: + return "HIPFFT_EXEC_FAILED"; + case HIPFFT_SETUP_FAILED: + return "HIPFFT_SETUP_FAILED"; + case HIPFFT_INVALID_SIZE: + return "HIPFFT_INVALID_SIZE"; + case HIPFFT_UNALIGNED_DATA: + return "HIPFFT_UNALIGNED_DATA"; + case HIPFFT_INCOMPLETE_PARAMETER_LIST: + return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; + case HIPFFT_INVALID_DEVICE: + return "HIPFFT_INVALID_DEVICE"; + case HIPFFT_PARSE_ERROR: + return "HIPFFT_PARSE_ERROR"; + case HIPFFT_NO_WORKSPACE: + return "HIPFFT_NO_WORKSPACE"; + case HIPFFT_NOT_IMPLEMENTED: + return "HIPFFT_NOT_IMPLEMENTED"; + case HIPFFT_NOT_SUPPORTED: + return "HIPFFT_NOT_SUPPORTED"; + } + return ""; +} + +#define CHECK_CUDA(func) \ + { \ + hipError_t status = (func); \ + if (status != hipSuccess) \ + { \ + printf("In File %s : HIP API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + hipGetErrorString(status), status); \ + } \ + } + +#define CHECK_CUBLAS(func) \ + { \ + hipblasStatus_t status = (func); \ + if (status != HIPBLAS_STATUS_SUCCESS) \ + { \ + printf("In File %s : HIPBLAS API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _hipblasGetErrorString(status), status); \ + } \ + } + +#define CHECK_CUFFT(func) \ + { \ + hipfftResult_t status = (func); \ + if (status != HIPFFT_SUCCESS) \ + { \ + printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + _hipfftGetErrorString(status), status); \ + } \ + } +#endif // __ROCM + +#endif // DEVICE_CHECK_H \ No newline at end of file diff --git a/source/source_pw/module_pwdft/global.h b/source/source_pw/module_pwdft/global.h index bea93a9331..c8190386b9 100644 --- a/source/source_pw/module_pwdft/global.h +++ b/source/source_pw/module_pwdft/global.h @@ -13,201 +13,7 @@ #endif #include "source_estate/magnetism.h" #include "source_hamilt/module_xc/xc_functional.h" -#ifdef __CUDA -#include "cublas_v2.h" -#include "cufft.h" -#include "source_base/module_device/cuda_compat.h" - -static const char* _cublasGetErrorString(cublasStatus_t error) -{ - switch (error) - { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - } - return ""; -} - -#define CHECK_CUDA(func) \ - { \ - cudaError_t status = (func); \ - if (status != cudaSuccess) \ - { \ - printf("In File %s : CUDA API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - cudaGetErrorString(status), status); \ - } \ - } - -#define CHECK_CUBLAS(func) \ - { \ - cublasStatus_t status = (func); \ - if (status != CUBLAS_STATUS_SUCCESS) \ - { \ - printf("In File %s : CUBLAS API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _cublasGetErrorString(status), status); \ - } \ - } - -#define CHECK_CUSOLVER(func) \ - { \ - cusolverStatus_t status = (func); \ - if (status != CUSOLVER_STATUS_SUCCESS) \ - { \ - printf("In File %s : CUSOLVER API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _cusolverGetErrorString(status), status); \ - } \ - } - -#define CHECK_CUFFT(func) \ - { \ - cufftResult_t status = (func); \ - if (status != CUFFT_SUCCESS) \ - { \ - printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - ModuleBase::cuda_compat::cufftGetErrorStringCompat(status), status); \ - } \ - } -#endif // __CUDA - -#ifdef __ROCM -#include -#include -#include - -static const char* _hipblasGetErrorString(hipblasStatus_t error) -{ - switch (error) - { - case HIPBLAS_STATUS_SUCCESS: - return "HIPBLAS_STATUS_SUCCESS"; - case HIPBLAS_STATUS_NOT_INITIALIZED: - return "HIPBLAS_STATUS_NOT_INITIALIZED"; - case HIPBLAS_STATUS_ALLOC_FAILED: - return "HIPBLAS_STATUS_ALLOC_FAILED"; - case HIPBLAS_STATUS_INVALID_VALUE: - return "HIPBLAS_STATUS_INVALID_VALUE"; - case HIPBLAS_STATUS_ARCH_MISMATCH: - return "HIPBLAS_STATUS_ARCH_MISMATCH"; - case HIPBLAS_STATUS_MAPPING_ERROR: - return "HIPBLAS_STATUS_MAPPING_ERROR"; - case HIPBLAS_STATUS_EXECUTION_FAILED: - return "HIPBLAS_STATUS_EXECUTION_FAILED"; - case HIPBLAS_STATUS_INTERNAL_ERROR: - return "HIPBLAS_STATUS_INTERNAL_ERROR"; - case HIPBLAS_STATUS_NOT_SUPPORTED: - return "HIPBLAS_STATUS_NOT_SUPPORTED"; - case HIPBLAS_STATUS_HANDLE_IS_NULLPTR: - return "HIPBLAS_STATUS_HANDLE_IS_NULLPTR"; - default: - return ""; - } - return ""; -} - -// static const char *_rocsolverGetErrorString(rocsolver_status error) -// { -// switch (error) -// { -// // case ROCSOLVER_STATUS_SUCCESS: -// // return "CUSOLVER_STATUS_SUCCESS"; -// } -// return ""; -// } - -static const char* _hipfftGetErrorString(hipfftResult_t error) -{ - switch (error) - { - case HIPFFT_SUCCESS: - return "HIPFFT_SUCCESS"; - case HIPFFT_INVALID_PLAN: - return "HIPFFT_INVALID_PLAN"; - case HIPFFT_ALLOC_FAILED: - return "HIPFFT_ALLOC_FAILED"; - case HIPFFT_INVALID_TYPE: - return "HIPFFT_INVALID_TYPE"; - case HIPFFT_INVALID_VALUE: - return "HIPFFT_INVALID_VALUE"; - case HIPFFT_INTERNAL_ERROR: - return "HIPFFT_INTERNAL_ERROR"; - case HIPFFT_EXEC_FAILED: - return "HIPFFT_EXEC_FAILED"; - case HIPFFT_SETUP_FAILED: - return "HIPFFT_SETUP_FAILED"; - case HIPFFT_INVALID_SIZE: - return "HIPFFT_INVALID_SIZE"; - case HIPFFT_UNALIGNED_DATA: - return "HIPFFT_UNALIGNED_DATA"; - case HIPFFT_INCOMPLETE_PARAMETER_LIST: - return "HIPFFT_INCOMPLETE_PARAMETER_LIST"; - case HIPFFT_INVALID_DEVICE: - return "HIPFFT_INVALID_DEVICE"; - case HIPFFT_PARSE_ERROR: - return "HIPFFT_PARSE_ERROR"; - case HIPFFT_NO_WORKSPACE: - return "HIPFFT_NO_WORKSPACE"; - case HIPFFT_NOT_IMPLEMENTED: - return "HIPFFT_NOT_IMPLEMENTED"; - case HIPFFT_NOT_SUPPORTED: - return "HIPFFT_NOT_SUPPORTED"; - } - return ""; -} - -#define CHECK_CUDA(func) \ - { \ - hipError_t status = (func); \ - if (status != hipSuccess) \ - { \ - printf("In File %s : HIP API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - hipGetErrorString(status), status); \ - } \ - } - -#define CHECK_CUBLAS(func) \ - { \ - hipblasStatus_t status = (func); \ - if (status != HIPBLAS_STATUS_SUCCESS) \ - { \ - printf("In File %s : HIPBLAS API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _hipblasGetErrorString(status), status); \ - } \ - } - -// #define CHECK_CUSOLVER(func)\ -// {\ -// rocsolver_status status = (func);\ -// if(status != CUSOLVER_STATUS_SUCCESS)\ -// {\ -// printf("In File %s : CUSOLVER API failed at line %d with error: %s (%d)\n",\ -// __FILE__, __LINE__, _rocsolverGetErrorString(status), status);\ -// }\ -// } - -#define CHECK_CUFFT(func) \ - { \ - hipfftResult_t status = (func); \ - if (status != HIPFFT_SUCCESS) \ - { \ - printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ - _hipfftGetErrorString(status), status); \ - } \ - } -#endif // __ROCM +#include "source_base/module_device/device_check.h" //========================================================== // EXPLAIN : define "GLOBAL CLASS"