From 9cca23cbaae0c754a747383780f8296d07eb0423 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Mon, 27 Feb 2023 17:34:32 +0800 Subject: [PATCH 1/7] refactor: remove global variable from module_pw. --- source/module_pw/fft.cpp | 25 +++++++++++----------- source/module_pw/fft.h | 4 ++++ source/module_pw/pw_basis.cpp | 2 +- source/module_pw/pw_basis_k.cpp | 32 ++++++++++++++++++---------- source/module_pw/pw_basis_k.h | 6 ++++++ source/module_pw/pw_distributeg.cpp | 4 ++-- source/module_pw/test/test-other.cpp | 6 ++++-- 7 files changed, 50 insertions(+), 29 deletions(-) diff --git a/source/module_pw/fft.cpp b/source/module_pw/fft.cpp index c3e5c582e1..05408e3453 100644 --- a/source/module_pw/fft.cpp +++ b/source/module_pw/fft.cpp @@ -1,6 +1,5 @@ #include "fft.h" -#include "module_base/global_variable.h" #include "module_base/memory.h" #ifdef _OPENMP @@ -36,8 +35,8 @@ void FFT::clear() if(z_auxr!=nullptr) {fftw_free(z_auxr); z_auxr = nullptr;} d_rspace = nullptr; #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { if (c_auxr_3d != nullptr) { delmem_cd_op()(gpu_ctx, c_auxr_3d); c_auxr_3d = nullptr; @@ -51,7 +50,7 @@ void FFT::clear() } } #endif - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { this->cleanfFFT(); if (c_auxg != nullptr) { fftw_free(c_auxg); @@ -98,8 +97,8 @@ void FFT:: initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, in // auxr_3d = static_cast *>( // fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz))); #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz); } else { @@ -107,7 +106,7 @@ void FFT:: initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, in } } #endif - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { c_auxg = (std::complex *) fftw_malloc(sizeof(fftwf_complex) * maxgrids); c_auxr = (std::complex *) fftw_malloc(sizeof(fftwf_complex) * maxgrids); ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids); @@ -126,7 +125,7 @@ void FFT:: setupFFT() if(!this->mpifft) { this->initplan(); - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { this->initplanf(); } } @@ -134,7 +133,7 @@ void FFT:: setupFFT() else { // this->initplan_mpi(); - // if (GlobalV::precision_flag == "single") { + // if (this->precision == "single") { // this->initplanf_mpi(); // } } @@ -232,8 +231,8 @@ void FFT :: initplan() // FFTW_BACKWARD, FFTW_MEASURE); #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { #if defined(__CUDA) cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C); #elif defined(__ROCM) @@ -376,8 +375,8 @@ void FFT:: cleanFFT() // fftw_destroy_plan(this->plan3dforward); // fftw_destroy_plan(this->plan3dbackward); #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { #if defined(__CUDA) cufftDestroy(c_handle); #elif defined(__ROCM) diff --git a/source/module_pw/fft.h b/source/module_pw/fft.h index 301769436d..1e3551e5c8 100644 --- a/source/module_pw/fft.h +++ b/source/module_pw/fft.h @@ -22,6 +22,7 @@ //Temporary: we donot need psi. However some GPU ops are defined in psi, which should be moved into module_base or module_gpu #include "module_psi/psi.h" +#include "module_base/global_variable.h" // #ifdef __MIX_PRECISION // #include "fftw3f.h" // #if defined(__FFTW3_MPI) && defined(__MPI) @@ -139,6 +140,9 @@ class FFT float * s_rspace=nullptr; //real number space for r, [nplane * nx *ny] double * d_rspace=nullptr; //real number space for r, [nplane * nx *ny] + + std::string device = GlobalV::device_flag; + std::string precision = GlobalV::precision_flag; }; } diff --git a/source/module_pw/pw_basis.cpp b/source/module_pw/pw_basis.cpp index 118efa440b..b779d024d7 100644 --- a/source/module_pw/pw_basis.cpp +++ b/source/module_pw/pw_basis.cpp @@ -31,7 +31,7 @@ PW_Basis:: ~PW_Basis() delete[] ig2igg; delete[] gg_uniq; #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { + if (this->device == "gpu") { delmem_int_op()(gpu_ctx, this->d_is2fftixy); } #endif diff --git a/source/module_pw/pw_basis_k.cpp b/source/module_pw/pw_basis_k.cpp index 5ec947b4fd..15324fdf1b 100644 --- a/source/module_pw/pw_basis_k.cpp +++ b/source/module_pw/pw_basis_k.cpp @@ -1,4 +1,6 @@ #include "pw_basis_k.h" + +#include #include "../module_base/constants.h" #include "../module_base/timer.h" #include "module_base/memory.h" @@ -20,8 +22,8 @@ PW_Basis_K::~PW_Basis_K() delete[] gk2; delete[] ig2ixyz_k_; #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { delmem_sd_op()(gpu_ctx, this->s_kvec_c); delmem_sd_op()(gpu_ctx, this->s_gcar); delmem_sd_op()(gpu_ctx, this->s_gk2); @@ -36,7 +38,7 @@ PW_Basis_K::~PW_Basis_K() } else { #endif - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { delmem_sh_op()(cpu_ctx, this->s_kvec_c); delmem_sh_op()(cpu_ctx, this->s_gcar); delmem_sh_op()(cpu_ctx, this->s_gk2); @@ -91,8 +93,8 @@ void PW_Basis_K:: initparameters( this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { resmem_sd_op()(gpu_ctx, this->s_kvec_c, this->nks * 3); castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } @@ -103,7 +105,7 @@ void PW_Basis_K:: initparameters( } else { #endif - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { resmem_sh_op()(cpu_ctx, this->s_kvec_c, this->nks * 3); castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast(&this->kvec_c[0][0]), this->nks * 3); } @@ -160,7 +162,7 @@ void PW_Basis_K::setupIndGk() } } #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { + if (this->device == "gpu") { resmem_int_op()(gpu_ctx, this->d_igl2isz_k, this->npwk_max * this->nks); syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks); } @@ -221,8 +223,8 @@ void PW_Basis_K::collect_local_pw() } } #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { - if (GlobalV::precision_flag == "single") { + if (this->device == "gpu") { + if (this->precision == "single") { resmem_sd_op()(gpu_ctx, this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(gpu_ctx, this->s_gcar, this->npwk_max * this->nks * 3); castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks); @@ -237,7 +239,7 @@ void PW_Basis_K::collect_local_pw() } else { #endif - if (GlobalV::precision_flag == "single") { + if (this->precision == "single") { resmem_sh_op()(cpu_ctx, this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(cpu_ctx, this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks); @@ -325,13 +327,21 @@ void PW_Basis_K::get_ig2ixyz_k() } } #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { + if (this->device == "gpu") { resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks); syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks); } #endif } +void PW_Basis_K::set_device(std::string device_) { + this->device = std::move(device_); +} + +void PW_Basis_K::set_precision(std::string precision_) { + this->device = std::move(precision_); +} + template <> float * PW_Basis_K::get_kvec_c_data() const { return this->s_kvec_c; diff --git a/source/module_pw/pw_basis_k.h b/source/module_pw/pw_basis_k.h index 6181efa634..ee5cef453b 100644 --- a/source/module_pw/pw_basis_k.h +++ b/source/module_pw/pw_basis_k.h @@ -129,9 +129,15 @@ class PW_Basis_K : public PW_Basis template FPTYPE * get_gcar_data() const; template FPTYPE * get_kvec_c_data() const; + void set_device(std::string device_); + void set_precision(std::string precision_); + private: float * s_gcar = nullptr, * s_kvec_c = nullptr; double * d_gcar = nullptr, * d_kvec_c = nullptr; + + std::string device = GlobalV::device_flag; + std::string precision = GlobalV::precision_flag; }; } diff --git a/source/module_pw/pw_distributeg.cpp b/source/module_pw/pw_distributeg.cpp index f84934ef59..3ec17b6e18 100644 --- a/source/module_pw/pw_distributeg.cpp +++ b/source/module_pw/pw_distributeg.cpp @@ -149,7 +149,7 @@ void PW_Basis::get_ig2isz_is2fftixy( delete[] this->ig2isz; this->ig2isz = nullptr; // map ig to the z coordinate of this planewave. delete[] this->is2fftixy; this->is2fftixy = nullptr; // map is (index of sticks) to ixy (iy + ix * fftny). #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { + if (this->device == "gpu") { delmem_int_op()(gpu_ctx, this->d_is2fftixy); d_is2fftixy = nullptr; } @@ -186,7 +186,7 @@ void PW_Basis::get_ig2isz_is2fftixy( if (st_move == this->nst && pw_filled == this->npw) break; } #if defined(__CUDA) || defined(__ROCM) - if (GlobalV::device_flag == "gpu") { + if (this->device == "gpu") { resmem_int_op()(gpu_ctx, d_is2fftixy, this->nst); syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_is2fftixy, this->is2fftixy, this->nst); } diff --git a/source/module_pw/test/test-other.cpp b/source/module_pw/test/test-other.cpp index 7b4f941cd4..9ef123f9fc 100644 --- a/source/module_pw/test/test-other.cpp +++ b/source/module_pw/test/test-other.cpp @@ -35,13 +35,15 @@ TEST_F(PWTEST,test_other) ModuleBase::Vector3 *kvec_d = new ModuleBase::Vector3[nks]; kvec_d[0].set(0,0,0.5); kvec_d[1].set(0.5,0.5,0.5); - GlobalV::precision_flag = "double"; //temporary + // GlobalV::precision_flag = "double"; //temporary + pwktest.set_precision("double"); pwktest.initgrids(2, latvec, 4,4,4); pwktest.initparameters(true, 20, nks, kvec_d); pwktest.setuptransform(); pwktest.collect_local_pw(); #ifdef __MIX_PRECISION - GlobalV::precision_flag = "single"; + // GlobalV::precision_flag = "single"; + pwktest.set_precision("single"); #endif pwktest.initparameters(true, 8, nks, kvec_d); pwktest.setuptransform(); From b4388062c2b0672474359a6df2670ea5047cfeb5 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Mon, 27 Feb 2023 17:45:23 +0800 Subject: [PATCH 2/7] fix: fix compilation error. --- source/module_pw/pw_basis.h | 4 ++++ source/module_pw/pw_basis_k.h | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/source/module_pw/pw_basis.h b/source/module_pw/pw_basis.h index 2c9171fa26..40fae098dc 100644 --- a/source/module_pw/pw_basis.h +++ b/source/module_pw/pw_basis.h @@ -260,6 +260,10 @@ class PW_Basis using resmem_int_op = psi::memory::resize_memory_op; using delmem_int_op = psi::memory::delete_memory_op; using syncmem_int_h2d_op = psi::memory::synchronize_memory_op; + +protected: + std::string device = GlobalV::device_flag; + std::string precision = GlobalV::precision_flag; }; } diff --git a/source/module_pw/pw_basis_k.h b/source/module_pw/pw_basis_k.h index ee5cef453b..31fafb8d97 100644 --- a/source/module_pw/pw_basis_k.h +++ b/source/module_pw/pw_basis_k.h @@ -136,8 +136,6 @@ class PW_Basis_K : public PW_Basis float * s_gcar = nullptr, * s_kvec_c = nullptr; double * d_gcar = nullptr, * d_kvec_c = nullptr; - std::string device = GlobalV::device_flag; - std::string precision = GlobalV::precision_flag; }; } From 7f9f9b2dfc94c099a8deda35f757757a8945ab12 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Mon, 27 Feb 2023 23:44:39 +0800 Subject: [PATCH 3/7] refactor: remove direct usage of GlobalV in module_pw. --- source/module_esolver/esolver_fp.cpp | 2 +- source/module_esolver/esolver_ks.cpp | 2 +- source/module_pw/fft.cpp | 8 ++++++++ source/module_pw/fft.h | 9 +++++++-- source/module_pw/pw_basis.cpp | 9 +++++++++ source/module_pw/pw_basis.h | 7 +++++-- source/module_pw/pw_basis_big.h | 7 +++++++ source/module_pw/pw_basis_k.cpp | 8 -------- source/module_pw/pw_basis_k.h | 3 --- source/module_pw/pw_basis_k_big.h | 7 +++++++ 10 files changed, 45 insertions(+), 17 deletions(-) diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 4257de0708..ed8dd44cbe 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -6,7 +6,7 @@ namespace ModuleESolver { // pw_rho = new ModuleBase::PW_Basis(); - pw_rho = new ModulePW::PW_Basis_Big(); + pw_rho = new ModulePW::PW_Basis_Big(GlobalV::device_flag, GlobalV::precision_flag); GlobalC::rhopw = this->pw_rho; //Temporary //temporary, it will be removed GlobalC::bigpw = static_cast(pw_rho); diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index e82cf28d4a..b1626c0ff3 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -31,7 +31,7 @@ namespace ModuleESolver // pw_rho = new ModuleBase::PW_Basis(); //temporary, it will be removed - pw_wfc = new ModulePW::PW_Basis_K_Big(); + pw_wfc = new ModulePW::PW_Basis_K_Big(GlobalV::device_flag, GlobalV::precision_flag); GlobalC::wfcpw = this->pw_wfc; //Temporary ModulePW::PW_Basis_K_Big* tmp = static_cast(pw_wfc); tmp->setbxyz(INPUT.bx,INPUT.by,INPUT.bz); diff --git a/source/module_pw/fft.cpp b/source/module_pw/fft.cpp index 05408e3453..18b54821f2 100644 --- a/source/module_pw/fft.cpp +++ b/source/module_pw/fft.cpp @@ -777,4 +777,12 @@ std::complex * FFT::get_auxr_3d_data() { } #endif +void FFT::set_device(std::string device_) { + this->device = std::move(device_); } + +void FFT::set_precision(std::string precision_) { + this->precision = std::move(precision_); +} + +} // namespace ModulePW diff --git a/source/module_pw/fft.h b/source/module_pw/fft.h index 1e3551e5c8..5655588df0 100644 --- a/source/module_pw/fft.h +++ b/source/module_pw/fft.h @@ -141,8 +141,13 @@ class FFT float * s_rspace=nullptr; //real number space for r, [nplane * nx *ny] double * d_rspace=nullptr; //real number space for r, [nplane * nx *ny] - std::string device = GlobalV::device_flag; - std::string precision = GlobalV::precision_flag; + std::string device = "cpu"; + std::string precision = "double"; + +public: + void set_device(std::string device_); + void set_precision(std::string precision_); + }; } diff --git a/source/module_pw/pw_basis.cpp b/source/module_pw/pw_basis.cpp index b779d024d7..6dc5a8c212 100644 --- a/source/module_pw/pw_basis.cpp +++ b/source/module_pw/pw_basis.cpp @@ -1,4 +1,6 @@ #include "pw_basis.h" + +#include #include "../module_base/mymath.h" #include "../module_base/timer.h" #include "../module_base/global_function.h" @@ -215,5 +217,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) } } +void PW_Basis::set_device(std::string device_) { + this->device = std::move(device_); +} + +void PW_Basis::set_precision(std::string precision_) { + this->precision = std::move(precision_); +} } \ No newline at end of file diff --git a/source/module_pw/pw_basis.h b/source/module_pw/pw_basis.h index 40fae098dc..6648ab69fb 100644 --- a/source/module_pw/pw_basis.h +++ b/source/module_pw/pw_basis.h @@ -261,9 +261,12 @@ class PW_Basis using delmem_int_op = psi::memory::delete_memory_op; using syncmem_int_h2d_op = psi::memory::synchronize_memory_op; + void set_device(std::string device_); + void set_precision(std::string precision_); + protected: - std::string device = GlobalV::device_flag; - std::string precision = GlobalV::precision_flag; + std::string device = "cpu"; + std::string precision = "double"; }; } diff --git a/source/module_pw/pw_basis_big.h b/source/module_pw/pw_basis_big.h index 85b110e0db..645dfdbb58 100644 --- a/source/module_pw/pw_basis_big.h +++ b/source/module_pw/pw_basis_big.h @@ -24,6 +24,13 @@ class PW_Basis_Big: public PW_Basis by = 1; bz = 1; } + PW_Basis_Big(std::string device_, std::string precision_) + { + this->device = std::move(device_); + this->precision = std::move(precision_); + this->ft.set_device(this->device); + this->ft.set_precision(this->precision); + } ~PW_Basis_Big(){}; void setbxyz(const int bx_in, const int by_in, const int bz_in) { diff --git a/source/module_pw/pw_basis_k.cpp b/source/module_pw/pw_basis_k.cpp index 15324fdf1b..cbcfe09960 100644 --- a/source/module_pw/pw_basis_k.cpp +++ b/source/module_pw/pw_basis_k.cpp @@ -334,14 +334,6 @@ void PW_Basis_K::get_ig2ixyz_k() #endif } -void PW_Basis_K::set_device(std::string device_) { - this->device = std::move(device_); -} - -void PW_Basis_K::set_precision(std::string precision_) { - this->device = std::move(precision_); -} - template <> float * PW_Basis_K::get_kvec_c_data() const { return this->s_kvec_c; diff --git a/source/module_pw/pw_basis_k.h b/source/module_pw/pw_basis_k.h index 31fafb8d97..1e9e412e9b 100644 --- a/source/module_pw/pw_basis_k.h +++ b/source/module_pw/pw_basis_k.h @@ -129,9 +129,6 @@ class PW_Basis_K : public PW_Basis template FPTYPE * get_gcar_data() const; template FPTYPE * get_kvec_c_data() const; - void set_device(std::string device_); - void set_precision(std::string precision_); - private: float * s_gcar = nullptr, * s_kvec_c = nullptr; double * d_gcar = nullptr, * d_kvec_c = nullptr; diff --git a/source/module_pw/pw_basis_k_big.h b/source/module_pw/pw_basis_k_big.h index 23a2754ac5..a08de90755 100644 --- a/source/module_pw/pw_basis_k_big.h +++ b/source/module_pw/pw_basis_k_big.h @@ -22,6 +22,13 @@ class PW_Basis_K_Big: public PW_Basis_K by = 1; bz = 1; } + PW_Basis_K_Big(std::string device_, std::string precision_) + { + this->device = std::move(device_); + this->precision = std::move(precision_); + this->ft.set_device(this->device); + this->ft.set_precision(this->precision); + } ~PW_Basis_K_Big(){}; void setbxyz(const int bx_in, const int by_in, const int bz_in) { From d305af82dabd08cb998d33cbf60910061c7b8bd1 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Tue, 28 Feb 2023 17:01:40 +0800 Subject: [PATCH 4/7] address comments --- source/module_pw/pw_basis_k.h | 6 ++++++ source/module_pw/test/pw_test.cpp | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/source/module_pw/pw_basis_k.h b/source/module_pw/pw_basis_k.h index 1e9e412e9b..816f11316a 100644 --- a/source/module_pw/pw_basis_k.h +++ b/source/module_pw/pw_basis_k.h @@ -54,6 +54,12 @@ class PW_Basis_K : public PW_Basis public: PW_Basis_K(); + PW_Basis_K(std::string device_, std::string precision_) { + this->device = std::move(device_); + this->precision = std::move(precision_); + this->ft.set_device(this->device); + this->ft.set_precision(this->precision); + } ~PW_Basis_K(); //init parameters of pw_basis_k class diff --git a/source/module_pw/test/pw_test.cpp b/source/module_pw/test/pw_test.cpp index a92e697ad3..6933c77dcf 100644 --- a/source/module_pw/test/pw_test.cpp +++ b/source/module_pw/test/pw_test.cpp @@ -37,8 +37,8 @@ int main(int argc, char **argv) int kpar; kpar = 1; #ifdef __ENABLE_FLOAT_FFTW - //Temporary, pw_basis should not contain global variables - GlobalV::precision_flag = "single"; + // Temporary, pw_basis should not contain global variables + // GlobalV::precision_flag = "single"; #endif #ifdef __MPI int nproc, myrank,mypool; From 261b7cac7d10fb7e0778a07460bef536ce4e5b07 Mon Sep 17 00:00:00 2001 From: denghuilu Date: Tue, 28 Feb 2023 20:34:02 +0800 Subject: [PATCH 5/7] address comments --- CMakeLists.txt | 5 ++++- source/module_pw/pw_basis.cpp | 6 ++++++ source/module_pw/pw_basis.h | 3 ++- source/module_pw/pw_basis_big.h | 9 ++------- source/module_pw/pw_basis_k.h | 7 +------ source/module_pw/pw_basis_k_big.h | 8 +------- source/module_pw/test/generate.cpp | 2 +- source/module_pw/test/pw_test.cpp | 4 ---- source/module_pw/test/test-big.cpp | 8 ++++---- source/module_pw/test/test-other.cpp | 8 ++++---- source/module_pw/test/test1-1-1.cpp | 2 +- source/module_pw/test/test1-1-2.cpp | 2 +- source/module_pw/test/test1-2-2.cpp | 2 +- source/module_pw/test/test1-2.cpp | 2 +- source/module_pw/test/test1-3.cpp | 2 +- source/module_pw/test/test1-4.cpp | 2 +- source/module_pw/test/test1-5.cpp | 2 +- source/module_pw/test/test2-1-1.cpp | 2 +- source/module_pw/test/test2-1-2.cpp | 2 +- source/module_pw/test/test2-2.cpp | 2 +- source/module_pw/test/test2-3.cpp | 2 +- source/module_pw/test/test3-1.cpp | 2 +- source/module_pw/test/test3-2.cpp | 2 +- source/module_pw/test/test3-3-2.cpp | 2 +- source/module_pw/test/test3-3.cpp | 2 +- source/module_pw/test/test4-1.cpp | 2 +- source/module_pw/test/test4-2.cpp | 2 +- source/module_pw/test/test4-3.cpp | 2 +- source/module_pw/test/test4-4.cpp | 2 +- source/module_pw/test/test4-5.cpp | 2 +- source/module_pw/test/test5-1-1.cpp | 2 +- source/module_pw/test/test5-1-2.cpp | 2 +- source/module_pw/test/test5-2-1.cpp | 2 +- source/module_pw/test/test5-2-2.cpp | 2 +- source/module_pw/test/test5-3-1.cpp | 2 +- source/module_pw/test/test5-4-1.cpp | 2 +- source/module_pw/test/test5-4-2.cpp | 2 +- source/module_pw/test/test6-1-1.cpp | 2 +- source/module_pw/test/test6-1-2.cpp | 2 +- source/module_pw/test/test6-2-1.cpp | 2 +- source/module_pw/test/test6-2-2.cpp | 2 +- source/module_pw/test/test6-3-1.cpp | 2 +- source/module_pw/test/test6-4-1.cpp | 2 +- source/module_pw/test/test6-4-2.cpp | 2 +- source/module_pw/test/test7-1.cpp | 2 +- source/module_pw/test/test7-2-1.cpp | 2 +- source/module_pw/test/test7-3-1.cpp | 2 +- source/module_pw/test/test7-3-2.cpp | 2 +- source/module_pw/test/test8-1.cpp | 2 +- source/module_pw/test/test8-2-1.cpp | 2 +- source/module_pw/test/test8-3-1.cpp | 2 +- source/module_pw/test/test8-3-2.cpp | 2 +- source/module_pw/test/time.cpp | 8 ++++---- 53 files changed, 71 insertions(+), 81 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dddb66aa70..a87e472c4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -289,7 +289,6 @@ else() endif() if (ENABLE_FLOAT_FFTW) list(APPEND math_libs FFTW3::FFTW3_FLOAT) - add_definitions(-D__ENABLE_FLOAT_FFTW) endif() if(CMAKE_CXX_COMPILER_ID MATCHES GNU) list(APPEND math_libs -lgfortran) @@ -302,6 +301,10 @@ else() endif() endif() +if (ENABLE_FLOAT_FFTW) + add_definitions(-D__ENABLE_FLOAT_FFTW) +endif() + if(ENABLE_DEEPKS) set(CMAKE_CXX_STANDARD 14) find_package(Torch REQUIRED) diff --git a/source/module_pw/pw_basis.cpp b/source/module_pw/pw_basis.cpp index 6dc5a8c212..2bfe6130ba 100644 --- a/source/module_pw/pw_basis.cpp +++ b/source/module_pw/pw_basis.cpp @@ -13,6 +13,12 @@ PW_Basis::PW_Basis() classname="PW_Basis"; } +PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) { + classname="PW_Basis"; + this->ft.set_device(this->device); + this->ft.set_precision(this->precision); +} + PW_Basis:: ~PW_Basis() { delete[] ig2isz; diff --git a/source/module_pw/pw_basis.h b/source/module_pw/pw_basis.h index 6648ab69fb..6961eed378 100644 --- a/source/module_pw/pw_basis.h +++ b/source/module_pw/pw_basis.h @@ -56,6 +56,7 @@ class PW_Basis public: std::string classname; PW_Basis(); + PW_Basis(std::string device_, std::string precision_); virtual ~PW_Basis(); //Init mpi parameters #ifdef __MPI @@ -270,6 +271,6 @@ class PW_Basis }; } -#endif //PlaneWave +#endif // PWBASIS_H #include "./pw_basis_big.h" //temporary it will be removed \ No newline at end of file diff --git a/source/module_pw/pw_basis_big.h b/source/module_pw/pw_basis_big.h index 645dfdbb58..443f68a50c 100644 --- a/source/module_pw/pw_basis_big.h +++ b/source/module_pw/pw_basis_big.h @@ -24,13 +24,8 @@ class PW_Basis_Big: public PW_Basis by = 1; bz = 1; } - PW_Basis_Big(std::string device_, std::string precision_) - { - this->device = std::move(device_); - this->precision = std::move(precision_); - this->ft.set_device(this->device); - this->ft.set_precision(this->precision); - } + PW_Basis_Big(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {} + ~PW_Basis_Big(){}; void setbxyz(const int bx_in, const int by_in, const int bz_in) { diff --git a/source/module_pw/pw_basis_k.h b/source/module_pw/pw_basis_k.h index 816f11316a..bb35d581fc 100644 --- a/source/module_pw/pw_basis_k.h +++ b/source/module_pw/pw_basis_k.h @@ -54,12 +54,7 @@ class PW_Basis_K : public PW_Basis public: PW_Basis_K(); - PW_Basis_K(std::string device_, std::string precision_) { - this->device = std::move(device_); - this->precision = std::move(precision_); - this->ft.set_device(this->device); - this->ft.set_precision(this->precision); - } + PW_Basis_K(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {} ~PW_Basis_K(); //init parameters of pw_basis_k class diff --git a/source/module_pw/pw_basis_k_big.h b/source/module_pw/pw_basis_k_big.h index a08de90755..3db43b7a56 100644 --- a/source/module_pw/pw_basis_k_big.h +++ b/source/module_pw/pw_basis_k_big.h @@ -22,13 +22,7 @@ class PW_Basis_K_Big: public PW_Basis_K by = 1; bz = 1; } - PW_Basis_K_Big(std::string device_, std::string precision_) - { - this->device = std::move(device_); - this->precision = std::move(precision_); - this->ft.set_device(this->device); - this->ft.set_precision(this->precision); - } + PW_Basis_K_Big(std::string device_, std::string precision_) : PW_Basis_K(device_, precision_) {} ~PW_Basis_K_Big(){}; void setbxyz(const int bx_in, const int by_in, const int bz_in) { diff --git a/source/module_pw/test/generate.cpp b/source/module_pw/test/generate.cpp index fb55160049..8d845bd082 100644 --- a/source/module_pw/test/generate.cpp +++ b/source/module_pw/test/generate.cpp @@ -43,7 +43,7 @@ int main(int argc, char **argv) create_pools(totnproc, myrank, nproc); if(myrank < nproc) { - ModulePW::PW_Basis pwtest; + ModulePW::PW_Basis pwtest(GlobalV::device_flag, GlobalV::precision_flag); #ifdef __MPI pwtest.initmpi(nproc, myrank, POOL_WORLD); #endif diff --git a/source/module_pw/test/pw_test.cpp b/source/module_pw/test/pw_test.cpp index 6933c77dcf..e132b25535 100644 --- a/source/module_pw/test/pw_test.cpp +++ b/source/module_pw/test/pw_test.cpp @@ -36,10 +36,6 @@ int main(int argc, char **argv) int kpar; kpar = 1; -#ifdef __ENABLE_FLOAT_FFTW - // Temporary, pw_basis should not contain global variables - // GlobalV::precision_flag = "single"; -#endif #ifdef __MPI int nproc, myrank,mypool; setupmpi(argc,argv,nproc, myrank); diff --git a/source/module_pw/test/test-big.cpp b/source/module_pw/test/test-big.cpp index de5b335953..14f11bb1b0 100644 --- a/source/module_pw/test/test-big.cpp +++ b/source/module_pw/test/test-big.cpp @@ -15,8 +15,8 @@ using namespace std; TEST_F(PWTEST,test_big) { cout<<"Temporary: test for pw_basis_big and pw_basis_k_big. (They should be removed in the future)"< Date: Wed, 1 Mar 2023 00:11:31 +0800 Subject: [PATCH 6/7] make module_pw independent of global_variable.cpp and global_parallel.cpp (#19) * fix: bug when stru_file=../STRU * Test: add UT for math_chebyshev * make module_pw independent of global_variable.cpp and global_parallel.cpp --- source/module_pw/test/CMakeLists.txt | 9 ++++----- source/module_pw/test/Makefile | 6 ++---- source/module_pw/test/generate.cpp | 3 +-- source/module_pw/test/pw_test.cpp | 22 ++++++++++++++++++++-- source/module_pw/test/pw_test.h | 19 +++++++++++++++++++ source/module_pw/test/test-big.cpp | 8 ++++---- source/module_pw/test/test-other.cpp | 9 ++++----- source/module_pw/test/test1-1-1.cpp | 2 +- source/module_pw/test/test1-1-2.cpp | 2 +- source/module_pw/test/test1-2-2.cpp | 2 +- source/module_pw/test/test1-2.cpp | 2 +- source/module_pw/test/test1-3.cpp | 2 +- source/module_pw/test/test1-4.cpp | 2 +- source/module_pw/test/test1-5.cpp | 2 +- source/module_pw/test/test2-1-1.cpp | 2 +- source/module_pw/test/test2-1-2.cpp | 2 +- source/module_pw/test/test2-2.cpp | 2 +- source/module_pw/test/test2-3.cpp | 2 +- source/module_pw/test/test3-1.cpp | 2 +- source/module_pw/test/test3-2.cpp | 2 +- source/module_pw/test/test3-3-2.cpp | 2 +- source/module_pw/test/test3-3.cpp | 2 +- source/module_pw/test/test4-1.cpp | 2 +- source/module_pw/test/test4-2.cpp | 2 +- source/module_pw/test/test4-3.cpp | 2 +- source/module_pw/test/test4-4.cpp | 2 +- source/module_pw/test/test4-5.cpp | 2 +- source/module_pw/test/test5-1-1.cpp | 2 +- source/module_pw/test/test5-1-2.cpp | 2 +- source/module_pw/test/test5-2-1.cpp | 2 +- source/module_pw/test/test5-2-2.cpp | 2 +- source/module_pw/test/test5-3-1.cpp | 2 +- source/module_pw/test/test5-4-1.cpp | 2 +- source/module_pw/test/test5-4-2.cpp | 2 +- source/module_pw/test/test6-1-1.cpp | 2 +- source/module_pw/test/test6-1-2.cpp | 2 +- source/module_pw/test/test6-2-1.cpp | 2 +- source/module_pw/test/test6-2-2.cpp | 2 +- source/module_pw/test/test6-3-1.cpp | 2 +- source/module_pw/test/test6-4-1.cpp | 2 +- source/module_pw/test/test6-4-2.cpp | 2 +- source/module_pw/test/test7-1.cpp | 2 +- source/module_pw/test/test7-2-1.cpp | 2 +- source/module_pw/test/test7-3-1.cpp | 2 +- source/module_pw/test/test7-3-2.cpp | 2 +- source/module_pw/test/test8-1.cpp | 2 +- source/module_pw/test/test8-2-1.cpp | 2 +- source/module_pw/test/test8-3-1.cpp | 2 +- source/module_pw/test/test8-3-2.cpp | 2 +- source/module_pw/test/test_gnu.sh | 8 ++++---- source/module_pw/test/test_tool.cpp | 8 ++++++++ source/module_pw/test/time.cpp | 8 ++++---- 52 files changed, 112 insertions(+), 72 deletions(-) diff --git a/source/module_pw/test/CMakeLists.txt b/source/module_pw/test/CMakeLists.txt index 18adbc94bf..7cb5f88069 100644 --- a/source/module_pw/test/CMakeLists.txt +++ b/source/module_pw/test/CMakeLists.txt @@ -1,11 +1,10 @@ -add_definitions(-D__NORMAL -D__MIX_PRECISION) +add_definitions(-D__NORMAL __NOMPICOMPLEX) AddTest( TARGET pw_test - LIBS ${math_libs} planewave psi device + LIBS ${math_libs} planewave device SOURCES ../../module_base/matrix.cpp ../../module_base/complexmatrix.cpp ../../module_base/matrix3.cpp ../../module_base/tool_quit.cpp - ../../module_base/mymath.cpp ../../module_base/timer.cpp ../../module_base/memory.cpp ../../module_base/global_variable.cpp - ../../module_base/parallel_common.cpp - ../../module_base/parallel_global.cpp ../../module_base/parallel_reduce.cpp + ../../module_base/mymath.cpp ../../module_base/timer.cpp ../../module_base/memory.cpp + ../../module_psi/kernels/memory_op.cpp pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp test2-1-1.cpp test2-1-2.cpp test2-2.cpp test2-3.cpp test3-1.cpp test3-2.cpp test3-3.cpp test3-3-2.cpp diff --git a/source/module_pw/test/Makefile b/source/module_pw/test/Makefile index fca331dc6d..f356074dfb 100644 --- a/source/module_pw/test/Makefile +++ b/source/module_pw/test/Makefile @@ -26,7 +26,7 @@ FFTW_DIR = /home/qianrui/gnucompile/fftw_3.3.8 #========================== # Compiler information #========================== -HONG = -D__NORMAL +HONG = -D__NORMAL -D__NOMPICOMPLEX INCLUDES = -I. -I../../ LIBS = OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES} @@ -130,9 +130,7 @@ pw_transform_k.o\ memory.o\ memory_op.o -OTHER_OBJS0=global_variable.o\ -parallel_global.o\ -parallel_reduce.o +OTHER_OBJS0= TESTFILE0=test1-1-1.o\ diff --git a/source/module_pw/test/generate.cpp b/source/module_pw/test/generate.cpp index 8d845bd082..51dd34bab5 100644 --- a/source/module_pw/test/generate.cpp +++ b/source/module_pw/test/generate.cpp @@ -43,7 +43,7 @@ int main(int argc, char **argv) create_pools(totnproc, myrank, nproc); if(myrank < nproc) { - ModulePW::PW_Basis pwtest(GlobalV::device_flag, GlobalV::precision_flag); + ModulePW::PW_Basis pwtest; #ifdef __MPI pwtest.initmpi(nproc, myrank, POOL_WORLD); #endif @@ -133,7 +133,6 @@ int main(int argc, char **argv) } - MPI_Type_free(&mpicomplex); MPI_Finalize(); diff --git a/source/module_pw/test/pw_test.cpp b/source/module_pw/test/pw_test.cpp index e132b25535..a0a719879d 100644 --- a/source/module_pw/test/pw_test.cpp +++ b/source/module_pw/test/pw_test.cpp @@ -2,13 +2,25 @@ #include "test_tool.h" #include "mpi.h" #endif -#include "../../module_base/global_variable.h" #include "fftw3.h" #include "pw_test.h" using namespace std; int nproc_in_pool, rank_in_pool; - +string precision_flag, device_flag; +namespace GlobalV +{ + std::ofstream ofs_running; +} +#ifdef __MPI +MPI_Comm POOL_WORLD; +namespace Parallel_Reduce +{ + void reduce_double_all(double &object){return;}; + void reduce_double_pool(double &object){return;}; + void reduce_double_pool(float &object){return;}; +} +#endif class TestEnv : public testing::Environment { public: @@ -36,6 +48,12 @@ int main(int argc, char **argv) int kpar; kpar = 1; +#ifdef __ENABLE_FLOAT_FFTW + precision_flag = "single"; +#else + precision_flag = "double"; +#endif + device_flag = "cpu"; #ifdef __MPI int nproc, myrank,mypool; setupmpi(argc,argv,nproc, myrank); diff --git a/source/module_pw/test/pw_test.h b/source/module_pw/test/pw_test.h index 7571ba5155..1ee806847d 100644 --- a/source/module_pw/test/pw_test.h +++ b/source/module_pw/test/pw_test.h @@ -2,8 +2,10 @@ #define __PWTEST #include "gtest/gtest.h" #include +#include using namespace std; extern int nproc_in_pool, rank_in_pool; +extern string precision_flag, device_flag; class PWTEST: public testing::Test { @@ -32,4 +34,21 @@ class PWTEST: public testing::Test } void TearDown(){} }; + +//memory.cpp depends on GlobalV::ofs_running and reduce_double_all +//GPU depends on reduce_double_pool +namespace GlobalV +{ + extern std::ofstream ofs_running; +} +#ifdef __MPI +#include "mpi.h" +extern MPI_Comm POOL_WORLD; +namespace Parallel_Reduce +{ + void reduce_double_all(double &object); + void reduce_double_pool(double &object); + void reduce_double_pool(float &object); +} +#endif #endif \ No newline at end of file diff --git a/source/module_pw/test/test-big.cpp b/source/module_pw/test/test-big.cpp index 14f11bb1b0..cc7e935c1b 100644 --- a/source/module_pw/test/test-big.cpp +++ b/source/module_pw/test/test-big.cpp @@ -15,8 +15,8 @@ using namespace std; TEST_F(PWTEST,test_big) { cout<<"Temporary: test for pw_basis_big and pw_basis_k_big. (They should be removed in the future)"< *kvec_d = new ModuleBase::Vector3[nks]; kvec_d[0].set(0,0,0.5); kvec_d[1].set(0.5,0.5,0.5); - // GlobalV::precision_flag = "double"; //temporary pwktest.set_precision("double"); pwktest.initgrids(2, latvec, 4,4,4); pwktest.initparameters(true, 20, nks, kvec_d); @@ -132,8 +131,8 @@ TEST_F(PWTEST,test_other) delete[] kvec_d; - ModulePW::PW_Basis *p_pw = new ModulePW::PW_Basis(GlobalV::device_flag, GlobalV::precision_flag); - ModulePW::PW_Basis_K *p_pwk = new ModulePW::PW_Basis_K(GlobalV::device_flag, GlobalV::precision_flag); + ModulePW::PW_Basis *p_pw = new ModulePW::PW_Basis(device_flag, precision_flag); + ModulePW::PW_Basis_K *p_pwk = new ModulePW::PW_Basis_K(device_flag, precision_flag); delete p_pw; delete p_pwk; fftw_cleanup(); diff --git a/source/module_pw/test/test1-1-1.cpp b/source/module_pw/test/test1-1-1.cpp index b72de2cfab..42c23114a5 100644 --- a/source/module_pw/test/test1-1-1.cpp +++ b/source/module_pw/test/test1-1-1.cpp @@ -22,7 +22,7 @@ TEST_F(PWTEST,test1_1_1) bool xprime = false; //-------------------------------------------------- - ModulePW::PW_Basis pwtest(GlobalV::device_flag, GlobalV::precision_flag); + ModulePW::PW_Basis pwtest(device_flag, precision_flag); #ifdef __MPI pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD); #endif diff --git a/source/module_pw/test/test1-1-2.cpp b/source/module_pw/test/test1-1-2.cpp index 439c88487c..e5bafc2b2f 100644 --- a/source/module_pw/test/test1-1-2.cpp +++ b/source/module_pw/test/test1-1-2.cpp @@ -22,7 +22,7 @@ TEST_F(PWTEST,test1_1_2) bool xprime = false; //-------------------------------------------------- - ModulePW::PW_Basis pwtest(GlobalV::device_flag, GlobalV::precision_flag); + ModulePW::PW_Basis pwtest(device_flag, precision_flag); #ifdef __MPI pwtest.initmpi(nproc_in_pool, rank_in_pool, POOL_WORLD); #endif diff --git a/source/module_pw/test/test1-2-2.cpp b/source/module_pw/test/test1-2-2.cpp index 120e8a3a00..c7784bf86d 100644 --- a/source/module_pw/test/test1-2-2.cpp +++ b/source/module_pw/test/test1-2-2.cpp @@ -15,7 +15,7 @@ using namespace std; TEST_F(PWTEST,test1_2_2) { cout<<"dividemthd 1, gamma_only: off, check fft between double and complex"< /dev/null +make -j12 CC=g++ DEBUG=ON > /dev/null 2>&1 echo "Test for Serial Version:" ./pw_test.exe elif ((i==1)) ;then -make -j12 CC=g++ FLOAT=ON DEBUG=ON > /dev/null +make -j12 CC=g++ FLOAT=ON DEBUG=ON > /dev/null 2>&1 echo "Test for Serial Version with single precision:" ./pw_test.exe elif ((i==2)) ;then -make -j12 CC=mpicxx DEBUG=ON > /dev/null +make -j12 CC=mpicxx DEBUG=ON > /dev/null 2>&1 echo "Test for MPI Version:" elif ((i==3)) ;then -make -j12 CC=mpicxx FLOAT=ON DEBUG=ON > /dev/null +make -j12 CC=mpicxx FLOAT=ON DEBUG=ON > /dev/null 2>&1 echo "Test for MPI Version with single precision:" fi if ((i>=2)) ; then diff --git a/source/module_pw/test/test_tool.cpp b/source/module_pw/test/test_tool.cpp index 699a135337..bb7ecedb98 100644 --- a/source/module_pw/test/test_tool.cpp +++ b/source/module_pw/test/test_tool.cpp @@ -1,5 +1,8 @@ #ifdef __MPI +#ifndef __NOMPICOMPLEX #include "../../module_base/parallel_global.h" +#endif +#include "pw_test.h" #include "mpi.h" #include void setupmpi(int argc,char **argv,int &nproc, int &myrank) @@ -10,6 +13,8 @@ void setupmpi(int argc,char **argv,int &nproc, int &myrank) std::cout<<"MPI_Init_thread request "< Date: Wed, 1 Mar 2023 00:30:33 +0800 Subject: [PATCH 7/7] fix bug caused by the last commit (#20) * fix: bug when stru_file=../STRU * Test: add UT for math_chebyshev * make module_pw independent of global_variable.cpp and global_parallel.cpp * fix bug --- source/module_pw/test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/module_pw/test/CMakeLists.txt b/source/module_pw/test/CMakeLists.txt index 7cb5f88069..6794369d22 100644 --- a/source/module_pw/test/CMakeLists.txt +++ b/source/module_pw/test/CMakeLists.txt @@ -1,4 +1,4 @@ -add_definitions(-D__NORMAL __NOMPICOMPLEX) +add_definitions(-D__NORMAL -D__NOMPICOMPLEX) AddTest( TARGET pw_test LIBS ${math_libs} planewave device