From b0cf81f29fdaf8ae2aec76b6224e12edac2afbd1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 May 2023 09:59:57 +0000 Subject: [PATCH 1/5] Test: add unit test for elecstate_pw.cpp --- source/module_elecstate/elecstate_getters.cpp | 10 ++ source/module_elecstate/elecstate_getters.h | 4 + source/module_elecstate/elecstate_pw.cpp | 35 +++-- source/module_elecstate/test/CMakeLists.txt | 20 +++ .../test/elecstate_pw_test.cpp | 139 ++++++++++++++++++ 5 files changed, 197 insertions(+), 11 deletions(-) create mode 100644 source/module_elecstate/test/elecstate_pw_test.cpp diff --git a/source/module_elecstate/elecstate_getters.cpp b/source/module_elecstate/elecstate_getters.cpp index 51e2c6bd47..33297a6bd1 100644 --- a/source/module_elecstate/elecstate_getters.cpp +++ b/source/module_elecstate/elecstate_getters.cpp @@ -12,6 +12,16 @@ double get_ucell_omega() return GlobalC::ucell.omega; } +double get_ucell_tpiba() +{ + return GlobalC::ucell.tpiba; +} + +int get_xc_func_type() +{ + return XC_Functional::get_func_type(); +} + std::string get_input_vdw_method() { return INPUT.vdw_method; diff --git a/source/module_elecstate/elecstate_getters.h b/source/module_elecstate/elecstate_getters.h index a14198305e..567ecaf05b 100644 --- a/source/module_elecstate/elecstate_getters.h +++ b/source/module_elecstate/elecstate_getters.h @@ -8,6 +8,10 @@ namespace elecstate /// @brief get the value of GlobalC::ucell.omega double get_ucell_omega(); +/// @brief get the value of GlobalC::ucell.tpiba +double get_ucell_tpiba(); +/// @brief get the value of XC_Functional::func_type +int get_xc_func_type(); /// @brief get the value of INPUT.vdw_method std::string get_input_vdw_method(); /// @brief get the value of GlobalC::ucell.magnet.tot_magnetization diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 3de7e8a907..99e7fdaf24 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -1,8 +1,8 @@ #include "elecstate_pw.h" +#include "elecstate_getters.h" #include "module_base/constants.h" #include "module_base/parallel_reduce.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_base/timer.h" #include "module_psi/kernels/device.h" @@ -20,7 +20,8 @@ ElecStatePW::~ElecStatePW() { if (psi::device::get_device_type(this->ctx) == psi::GpuDevice) { delmem_var_op()(this->ctx, this->rho_data); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { delmem_var_op()(this->ctx, this->kin_r_data); } } @@ -37,7 +38,8 @@ void ElecStatePW::init_rho_data() for (int ii = 0; ii < this->charge->nspin; ii++) { this->rho[ii] = this->rho_data + ii * this->charge->nrxx; } - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { this->kin_r = new FPTYPE*[this->charge->nspin]; resmem_var_op()(this->ctx, this->kin_r_data, this->charge->nspin * this->charge->nrxx); for (int ii = 0; ii < this->charge->nspin; ii++) { @@ -47,7 +49,8 @@ void ElecStatePW::init_rho_data() } else { this->rho = reinterpret_cast(this->charge->rho); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { this->kin_r = reinterpret_cast(this->charge->kin_r); } } @@ -74,8 +77,8 @@ void ElecStatePW::psiToRho(const psi::Psi, // denghui replaced at 20221110 // ModuleBase::GlobalFunc::ZEROS(this->rho[is], this->charge->nrxx); setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) - { + if (get_xc_func_type() == 3) + { // ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx); } @@ -89,7 +92,8 @@ void ElecStatePW::psiToRho(const psi::Psi, if (GlobalV::device_flag == "gpu" || GlobalV::precision_flag == "single") { for (int ii = 0; ii < GlobalV::NSPIN; ii++) { castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx); } } @@ -161,7 +165,7 @@ void ElecStatePW::rhoBandK(const psi::Psi, this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); - const auto w1 = static_cast(this->wg(ik, ibnd) / GlobalC::ucell.omega); + const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); // replaced by denghui at 20221110 elecstate_pw_op()(this->ctx, GlobalV::DOMAG, GlobalV::DOMAG_Z, this->charge->nrxx, w1, this->rho, this->wfcr, this->wfcr_another_spin); @@ -180,7 +184,7 @@ void ElecStatePW::rhoBandK(const psi::Psi, this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); - const auto w1 = static_cast(this->wg(ik, ibnd) / GlobalC::ucell.omega); + const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); if (w1 != 0.0) { @@ -189,13 +193,22 @@ void ElecStatePW::rhoBandK(const psi::Psi, } // kinetic energy density - if (XC_Functional::get_func_type() == 3) + if (get_xc_func_type() == 3) { for (int j = 0; j < 3; j++) { setmem_complex_op()(this->ctx, this->wfcr, 0, this->charge->nrxx); - meta_op()(this->ctx, ik, j, npw, this->basis->npwk_max, static_cast(GlobalC::ucell.tpiba), this->basis->template get_gcar_data(), this->basis->template get_kvec_c_data(), &psi(ibnd, 0), this->wfcr); + meta_op()(this->ctx, + ik, + j, + npw, + this->basis->npwk_max, + static_cast(get_ucell_tpiba()), + this->basis->template get_gcar_data(), + this->basis->template get_kvec_c_data(), + &psi(ibnd, 0), + this->wfcr); this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); diff --git a/source/module_elecstate/test/CMakeLists.txt b/source/module_elecstate/test/CMakeLists.txt index 2c205320ba..e9832b891e 100644 --- a/source/module_elecstate/test/CMakeLists.txt +++ b/source/module_elecstate/test/CMakeLists.txt @@ -40,3 +40,23 @@ AddTest( LIBS ${math_libs} base device SOURCES elecstate_base_test.cpp ../elecstate.cpp ../occupy.cpp ../../module_psi/psi.cpp ) + +AddTest( + TARGET elecstate_pw + LIBS ${math_libs} base device + SOURCES elecstate_pw_test.cpp + ../elecstate_pw.cpp + ../elecstate.cpp + ../occupy.cpp + ../../module_psi/psi.cpp + ../../module_basis/module_pw/pw_basis_k.cpp + ../../module_basis/module_pw/pw_basis.cpp + ../../module_basis/module_pw/pw_init.cpp + ../../module_basis/module_pw/pw_distributeg.cpp + ../../module_basis/module_pw/pw_distributer.cpp + ../../module_basis/module_pw/pw_distributeg_method1.cpp + ../../module_basis/module_pw/pw_distributeg_method2.cpp + ../../module_basis/module_pw/pw_transform_k.cpp + ../../module_basis/module_pw/fft.cpp + ../../module_psi/kernels/memory_op.cpp +) \ No newline at end of file diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp new file mode 100644 index 0000000000..cb4938a5f3 --- /dev/null +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -0,0 +1,139 @@ +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#define protected public +#include "module_elecstate/elecstate_pw.h" + +// mock functions for testing +namespace elecstate +{ +double get_ucell_omega() +{ + return 500.0; +} +double get_ucell_tpiba() +{ + return 2.0; +} +int get_xc_func_type() +{ + return 1; +} +void Potential::init_pot(int, Charge const*) +{ +} +void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) +{ +} +void Potential::cal_fixed_v(double* vl_pseudo) +{ +} +Potential::~Potential() +{ +} +} // namespace elecstate +Charge::Charge() +{ +} +Charge::~Charge() +{ +} +K_Vectors::K_Vectors() +{ +} +K_Vectors::~K_Vectors() +{ +} +void Charge::set_rho_core(ModuleBase::ComplexMatrix const&) +{ +} +void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&) +{ +} +void Charge::set_rhopw(ModulePW::PW_Basis*) +{ +} +void Charge::renormalize_rho() +{ +} + +/************************************************ + * unit test of elecstate_pw.cpp + ***********************************************/ + +/** + * - Tested Functions: + * - InitNelecSpin: + */ + +void Set_GlobalV_Default() +{ + GlobalV::device_flag = "cpu"; + GlobalV::precision_flag = "double"; + GlobalV::DOMAG = false; + GlobalV::DOMAG_Z = false; + // Base class dependent + GlobalV::NSPIN = 1; + GlobalV::nelec = 10.0; + GlobalV::nupdown = 0.0; + GlobalV::TWO_EFERMI = false; + GlobalV::NBANDS = 6; + GlobalV::NLOCAL = 6; + GlobalV::ESOLVER_TYPE = "ksdft"; + GlobalV::LSPINORB = false; + GlobalV::BASIS_TYPE = "pw"; + GlobalV::md_prec_level = 0; + GlobalV::KPAR = 1; + GlobalV::NPROC_IN_POOL = 1; +} + +class ElecStatePWTest : public ::testing::Test +{ + protected: + elecstate::ElecStatePW* elecstate_pw_d; + elecstate::ElecStatePW* elecstate_pw_f; + ModulePW::PW_Basis_K* wfcpw; + Charge* chg; + K_Vectors* klist; + ModulePW::PW_Basis* rhopw; + ModulePW::PW_Basis_Big* bigpw; + void SetUp() override + { + Set_GlobalV_Default(); + wfcpw = new ModulePW::PW_Basis_K; + chg = new Charge; + klist = new K_Vectors; + klist->nks = 5; + rhopw = new ModulePW::PW_Basis; + bigpw = new ModulePW::PW_Basis_Big; + } + + void TearDown() override + { + delete wfcpw; + delete chg; + delete klist; + delete rhopw; + delete elecstate_pw_d; + delete elecstate_pw_f; + } +}; + +TEST_F(ElecStatePWTest, Constructor) +{ + // test constructor + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_EQ(elecstate_pw_d->classname, "ElecStatePW"); + EXPECT_EQ(elecstate_pw_d->charge, chg); + EXPECT_EQ(elecstate_pw_d->klist, klist); + EXPECT_EQ(elecstate_pw_d->bigpw, bigpw); + elecstate_pw_f = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_EQ(elecstate_pw_f->classname, "ElecStatePW"); + EXPECT_EQ(elecstate_pw_f->charge, chg); + EXPECT_EQ(elecstate_pw_f->klist, klist); + EXPECT_EQ(elecstate_pw_f->bigpw, bigpw); +} + +#undef protected \ No newline at end of file From 83570e58cf8ba7e841654c4e4a179d776c891313 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 May 2023 02:24:50 +0000 Subject: [PATCH 2/5] fix compiling error --- source/module_elecstate/test/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/module_elecstate/test/CMakeLists.txt b/source/module_elecstate/test/CMakeLists.txt index e9832b891e..806aeb2361 100644 --- a/source/module_elecstate/test/CMakeLists.txt +++ b/source/module_elecstate/test/CMakeLists.txt @@ -1,6 +1,8 @@ remove_definitions(-D__MPI) remove_definitions(-D__EXX) remove_definitions(-D__CUDA) +remove_definitions(-D__UT_USE_CUDA) +remove_definitions(-D__UT_USE_ROCM) remove_definitions(-D__ROCM) remove_definitions(-D__DEEPKS) remove_definitions(-D_OPENMP) @@ -59,4 +61,4 @@ AddTest( ../../module_basis/module_pw/pw_transform_k.cpp ../../module_basis/module_pw/fft.cpp ../../module_psi/kernels/memory_op.cpp -) \ No newline at end of file +) From 6680559cb9a5ab73a172b597243d876ac1095db8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 May 2023 04:42:55 +0000 Subject: [PATCH 3/5] add more unit tests --- .../test/elecstate_pw_test.cpp | 104 +++++++++++++----- 1 file changed, 78 insertions(+), 26 deletions(-) diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp index cb4938a5f3..4c4e2bdc30 100644 --- a/source/module_elecstate/test/elecstate_pw_test.cpp +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -17,9 +17,10 @@ double get_ucell_tpiba() { return 2.0; } +int tmp_xc_func_type = 1; int get_xc_func_type() { - return 1; + return tmp_xc_func_type; } void Potential::init_pot(int, Charge const*) { @@ -59,15 +60,6 @@ void Charge::renormalize_rho() { } -/************************************************ - * unit test of elecstate_pw.cpp - ***********************************************/ - -/** - * - Tested Functions: - * - InitNelecSpin: - */ - void Set_GlobalV_Default() { GlobalV::device_flag = "cpu"; @@ -89,16 +81,29 @@ void Set_GlobalV_Default() GlobalV::NPROC_IN_POOL = 1; } +/************************************************ + * unit test of elecstate_pw.cpp + ***********************************************/ + +/** + * - Tested Functions: + * - Constructor: elecstate::ElecStatePW constructor and destructor + * - including double and single precision versions + * - IinitRhoData: elecstate::ElecStatePW::init_rho_data() + * - get rho and kin_r for ElecStatePW + * - todo: psiToRho: elecstate::ElecStatePW::psiToRho() + */ + class ElecStatePWTest : public ::testing::Test { protected: - elecstate::ElecStatePW* elecstate_pw_d; - elecstate::ElecStatePW* elecstate_pw_f; - ModulePW::PW_Basis_K* wfcpw; - Charge* chg; - K_Vectors* klist; - ModulePW::PW_Basis* rhopw; - ModulePW::PW_Basis_Big* bigpw; + elecstate::ElecStatePW* elecstate_pw_d = nullptr; + elecstate::ElecStatePW* elecstate_pw_s = nullptr; + ModulePW::PW_Basis_K* wfcpw = nullptr; + Charge* chg = nullptr; + K_Vectors* klist = nullptr; + ModulePW::PW_Basis* rhopw = nullptr; + ModulePW::PW_Basis_Big* bigpw = nullptr; void SetUp() override { Set_GlobalV_Default(); @@ -116,24 +121,71 @@ class ElecStatePWTest : public ::testing::Test delete chg; delete klist; delete rhopw; - delete elecstate_pw_d; - delete elecstate_pw_f; + if (elecstate_pw_d != nullptr) + { + delete elecstate_pw_d; + } + if (elecstate_pw_s != nullptr) + { + delete elecstate_pw_s; + } } }; -TEST_F(ElecStatePWTest, Constructor) +TEST_F(ElecStatePWTest, ConstructorDouble) { - // test constructor elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); EXPECT_EQ(elecstate_pw_d->classname, "ElecStatePW"); EXPECT_EQ(elecstate_pw_d->charge, chg); EXPECT_EQ(elecstate_pw_d->klist, klist); EXPECT_EQ(elecstate_pw_d->bigpw, bigpw); - elecstate_pw_f = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); - EXPECT_EQ(elecstate_pw_f->classname, "ElecStatePW"); - EXPECT_EQ(elecstate_pw_f->charge, chg); - EXPECT_EQ(elecstate_pw_f->klist, klist); - EXPECT_EQ(elecstate_pw_f->bigpw, bigpw); +} + +TEST_F(ElecStatePWTest, ConstructorSingle) +{ + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_EQ(elecstate_pw_s->classname, "ElecStatePW"); + EXPECT_EQ(elecstate_pw_s->charge, chg); + EXPECT_EQ(elecstate_pw_s->klist, klist); + EXPECT_EQ(elecstate_pw_s->bigpw, bigpw); +} + +TEST_F(ElecStatePWTest, InitRhoDataDouble) +{ + elecstate::tmp_xc_func_type = 3; + chg->nrxx = 1000; + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + elecstate_pw_d->init_rho_data(); + EXPECT_EQ(elecstate_pw_d->init_rho, true); + EXPECT_EQ(elecstate_pw_d->rho, chg->rho); + EXPECT_EQ(elecstate_pw_d->kin_r, chg->kin_r); +} + +TEST_F(ElecStatePWTest, InitRhoDataSingle) +{ + GlobalV::precision_flag = "single"; + elecstate::tmp_xc_func_type = 3; + chg->nspin = GlobalV::NSPIN; + chg->nrxx = 1000; + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + elecstate_pw_s->init_rho_data(); + EXPECT_EQ(elecstate_pw_s->init_rho, true); + EXPECT_NE(elecstate_pw_s->rho, nullptr); + EXPECT_NE(elecstate_pw_s->kin_r, nullptr); +} + +TEST_F(ElecStatePWTest, ParallelKDouble) +{ + //this is a trivial call due to removing of __MPI + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_NO_THROW(elecstate_pw_d->parallelK()); +} + +TEST_F(ElecStatePWTest, ParallelKSingle) +{ + //this is a trivial call due to removing of __MPI + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_NO_THROW(elecstate_pw_s->parallelK()); } #undef protected \ No newline at end of file From 0963dffd1e2ce86060b6c960a495c70bf8e628fe Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 May 2023 04:44:21 +0000 Subject: [PATCH 4/5] update comments --- source/module_elecstate/test/elecstate_pw_test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp index 4c4e2bdc30..bd9b42f520 100644 --- a/source/module_elecstate/test/elecstate_pw_test.cpp +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -91,6 +91,8 @@ void Set_GlobalV_Default() * - including double and single precision versions * - IinitRhoData: elecstate::ElecStatePW::init_rho_data() * - get rho and kin_r for ElecStatePW + * - ParallelK: elecstate::ElecStatePW::parallelK() + * - trivial call due to removing of __MPI * - todo: psiToRho: elecstate::ElecStatePW::psiToRho() */ From 9fb96a3b631abcddc0d6dfd4bb8a5a5db5243ee2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 May 2023 04:48:04 +0000 Subject: [PATCH 5/5] update comments --- source/module_elecstate/test/elecstate_pw_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp index bd9b42f520..c5a6f5e68f 100644 --- a/source/module_elecstate/test/elecstate_pw_test.cpp +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -89,7 +89,7 @@ void Set_GlobalV_Default() * - Tested Functions: * - Constructor: elecstate::ElecStatePW constructor and destructor * - including double and single precision versions - * - IinitRhoData: elecstate::ElecStatePW::init_rho_data() + * - InitRhoData: elecstate::ElecStatePW::init_rho_data() * - get rho and kin_r for ElecStatePW * - ParallelK: elecstate::ElecStatePW::parallelK() * - trivial call due to removing of __MPI