Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ else()
endif()
if (ENABLE_FLOAT_FFTW)
list(APPEND math_libs FFTW3::FFTW3_FLOAT)
add_definitions(-D__ENABLE_FLOAT_FFTW)
endif()
if(CMAKE_CXX_COMPILER_ID MATCHES GNU)
list(APPEND math_libs -lgfortran)
Expand All @@ -302,6 +301,10 @@ else()
endif()
endif()

if (ENABLE_FLOAT_FFTW)
add_definitions(-D__ENABLE_FLOAT_FFTW)
endif()

if(ENABLE_DEEPKS)
set(CMAKE_CXX_STANDARD 14)
find_package(Torch REQUIRED)
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace ModuleESolver
{
// pw_rho = new ModuleBase::PW_Basis();

pw_rho = new ModulePW::PW_Basis_Big();
pw_rho = new ModulePW::PW_Basis_Big(GlobalV::device_flag, GlobalV::precision_flag);
GlobalC::rhopw = this->pw_rho; //Temporary
//temporary, it will be removed
GlobalC::bigpw = static_cast<ModulePW::PW_Basis_Big*>(pw_rho);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace ModuleESolver

// pw_rho = new ModuleBase::PW_Basis();
//temporary, it will be removed
pw_wfc = new ModulePW::PW_Basis_K_Big();
pw_wfc = new ModulePW::PW_Basis_K_Big(GlobalV::device_flag, GlobalV::precision_flag);
GlobalC::wfcpw = this->pw_wfc; //Temporary
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
tmp->setbxyz(INPUT.bx,INPUT.by,INPUT.bz);
Expand Down
33 changes: 20 additions & 13 deletions source/module_pw/fft.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "fft.h"

#include "module_base/global_variable.h"
#include "module_base/memory.h"
#include "module_base/tool_quit.h"

Expand Down Expand Up @@ -37,8 +36,8 @@ void FFT::clear()
if(z_auxr!=nullptr) {fftw_free(z_auxr); z_auxr = nullptr;}
d_rspace = nullptr;
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
if (c_auxr_3d != nullptr) {
delmem_cd_op()(gpu_ctx, c_auxr_3d);
c_auxr_3d = nullptr;
Expand All @@ -53,7 +52,7 @@ void FFT::clear()
}
#endif // defined(__CUDA) || defined(__ROCM)
#if defined(__ENABLE_FLOAT_FFTW)
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
this->cleanfFFT();
if (c_auxg != nullptr) {
fftw_free(c_auxg);
Expand Down Expand Up @@ -101,8 +100,8 @@ void FFT:: initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, in
// auxr_3d = static_cast<std::complex<double> *>(
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
resmem_cd_op()(gpu_ctx, this->c_auxr_3d, this->nx * this->ny * this->nz);
}
else {
Expand All @@ -111,7 +110,7 @@ void FFT:: initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, in
}
#endif // defined(__CUDA) || defined(__ROCM)
#if defined(__ENABLE_FLOAT_FFTW)
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
c_auxg = (std::complex<float> *) fftw_malloc(sizeof(fftwf_complex) * maxgrids);
c_auxr = (std::complex<float> *) fftw_malloc(sizeof(fftwf_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
Expand All @@ -132,7 +131,7 @@ void FFT:: setupFFT()
{
this->initplan();
#if defined(__ENABLE_FLOAT_FFTW)
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
this->initplanf();
}
#endif // defined(__ENABLE_FLOAT_FFTW)
Expand All @@ -141,7 +140,7 @@ void FFT:: setupFFT()
else
{
// this->initplan_mpi();
// if (GlobalV::precision_flag == "single") {
// if (this->precision == "single") {
// this->initplanf_mpi();
// }
}
Expand Down Expand Up @@ -239,8 +238,8 @@ void FFT :: initplan()
// FFTW_BACKWARD, FFTW_MEASURE);

#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
#if defined(__CUDA)
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
#elif defined(__ROCM)
Expand Down Expand Up @@ -384,8 +383,8 @@ void FFT:: cleanFFT()
// fftw_destroy_plan(this->plan3dforward);
// fftw_destroy_plan(this->plan3dbackward);
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
#if defined(__CUDA)
cufftDestroy(c_handle);
#elif defined(__ROCM)
Expand Down Expand Up @@ -818,4 +817,12 @@ std::complex<double> * FFT::get_auxr_3d_data() {
}
#endif

void FFT::set_device(std::string device_) {
this->device = std::move(device_);
}

void FFT::set_precision(std::string precision_) {
this->precision = std::move(precision_);
}

} // namespace ModulePW
8 changes: 8 additions & 0 deletions source/module_pw/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class FFT

float * s_rspace=nullptr; //real number space for r, [nplane * nx *ny]
double * d_rspace=nullptr; //real number space for r, [nplane * nx *ny]

std::string device = "cpu";
std::string precision = "double";

public:
void set_device(std::string device_);
void set_precision(std::string precision_);

};
}

