Skip to content
35 changes: 24 additions & 11 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "elecstate_pw.h"

#include "elecstate_getters.h"
#include "module_base/constants.h"
#include "module_base/parallel_reduce.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_base/timer.h"
#include "module_psi/kernels/device.h"

Expand All @@ -20,7 +20,8 @@ ElecStatePW<FPTYPE, Device>::~ElecStatePW()
{
if (psi::device::get_device_type<Device>(this->ctx) == psi::GpuDevice) {
delmem_var_op()(this->ctx, this->rho_data);
if (XC_Functional::get_func_type() == 3) {
if (get_xc_func_type() == 3)
{
delmem_var_op()(this->ctx, this->kin_r_data);
}
}
Expand All @@ -37,7 +38,8 @@ void ElecStatePW<FPTYPE, Device>::init_rho_data()
for (int ii = 0; ii < this->charge->nspin; ii++) {
this->rho[ii] = this->rho_data + ii * this->charge->nrxx;
}
if (XC_Functional::get_func_type() == 3) {
if (get_xc_func_type() == 3)
{
this->kin_r = new FPTYPE*[this->charge->nspin];
resmem_var_op()(this->ctx, this->kin_r_data, this->charge->nspin * this->charge->nrxx);
for (int ii = 0; ii < this->charge->nspin; ii++) {
Expand All @@ -47,7 +49,8 @@ void ElecStatePW<FPTYPE, Device>::init_rho_data()
}
else {
this->rho = reinterpret_cast<FPTYPE **>(this->charge->rho);
if (XC_Functional::get_func_type() == 3) {
if (get_xc_func_type() == 3)
{
this->kin_r = reinterpret_cast<FPTYPE **>(this->charge->kin_r);
}
}
Expand All @@ -74,8 +77,8 @@ void ElecStatePW<FPTYPE, Device>::psiToRho(const psi::Psi<std::complex<FPTYPE>,
// denghui replaced at 20221110
// ModuleBase::GlobalFunc::ZEROS(this->rho[is], this->charge->nrxx);
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
if (XC_Functional::get_func_type() == 3)
{
if (get_xc_func_type() == 3)
{
// ModuleBase::GlobalFunc::ZEROS(this->charge->kin_r[is], this->charge->nrxx);
setmem_var_op()(this->ctx, this->kin_r[is], 0, this->charge->nrxx);
}
Expand All @@ -89,7 +92,8 @@ void ElecStatePW<FPTYPE, Device>::psiToRho(const psi::Psi<std::complex<FPTYPE>,
if (GlobalV::device_flag == "gpu" || GlobalV::precision_flag == "single") {
for (int ii = 0; ii < GlobalV::NSPIN; ii++) {
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->rho[ii], this->rho[ii], this->charge->nrxx);
if (XC_Functional::get_func_type() == 3) {
if (get_xc_func_type() == 3)
{
castmem_var_d2h_op()(cpu_ctx, this->ctx, this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
}
}
Expand Down Expand Up @@ -161,7 +165,7 @@ void ElecStatePW<FPTYPE, Device>::rhoBandK(const psi::Psi<std::complex<FPTYPE>,

this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik);

const auto w1 = static_cast<FPTYPE>(this->wg(ik, ibnd) / GlobalC::ucell.omega);
const auto w1 = static_cast<FPTYPE>(this->wg(ik, ibnd) / get_ucell_omega());

// replaced by denghui at 20221110
elecstate_pw_op()(this->ctx, GlobalV::DOMAG, GlobalV::DOMAG_Z, this->charge->nrxx, w1, this->rho, this->wfcr, this->wfcr_another_spin);
Expand All @@ -180,7 +184,7 @@ void ElecStatePW<FPTYPE, Device>::rhoBandK(const psi::Psi<std::complex<FPTYPE>,

this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);

const auto w1 = static_cast<FPTYPE>(this->wg(ik, ibnd) / GlobalC::ucell.omega);
const auto w1 = static_cast<FPTYPE>(this->wg(ik, ibnd) / get_ucell_omega());

if (w1 != 0.0)
{
Expand All @@ -189,13 +193,22 @@ void ElecStatePW<FPTYPE, Device>::rhoBandK(const psi::Psi<std::complex<FPTYPE>,
}

// kinetic energy density
if (XC_Functional::get_func_type() == 3)
if (get_xc_func_type() == 3)
{
for (int j = 0; j < 3; j++)
{
setmem_complex_op()(this->ctx, this->wfcr, 0, this->charge->nrxx);

meta_op()(this->ctx, ik, j, npw, this->basis->npwk_max, static_cast<FPTYPE>(GlobalC::ucell.tpiba), this->basis->template get_gcar_data<FPTYPE>(), this->basis->template get_kvec_c_data<FPTYPE>(), &psi(ibnd, 0), this->wfcr);
meta_op()(this->ctx,
ik,
j,
npw,
this->basis->npwk_max,
static_cast<FPTYPE>(get_ucell_tpiba()),
this->basis->template get_gcar_data<FPTYPE>(),
this->basis->template get_kvec_c_data<FPTYPE>(),
&psi(ibnd, 0),
this->wfcr);

this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);

Expand Down
22 changes: 22 additions & 0 deletions source/module_elecstate/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
remove_definitions(-D__MPI)
remove_definitions(-D__EXX)
remove_definitions(-D__CUDA)
remove_definitions(-D__UT_USE_CUDA)
remove_definitions(-D__UT_USE_ROCM)
remove_definitions(-D__ROCM)
remove_definitions(-D__DEEPKS)
remove_definitions(-D_OPENMP)
Expand Down Expand Up @@ -41,6 +43,26 @@ AddTest(
SOURCES elecstate_base_test.cpp ../elecstate.cpp ../occupy.cpp ../../module_psi/psi.cpp
)

AddTest(
TARGET elecstate_pw
LIBS ${math_libs} base device
SOURCES elecstate_pw_test.cpp
../elecstate_pw.cpp
../elecstate.cpp
../occupy.cpp
../../module_psi/psi.cpp
../../module_basis/module_pw/pw_basis_k.cpp
../../module_basis/module_pw/pw_basis.cpp
../../module_basis/module_pw/pw_init.cpp
../../module_basis/module_pw/pw_distributeg.cpp
../../module_basis/module_pw/pw_distributer.cpp
../../module_basis/module_pw/pw_distributeg_method1.cpp
../../module_basis/module_pw/pw_distributeg_method2.cpp
../../module_basis/module_pw/pw_transform_k.cpp
../../module_basis/module_pw/fft.cpp
../../module_psi/kernels/memory_op.cpp
)

AddTest(
TARGET elecstate_energy
LIBS ${math_libs} base device
Expand Down
193 changes: 193 additions & 0 deletions source/module_elecstate/test/elecstate_pw_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#include <string>

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#define protected public
#include "module_elecstate/elecstate_pw.h"

// mock functions for testing
namespace elecstate
{
double get_ucell_omega()
{
return 500.0;
}
double get_ucell_tpiba()
{
return 2.0;
}
int tmp_xc_func_type = 1;
int get_xc_func_type()
{
return tmp_xc_func_type;
}
void Potential::init_pot(int, Charge const*)
{
}
void Potential::cal_v_eff(const Charge* chg, const UnitCell* ucell, ModuleBase::matrix& v_eff)
{
}
void Potential::cal_fixed_v(double* vl_pseudo)
{
}
Potential::~Potential()
{
}
} // namespace elecstate
Charge::Charge()
{
}
Charge::~Charge()
{
}
K_Vectors::K_Vectors()
{
}
K_Vectors::~K_Vectors()
{
}
void Charge::set_rho_core(ModuleBase::ComplexMatrix const&)
{
}
void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&)
{
}
void Charge::set_rhopw(ModulePW::PW_Basis*)
{
}
void Charge::renormalize_rho()
{
}

void Set_GlobalV_Default()
{
GlobalV::device_flag = "cpu";
GlobalV::precision_flag = "double";
GlobalV::DOMAG = false;
GlobalV::DOMAG_Z = false;
// Base class dependent
GlobalV::NSPIN = 1;
GlobalV::nelec = 10.0;
GlobalV::nupdown = 0.0;
GlobalV::TWO_EFERMI = false;
GlobalV::NBANDS = 6;
GlobalV::NLOCAL = 6;
GlobalV::ESOLVER_TYPE = "ksdft";
GlobalV::LSPINORB = false;
GlobalV::BASIS_TYPE = "pw";
GlobalV::md_prec_level = 0;
GlobalV::KPAR = 1;
GlobalV::NPROC_IN_POOL = 1;
}

/************************************************
* unit test of elecstate_pw.cpp
***********************************************/

/**
* - Tested Functions:
* - Constructor: elecstate::ElecStatePW constructor and destructor
* - including double and single precision versions
* - InitRhoData: elecstate::ElecStatePW::init_rho_data()
* - get rho and kin_r for ElecStatePW
* - ParallelK: elecstate::ElecStatePW::parallelK()
* - trivial call due to removing of __MPI
* - todo: psiToRho: elecstate::ElecStatePW::psiToRho()
*/

class ElecStatePWTest : public ::testing::Test
{
protected:
elecstate::ElecStatePW<double, psi::DEVICE_CPU>* elecstate_pw_d = nullptr;
elecstate::ElecStatePW<float, psi::DEVICE_CPU>* elecstate_pw_s = nullptr;
ModulePW::PW_Basis_K* wfcpw = nullptr;
Charge* chg = nullptr;
K_Vectors* klist = nullptr;
ModulePW::PW_Basis* rhopw = nullptr;
ModulePW::PW_Basis_Big* bigpw = nullptr;
void SetUp() override
{
Set_GlobalV_Default();
wfcpw = new ModulePW::PW_Basis_K;
chg = new Charge;
klist = new K_Vectors;
klist->nks = 5;
rhopw = new ModulePW::PW_Basis;
bigpw = new ModulePW::PW_Basis_Big;
}

void TearDown() override
{
delete wfcpw;
delete chg;
delete klist;
delete rhopw;
if (elecstate_pw_d != nullptr)
{
delete elecstate_pw_d;
}
if (elecstate_pw_s != nullptr)
{
delete elecstate_pw_s;
}
}
};

TEST_F(ElecStatePWTest, ConstructorDouble)
{
elecstate_pw_d = new elecstate::ElecStatePW<double, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
EXPECT_EQ(elecstate_pw_d->classname, "ElecStatePW");
EXPECT_EQ(elecstate_pw_d->charge, chg);
EXPECT_EQ(elecstate_pw_d->klist, klist);
EXPECT_EQ(elecstate_pw_d->bigpw, bigpw);
}

TEST_F(ElecStatePWTest, ConstructorSingle)
{
elecstate_pw_s = new elecstate::ElecStatePW<float, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
EXPECT_EQ(elecstate_pw_s->classname, "ElecStatePW");
EXPECT_EQ(elecstate_pw_s->charge, chg);
EXPECT_EQ(elecstate_pw_s->klist, klist);
EXPECT_EQ(elecstate_pw_s->bigpw, bigpw);
}

TEST_F(ElecStatePWTest, InitRhoDataDouble)
{
elecstate::tmp_xc_func_type = 3;
chg->nrxx = 1000;
elecstate_pw_d = new elecstate::ElecStatePW<double, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
elecstate_pw_d->init_rho_data();
EXPECT_EQ(elecstate_pw_d->init_rho, true);
EXPECT_EQ(elecstate_pw_d->rho, chg->rho);
EXPECT_EQ(elecstate_pw_d->kin_r, chg->kin_r);
}

TEST_F(ElecStatePWTest, InitRhoDataSingle)
{
GlobalV::precision_flag = "single";
elecstate::tmp_xc_func_type = 3;
chg->nspin = GlobalV::NSPIN;
chg->nrxx = 1000;
elecstate_pw_s = new elecstate::ElecStatePW<float, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
elecstate_pw_s->init_rho_data();
EXPECT_EQ(elecstate_pw_s->init_rho, true);
EXPECT_NE(elecstate_pw_s->rho, nullptr);
EXPECT_NE(elecstate_pw_s->kin_r, nullptr);
}

TEST_F(ElecStatePWTest, ParallelKDouble)
{
//this is a trivial call due to removing of __MPI
elecstate_pw_d = new elecstate::ElecStatePW<double, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
EXPECT_NO_THROW(elecstate_pw_d->parallelK());
}

TEST_F(ElecStatePWTest, ParallelKSingle)
{
//this is a trivial call due to removing of __MPI
elecstate_pw_s = new elecstate::ElecStatePW<float, psi::DEVICE_CPU>(wfcpw, chg, klist, rhopw, bigpw);
EXPECT_NO_THROW(elecstate_pw_s->parallelK());
}

#undef protected