diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 3de7e8a907..99e7fdaf24 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -1,8 +1,8 @@ #include "elecstate_pw.h" +#include "elecstate_getters.h" #include "module_base/constants.h" #include "module_base/parallel_reduce.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_base/timer.h" #include "module_psi/kernels/device.h" @@ -20,7 +20,8 @@ ElecStatePW::~ElecStatePW() { if (psi::device::get_device_type(this->ctx) == psi::GpuDevice) { delmem_var_op()(this->ctx, this->rho_data); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { delmem_var_op()(this->ctx, this->kin_r_data); } } @@ -37,7 +38,8 @@ void ElecStatePW::init_rho_data() for (int ii = 0; ii < this->charge->nspin; ii++) { this->rho[ii] = this->rho_data + ii * this->charge->nrxx; } - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { this->kin_r = new FPTYPE*[this->charge->nspin]; resmem_var_op()(this->ctx, this->kin_r_data, this->charge->nspin * this->charge->nrxx); for (int ii = 0; ii < this->charge->nspin; ii++) { @@ -47,7 +49,8 @@ void ElecStatePW::init_rho_data() } else { this->rho = reinterpret_cast(this->charge->rho); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { this->kin_r = reinterpret_cast(this->charge->kin_r); } } @@ -74,8 +77,8 @@ void ElecStatePW::psiToRho(const psi::Psi, // denghui replaced at 20221110 // ModuleBase::GlobalFunc::ZEROS(this->rho[is], this->charge->nrxx); setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) - { + if (get_xc_func_type() == 3) + { // ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx); setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx); } @@ -89,7 +92,8 @@ void ElecStatePW::psiToRho(const psi::Psi, if (GlobalV::device_flag == "gpu" || GlobalV::precision_flag == "single") { for (int ii = 0; ii < GlobalV::NSPIN; ii++) { castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx); - if (XC_Functional::get_func_type() == 3) { + if (get_xc_func_type() == 3) + { castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx); } } @@ -161,7 +165,7 @@ void ElecStatePW::rhoBandK(const psi::Psi, this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); - const auto w1 = static_cast(this->wg(ik, ibnd) / GlobalC::ucell.omega); + const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); // replaced by denghui at 20221110 elecstate_pw_op()(this->ctx, GlobalV::DOMAG, GlobalV::DOMAG_Z, this->charge->nrxx, w1, this->rho, this->wfcr, this->wfcr_another_spin); @@ -180,7 +184,7 @@ void ElecStatePW::rhoBandK(const psi::Psi, this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); - const auto w1 = static_cast(this->wg(ik, ibnd) / GlobalC::ucell.omega); + const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); if (w1 != 0.0) { @@ -189,13 +193,22 @@ void ElecStatePW::rhoBandK(const psi::Psi, } // kinetic energy density - if (XC_Functional::get_func_type() == 3) + if (get_xc_func_type() == 3) { for (int j = 0; j < 3; j++) { setmem_complex_op()(this->ctx, this->wfcr, 0, this->charge->nrxx); - meta_op()(this->ctx, ik, j, npw, this->basis->npwk_max, static_cast(GlobalC::ucell.tpiba), this->basis->template get_gcar_data(), this->basis->template get_kvec_c_data(), &psi(ibnd, 0), this->wfcr); + meta_op()(this->ctx, + ik, + j, + npw, + this->basis->npwk_max, + static_cast(get_ucell_tpiba()), + this->basis->template get_gcar_data(), + this->basis->template get_kvec_c_data(), + &psi(ibnd, 0), + this->wfcr); this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); diff --git a/source/module_elecstate/test/CMakeLists.txt b/source/module_elecstate/test/CMakeLists.txt index 196a6d85e8..7f0e06f009 100644 --- a/source/module_elecstate/test/CMakeLists.txt +++ b/source/module_elecstate/test/CMakeLists.txt @@ -1,6 +1,8 @@ remove_definitions(-D__MPI) remove_definitions(-D__EXX) remove_definitions(-D__CUDA) +remove_definitions(-D__UT_USE_CUDA) +remove_definitions(-D__UT_USE_ROCM) remove_definitions(-D__ROCM) remove_definitions(-D__DEEPKS) remove_definitions(-D_OPENMP) @@ -41,6 +43,26 @@ AddTest( SOURCES elecstate_base_test.cpp ../elecstate.cpp ../occupy.cpp ../../module_psi/psi.cpp ) +AddTest( + TARGET elecstate_pw + LIBS ${math_libs} base device + SOURCES elecstate_pw_test.cpp + ../elecstate_pw.cpp + ../elecstate.cpp + ../occupy.cpp + ../../module_psi/psi.cpp + ../../module_basis/module_pw/pw_basis_k.cpp + ../../module_basis/module_pw/pw_basis.cpp + ../../module_basis/module_pw/pw_init.cpp + ../../module_basis/module_pw/pw_distributeg.cpp + ../../module_basis/module_pw/pw_distributer.cpp + ../../module_basis/module_pw/pw_distributeg_method1.cpp + ../../module_basis/module_pw/pw_distributeg_method2.cpp + ../../module_basis/module_pw/pw_transform_k.cpp + ../../module_basis/module_pw/fft.cpp + ../../module_psi/kernels/memory_op.cpp +) + AddTest( TARGET elecstate_energy LIBS ${math_libs} base device diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp new file mode 100644 index 0000000000..c5a6f5e68f --- /dev/null +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -0,0 +1,193 @@ +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#define protected public +#include "module_elecstate/elecstate_pw.h" + +// mock functions for testing +namespace elecstate +{ +double get_ucell_omega() +{ + return 500.0; +} +double get_ucell_tpiba() +{ + return 2.0; +} +int tmp_xc_func_type = 1; +int get_xc_func_type() +{ + return tmp_xc_func_type; +} +void Potential::init_pot(int, Charge const*) +{ +} +void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff) +{ +} +void Potential::cal_fixed_v(double* vl_pseudo) +{ +} +Potential::~Potential() +{ +} +} // namespace elecstate +Charge::Charge() +{ +} +Charge::~Charge() +{ +} +K_Vectors::K_Vectors() +{ +} +K_Vectors::~K_Vectors() +{ +} +void Charge::set_rho_core(ModuleBase::ComplexMatrix const&) +{ +} +void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&) +{ +} +void Charge::set_rhopw(ModulePW::PW_Basis*) +{ +} +void Charge::renormalize_rho() +{ +} + +void Set_GlobalV_Default() +{ + GlobalV::device_flag = "cpu"; + GlobalV::precision_flag = "double"; + GlobalV::DOMAG = false; + GlobalV::DOMAG_Z = false; + // Base class dependent + GlobalV::NSPIN = 1; + GlobalV::nelec = 10.0; + GlobalV::nupdown = 0.0; + GlobalV::TWO_EFERMI = false; + GlobalV::NBANDS = 6; + GlobalV::NLOCAL = 6; + GlobalV::ESOLVER_TYPE = "ksdft"; + GlobalV::LSPINORB = false; + GlobalV::BASIS_TYPE = "pw"; + GlobalV::md_prec_level = 0; + GlobalV::KPAR = 1; + GlobalV::NPROC_IN_POOL = 1; +} + +/************************************************ + * unit test of elecstate_pw.cpp + ***********************************************/ + +/** + * - Tested Functions: + * - Constructor: elecstate::ElecStatePW constructor and destructor + * - including double and single precision versions + * - InitRhoData: elecstate::ElecStatePW::init_rho_data() + * - get rho and kin_r for ElecStatePW + * - ParallelK: elecstate::ElecStatePW::parallelK() + * - trivial call due to removing of __MPI + * - todo: psiToRho: elecstate::ElecStatePW::psiToRho() + */ + +class ElecStatePWTest : public ::testing::Test +{ + protected: + elecstate::ElecStatePW* elecstate_pw_d = nullptr; + elecstate::ElecStatePW* elecstate_pw_s = nullptr; + ModulePW::PW_Basis_K* wfcpw = nullptr; + Charge* chg = nullptr; + K_Vectors* klist = nullptr; + ModulePW::PW_Basis* rhopw = nullptr; + ModulePW::PW_Basis_Big* bigpw = nullptr; + void SetUp() override + { + Set_GlobalV_Default(); + wfcpw = new ModulePW::PW_Basis_K; + chg = new Charge; + klist = new K_Vectors; + klist->nks = 5; + rhopw = new ModulePW::PW_Basis; + bigpw = new ModulePW::PW_Basis_Big; + } + + void TearDown() override + { + delete wfcpw; + delete chg; + delete klist; + delete rhopw; + if (elecstate_pw_d != nullptr) + { + delete elecstate_pw_d; + } + if (elecstate_pw_s != nullptr) + { + delete elecstate_pw_s; + } + } +}; + +TEST_F(ElecStatePWTest, ConstructorDouble) +{ + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_EQ(elecstate_pw_d->classname, "ElecStatePW"); + EXPECT_EQ(elecstate_pw_d->charge, chg); + EXPECT_EQ(elecstate_pw_d->klist, klist); + EXPECT_EQ(elecstate_pw_d->bigpw, bigpw); +} + +TEST_F(ElecStatePWTest, ConstructorSingle) +{ + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_EQ(elecstate_pw_s->classname, "ElecStatePW"); + EXPECT_EQ(elecstate_pw_s->charge, chg); + EXPECT_EQ(elecstate_pw_s->klist, klist); + EXPECT_EQ(elecstate_pw_s->bigpw, bigpw); +} + +TEST_F(ElecStatePWTest, InitRhoDataDouble) +{ + elecstate::tmp_xc_func_type = 3; + chg->nrxx = 1000; + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + elecstate_pw_d->init_rho_data(); + EXPECT_EQ(elecstate_pw_d->init_rho, true); + EXPECT_EQ(elecstate_pw_d->rho, chg->rho); + EXPECT_EQ(elecstate_pw_d->kin_r, chg->kin_r); +} + +TEST_F(ElecStatePWTest, InitRhoDataSingle) +{ + GlobalV::precision_flag = "single"; + elecstate::tmp_xc_func_type = 3; + chg->nspin = GlobalV::NSPIN; + chg->nrxx = 1000; + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + elecstate_pw_s->init_rho_data(); + EXPECT_EQ(elecstate_pw_s->init_rho, true); + EXPECT_NE(elecstate_pw_s->rho, nullptr); + EXPECT_NE(elecstate_pw_s->kin_r, nullptr); +} + +TEST_F(ElecStatePWTest, ParallelKDouble) +{ + //this is a trivial call due to removing of __MPI + elecstate_pw_d = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_NO_THROW(elecstate_pw_d->parallelK()); +} + +TEST_F(ElecStatePWTest, ParallelKSingle) +{ + //this is a trivial call due to removing of __MPI + elecstate_pw_s = new elecstate::ElecStatePW(wfcpw, chg, klist, rhopw, bigpw); + EXPECT_NO_THROW(elecstate_pw_s->parallelK()); +} + +#undef protected \ No newline at end of file