Expand Down
17 changes: 16 additions & 1 deletion source/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "pw_basis.h"

#include <utility>
#include "../module_base/mymath.h"
#include "../module_base/timer.h"
#include "../module_base/global_function.h"
Expand All @@ -11,6 +13,12 @@ PW_Basis::PW_Basis()
classname="PW_Basis";
}

PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
classname="PW_Basis";
this->ft.set_device(this->device);
this->ft.set_precision(this->precision);
}

PW_Basis:: ~PW_Basis()
{
delete[] ig2isz;
Expand All @@ -31,7 +39,7 @@ PW_Basis:: ~PW_Basis()
delete[] ig2igg;
delete[] gg_uniq;
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (this->device == "gpu") {
delmem_int_op()(gpu_ctx, this->d_is2fftixy);
}
#endif
Expand Down Expand Up @@ -215,5 +223,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is)
}
}

void PW_Basis::set_device(std::string device_) {
this->device = std::move(device_);
}

void PW_Basis::set_precision(std::string precision_) {
this->precision = std::move(precision_);
}

}
10 changes: 9 additions & 1 deletion source/module_pw/pw_basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class PW_Basis
public:
std::string classname;
PW_Basis();
PW_Basis(std::string device_, std::string precision_);
virtual ~PW_Basis();
//Init mpi parameters
#ifdef __MPI
Expand Down Expand Up @@ -260,9 +261,16 @@ class PW_Basis
using resmem_int_op = psi::memory::resize_memory_op<int, psi::DEVICE_GPU>;
using delmem_int_op = psi::memory::delete_memory_op<int, psi::DEVICE_GPU>;
using syncmem_int_h2d_op = psi::memory::synchronize_memory_op<int, psi::DEVICE_GPU, psi::DEVICE_CPU>;

void set_device(std::string device_);
void set_precision(std::string precision_);

protected:
std::string device = "cpu";
std::string precision = "double";
};

}
#endif //PlaneWave
#endif // PWBASIS_H

#include "./pw_basis_big.h" //temporary it will be removed
2 changes: 2 additions & 0 deletions source/module_pw/pw_basis_big.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class PW_Basis_Big: public PW_Basis
by = 1;
bz = 1;
}
PW_Basis_Big(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {}

~PW_Basis_Big(){};
void setbxyz(const int bx_in, const int by_in, const int bz_in)
{
Expand Down
24 changes: 13 additions & 11 deletions source/module_pw/pw_basis_k.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "pw_basis_k.h"

#include <utility>
#include "../module_base/constants.h"
#include "../module_base/timer.h"
#include "module_base/memory.h"
Expand All @@ -20,8 +22,8 @@ PW_Basis_K::~PW_Basis_K()
delete[] gk2;
delete[] ig2ixyz_k_;
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
delmem_sd_op()(gpu_ctx, this->s_kvec_c);
delmem_sd_op()(gpu_ctx, this->s_gcar);
delmem_sd_op()(gpu_ctx, this->s_gk2);
Expand All @@ -36,7 +38,7 @@ PW_Basis_K::~PW_Basis_K()
}
else {
#endif
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
delmem_sh_op()(cpu_ctx, this->s_kvec_c);
delmem_sh_op()(cpu_ctx, this->s_gcar);
delmem_sh_op()(cpu_ctx, this->s_gk2);
Expand Down Expand Up @@ -91,8 +93,8 @@ void PW_Basis_K:: initparameters(
this->fftnxyz = this->fftnxy * this->fftnz;
this->distribution_type = distribution_type_in;
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
resmem_sd_op()(gpu_ctx, this->s_kvec_c, this->nks * 3);
castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
}
Expand All @@ -103,7 +105,7 @@ void PW_Basis_K:: initparameters(
}
else {
#endif
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
resmem_sh_op()(cpu_ctx, this->s_kvec_c, this->nks * 3);
castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_kvec_c, reinterpret_cast<double *>(&this->kvec_c[0][0]), this->nks * 3);
}
Expand Down Expand Up @@ -160,7 +162,7 @@ void PW_Basis_K::setupIndGk()
}
}
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (this->device == "gpu") {
resmem_int_op()(gpu_ctx, this->d_igl2isz_k, this->npwk_max * this->nks);
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks);
}
Expand Down Expand Up @@ -221,8 +223,8 @@ void PW_Basis_K::collect_local_pw()
}
}
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (GlobalV::precision_flag == "single") {
if (this->device == "gpu") {
if (this->precision == "single") {
resmem_sd_op()(gpu_ctx, this->s_gk2, this->npwk_max * this->nks);
resmem_sd_op()(gpu_ctx, this->s_gcar, this->npwk_max * this->nks * 3);
castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks);
Expand All @@ -237,7 +239,7 @@ void PW_Basis_K::collect_local_pw()
}
else {
#endif
if (GlobalV::precision_flag == "single") {
if (this->precision == "single") {
resmem_sh_op()(cpu_ctx, this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2");
resmem_sh_op()(cpu_ctx, this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar");
castmem_d2s_h2h_op()(cpu_ctx, cpu_ctx, this->s_gk2, this->gk2, this->npwk_max * this->nks);
Expand Down Expand Up @@ -325,7 +327,7 @@ void PW_Basis_K::get_ig2ixyz_k()
}
}
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (this->device == "gpu") {
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks);
}
Expand Down
2 changes: 2 additions & 0 deletions source/module_pw/pw_basis_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class PW_Basis_K : public PW_Basis

public:
PW_Basis_K();
PW_Basis_K(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {}
~PW_Basis_K();

//init parameters of pw_basis_k class
Expand Down Expand Up @@ -132,6 +133,7 @@ class PW_Basis_K : public PW_Basis
private:
float * s_gcar = nullptr, * s_kvec_c = nullptr;
double * d_gcar = nullptr, * d_kvec_c = nullptr;

};

}
Expand Down
1 change: 1 addition & 0 deletions source/module_pw/pw_basis_k_big.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class PW_Basis_K_Big: public PW_Basis_K
by = 1;
bz = 1;
}
PW_Basis_K_Big(std::string device_, std::string precision_) : PW_Basis_K(device_, precision_) {}
~PW_Basis_K_Big(){};
void setbxyz(const int bx_in, const int by_in, const int bz_in)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_pw/pw_distributeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void PW_Basis::get_ig2isz_is2fftixy(
delete[] this->ig2isz; this->ig2isz = nullptr; // map ig to the z coordinate of this planewave.
delete[] this->is2fftixy; this->is2fftixy = nullptr; // map is (index of sticks) to ixy (iy + ix * fftny).
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (this->device == "gpu") {
delmem_int_op()(gpu_ctx, this->d_is2fftixy);
d_is2fftixy = nullptr;
}
Expand Down Expand Up @@ -186,7 +186,7 @@ void PW_Basis::get_ig2isz_is2fftixy(
if (st_move == this->nst && pw_filled == this->npw) break;
}
#if defined(__CUDA) || defined(__ROCM)
if (GlobalV::device_flag == "gpu") {
if (this->device == "gpu") {
resmem_int_op()(gpu_ctx, d_is2fftixy, this->nst);
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_is2fftixy, this->is2fftixy, this->nst);
}
Expand Down
9 changes: 4 additions & 5 deletions source/module_pw/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
add_definitions(-D__NORMAL -D__MIX_PRECISION)
add_definitions(-D__NORMAL -D__NOMPICOMPLEX)
AddTest(
TARGET pw_test
LIBS ${math_libs} planewave psi device
LIBS ${math_libs} planewave device
SOURCES ../../module_base/matrix.cpp ../../module_base/complexmatrix.cpp ../../module_base/matrix3.cpp ../../module_base/tool_quit.cpp
../../module_base/mymath.cpp ../../module_base/timer.cpp ../../module_base/memory.cpp ../../module_base/global_variable.cpp
../../module_base/parallel_common.cpp
../../module_base/parallel_global.cpp ../../module_base/parallel_reduce.cpp
../../module_base/mymath.cpp ../../module_base/timer.cpp ../../module_base/memory.cpp
../../module_psi/kernels/memory_op.cpp
pw_test.cpp test1-1-1.cpp test1-1-2.cpp test1-2.cpp test1-3.cpp test1-4.cpp test1-5.cpp
test2-1-1.cpp test2-1-2.cpp test2-2.cpp test2-3.cpp
test3-1.cpp test3-2.cpp test3-3.cpp test3-3-2.cpp
Expand Down
6 changes: 2 additions & 4 deletions source/module_pw/test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ FFTW_DIR = /home/qianrui/gnucompile/fftw_3.3.8
#==========================
# Compiler information
#==========================
HONG = -D__NORMAL
HONG = -D__NORMAL -D__NOMPICOMPLEX
INCLUDES = -I. -I../../
LIBS =
OPTS = -Ofast -march=native -std=c++11 -m64 ${INCLUDES}
Expand Down Expand Up @@ -130,9 +130,7 @@ pw_transform_k.o\
memory.o\
memory_op.o

OTHER_OBJS0=global_variable.o\
parallel_global.o\
parallel_reduce.o
OTHER_OBJS0=


TESTFILE0=test1-1-1.o\
Expand Down
Loading