diff --git a/source/source_pw/module_pwdft/operator_pw/exx_pw_ace.cpp b/source/source_pw/module_pwdft/operator_pw/exx_pw_ace.cpp index 80c4a0dc6a..c1fcb65de7 100644 --- a/source/source_pw/module_pwdft/operator_pw/exx_pw_ace.cpp +++ b/source/source_pw/module_pwdft/operator_pw/exx_pw_ace.cpp @@ -1,4 +1,6 @@ #include "op_exx_pw.h" +#include "source_base/parallel_comm.h" +#include "source_io/module_parameter/parameter.h" namespace hamilt { @@ -60,10 +62,6 @@ void OperatorEXXPW::act_op_ace(const int nbands, nbasis ); - - // // negative sign, add to hpsi - // vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1); - // delmem_complex_op()(hpsi); delmem_complex_op()(Xi_psi); ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace"); @@ -72,13 +70,12 @@ void OperatorEXXPW::act_op_ace(const int nbands, template void OperatorEXXPW::construct_ace() const { - // int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk(); int nbands = psi.get_nbands(); int nbasis = psi.get_nbasis(); int nk = psi.get_nk(); + int* ik_ = const_cast(&this->ik); int ik_save = this->ik; - int * ik_ = const_cast(&this->ik); T intermediate_one = 1.0, intermediate_zero = 0.0; @@ -116,93 +113,167 @@ void OperatorEXXPW::construct_ace() const if (first_iter) return; ModuleBase::timer::tick("OperatorEXXPW", "construct_ace"); - for (int ik = 0; ik < nk; ik++) + int nk_max = kv->para_k.get_max_nks_pool(); + int nspin_fac = PARAM.inp.nspin == 2 ? 2 : 1; + for (int ispin = 0; ispin < nspin_fac; ispin++) { - int npwk = wfcpw->npwk[ik]; - - T* Xi_ace = Xi_ace_k[ik]; - psi.fix_kb(ik, 0); - T* p_psi = psi.get_pointer(); - - setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + for (int ik0 = 0; ik0 < nk_max; ik0++) + { + int ik = ik0 + ispin * nk_max; + int npwk = wfcpw->npwk[ik]; + + T* Xi_ace = Xi_ace_k[ik]; + psi.fix_kb(ik, 0); + T* p_psi = psi.get_pointer(); + + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); + setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + int nqs = kv->get_nkstot_full(); + + bool skip_ik = false; + if (ik >= wfcpw->nks) + { + skip_ik = true; + } + if (skip_ik) + { + // ik fixed here, select band n + for (int iq0 = 0; iq0 < nqs; iq0++) + { + int iq = iq0 + ik; + // for \psi_nk, get the pw of iq and band m + get_exx_potential(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, ik, iq); + + // decide which pool does the iq belong to + int iq_pool = kv->para_k.whichpool[iq0]; + int iq_loc = iq - kv->para_k.startk_pool[iq_pool]; + + for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++) + { + double wg_mqb = 0; + bool skip = false; + if (iq_pool == GlobalV::MY_POOL) + { + wg_mqb = (*wg)(iq_loc, m_iband); + } +#ifdef __MPI + MPI_Bcast(&wg_mqb, 1, MPI_DOUBLE, kv->para_k.get_startpro_pool(iq_pool), MPI_COMM_WORLD); +#endif + if (wg_mqb < 1e-12) + continue; + + if (iq_pool == GlobalV::MY_POOL) + { + const T* psi_mq = get_pw(m_iband, iq_loc); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc); + // send + } + // if (iq == 0) + // std::cout << "Bcast psi_mq_real" << std::endl; +#ifdef __MPI + MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); +#endif - *ik_ = ik; - - act_op( - nbands, - nbasis, - 1, - p_psi, - h_psi_ace, - nbasis, - false - ); - - // psi_h_psi_ace = psi^\dagger * h_psi_ace - // p_exx_helper->psi.fix_kb(0, 0); - gemm_complex_op()('C', - 'N', - nbands, - nbands, - npwk, - &intermediate_one, - p_psi, - nbasis, - h_psi_ace, - nbasis, - &intermediate_zero, - psi_h_psi_ace, - nbands); - - // reduction of psi_h_psi_ace, due to distributed memory - Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); - - T intermediate_minus_one = -1.0; - axpy_complex_op()(nbands * nbands, - &intermediate_minus_one, - psi_h_psi_ace, - 1, - L_ace, - 1); - - - int info = 0; - char up = 'U', lo = 'L'; - - lapack_potrf()(lo, nbands, L_ace, nbands); - - // expand for-loop - for (int i = 0; i < nbands; ++i) { - setmem_complex_op()(L_ace + i * nbands, 0, i); + } // end of iq + + } + } + else + { + *ik_ = ik; + act_op_kpar(nbands, nbasis, 1, p_psi, h_psi_ace, nbasis, false); + // psi_h_psi_ace = psi^\dagger * h_psi_ace + // p_exx_helper->psi.fix_kb(0, 0); + gemm_complex_op()('C', + 'N', + nbands, + nbands, + npwk, + &intermediate_one, + p_psi, + nbasis, + h_psi_ace, + nbasis, + &intermediate_zero, + psi_h_psi_ace, + nbands); + + // reduction of psi_h_psi_ace, due to distributed memory + Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); + + T intermediate_minus_one = -1.0; + axpy_complex_op()(nbands * nbands, + &intermediate_minus_one, + psi_h_psi_ace, + 1, + L_ace, + 1); + + + int info = 0; + char up = 'U', lo = 'L'; + + // for (int i = 0; i < nbands; ++i) + // { + // for (int j = 0; j < nbands; ++j) + // { + // // std::cout << L_ace[i * nbands + j]. << " "; + // if (L_ace[i * nbands + j].imag() >= 0.0) + // { + // std::cout << L_ace[i * nbands + j].real() << "+" << L_ace[i * nbands + j].imag() << "im "; + // } + // else + // { + // std::cout << L_ace[i * nbands + j].real() << L_ace[i * nbands + j].imag() << "im "; + // } + // } + // std::cout << ";" << std::endl; + // } + // MPI_Barrier(MPI_COMM_WORLD); + // MPI_Abort(MPI_COMM_WORLD, 0); + + lapack_potrf()(lo, nbands, L_ace, nbands); + + // expand for-loop + for (int i = 0; i < nbands; ++i) { + setmem_complex_op()(L_ace + i * nbands, 0, i); + } + + // L_ace inv in place + char non = 'N'; + lapack_trtri()(lo, non, nbands, L_ace, nbands); + + // Xi_ace = L_ace^-1 * h_psi_ace^dagger + gemm_complex_op()('N', + 'C', + nbands, + npwk, + nbands, + &intermediate_one, + L_ace, + nbands, + h_psi_ace, + nbasis, + &intermediate_zero, + Xi_ace, + nbands); + + // clear mem + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands); + setmem_complex_op()(L_ace, 0, nbands * nbands); + } } - - // L_ace inv in place - char non = 'N'; - lapack_trtri()(lo, non, nbands, L_ace, nbands); - - // Xi_ace = L_ace^-1 * h_psi_ace^dagger - gemm_complex_op()('N', - 'C', - nbands, - npwk, - nbands, - &intermediate_one, - L_ace, - nbands, - h_psi_ace, - nbasis, - &intermediate_zero, - Xi_ace, - nbands); - - // clear mem - setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); - setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands); - setmem_complex_op()(L_ace, 0, nbands * nbands); - } *ik_ = ik_save; + ModuleBase::timer::tick("OperatorEXXPW", "construct_ace"); } @@ -234,7 +305,7 @@ double OperatorEXXPW::cal_exx_energy_ace(psi::Psi* ppsi_) } } - Parallel_Reduce::reduce_pool(Eexx); + Parallel_Reduce::reduce_all(Eexx); *ik_ = ik_save; return Eexx; } diff --git a/source/source_pw/module_pwdft/operator_pw/exx_pw_pot.cpp b/source/source_pw/module_pwdft/operator_pw/exx_pw_pot.cpp index 8c2c6dd140..36700fedd3 100644 --- a/source/source_pw/module_pwdft/operator_pw/exx_pw_pot.cpp +++ b/source/source_pw/module_pwdft/operator_pw/exx_pw_pot.cpp @@ -30,6 +30,20 @@ void get_exx_potential(const K_Vectors* kv, // fill zero setmem_real_cpu_op()(pot_cpu, 0, npw); + std::vector> qvec_c, qvec_d; +#ifdef __MPI + kv->para_k.gatherkvec(kv->kvec_c, qvec_c); + kv->para_k.gatherkvec(kv->kvec_d, qvec_d); +#else + qvec_c = kv->kvec_c; + qvec_d = kv->kvec_d; +#endif + + if (ik > nks) + { + return; + } + // calculate Fock pot auto param_fock = GlobalC::exx_info.info_global.coulomb_param[Conv_Coulomb_Pot_K::Coulomb_Type::Fock]; for (int i = 0; i < param_fock.size(); i++) @@ -39,8 +53,8 @@ void get_exx_potential(const K_Vectors* kv, double alpha = std::stod(param["alpha"]); const ModuleBase::Vector3 k_c = wfcpw->kvec_c[ik]; const ModuleBase::Vector3 k_d = wfcpw->kvec_d[ik]; - const ModuleBase::Vector3 q_c = wfcpw->kvec_c[iq]; - const ModuleBase::Vector3 q_d = wfcpw->kvec_d[iq]; + const ModuleBase::Vector3 q_c = qvec_c[iq]; + const ModuleBase::Vector3 q_d = qvec_d[iq]; #ifdef _OPENMP #pragma omp parallel for schedule(static) @@ -109,8 +123,8 @@ void get_exx_potential(const K_Vectors* kv, ucell_omega); const ModuleBase::Vector3 k_c = wfcpw->kvec_c[ik]; const ModuleBase::Vector3 k_d = wfcpw->kvec_d[ik]; - const ModuleBase::Vector3 q_c = wfcpw->kvec_c[iq]; - const ModuleBase::Vector3 q_d = wfcpw->kvec_d[iq]; + const ModuleBase::Vector3 q_c = qvec_c[iq]; + const ModuleBase::Vector3 q_d = qvec_d[iq]; #ifdef _OPENMP #pragma omp parallel for schedule(static) @@ -146,6 +160,10 @@ void get_exx_potential(const K_Vectors* kv, // const int ig_kq = ik * nks * npw + iq * npw + ig; Real gg = (k_c - q_c + rhopw_dev->gcar[ig]).norm2() * tpiba2; + // if (ig == 0 && GlobalV::MY_RANK==1) + // { + // printf("k-q+G: %f %f %f\n", (k_c - q_c + rhopw_dev->gcar[ig])[0], (k_c - q_c + rhopw_dev->gcar[ig])[1], (k_c - q_c + rhopw_dev->gcar[ig])[2]); + // } // if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2 if (gg >= 1e-8) { @@ -388,7 +406,7 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type, // this is the \sum_q F(q) part // temporarily for all k points, should be replaced to q points later - for (int ik = 0; ik < wfcpw->nks; ik++) + for (int ik = 0; ik < wfcpw->nks / nk_fac; ik++) { const ModuleBase::Vector3 k_c = wfcpw->kvec_c[ik]; const ModuleBase::Vector3 k_d = wfcpw->kvec_d[ik]; @@ -437,7 +455,7 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type, } } - Parallel_Reduce::reduce_pool(div); + Parallel_Reduce::reduce_all(div); // std::cout << "EXX div: " << div << std::endl; // if (PARAM.inp.dft_functional == "hse") @@ -454,8 +472,8 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type, } } - div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / wfcpw->nks; - // std::cout << "div: " << div << std::endl; + div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / kv->get_nkstot_full(); + // std::cout << "div: " << div << std::endl; // numerically value the mean value of F(q) in the reciprocal space // This means we need to calculate the average of F(q) in the first brillouin zone @@ -481,9 +499,9 @@ double exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type coulomb_type, aa += 1.0 / std::sqrt(alpha * ModuleBase::PI); div -= ModuleBase::e2 * ucell_omega * aa; - exx_div = div * wfcpw->nks / nk_fac; + exx_div = div * kv->get_nkstot_full(); // exx_div = 0; - // std::cout << "EXX divergence: " << exx_div << std::endl; + // std::cout << "EXX divergence: " << exx_div << std::endl; return exx_div; } diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp index 404c660877..9a7737f821 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp @@ -37,6 +37,11 @@ OperatorEXXPW::OperatorEXXPW(const int* isk_in, const UnitCell *ucell) : isk(isk_in), wfcpw(wfcpw_in), rhopw(rhopw_in), kv(kv_in), ucell(ucell) { + if (GlobalV::KPAR != 1 && PARAM.inp.exxace == false) + { + // GlobalV::ofs_running << "EXX Calculation does not support k-point parallelism" << std::endl; + ModuleBase::WARNING_QUIT("OperatorEXXPW", "EXX Calculation does not support k-point parallelism when exxace is set to false"); + } gamma_extrapolation = PARAM.inp.exx_gamma_extrapolation; bool is_mp = kv_in->get_is_mp(); #ifdef __MPI @@ -46,11 +51,6 @@ OperatorEXXPW::OperatorEXXPW(const int* isk_in, { gamma_extrapolation = false; } - if (GlobalV::KPAR != 1) - { - // GlobalV::ofs_running << "EXX Calculation does not support k-point parallelism" << std::endl; - ModuleBase::WARNING_QUIT("OperatorEXXPW", "EXX Calculation does not support k-point parallelism"); - } this->classname = "OperatorEXXPW"; this->ctx = nullptr; @@ -192,27 +192,21 @@ void OperatorEXXPW::act_op(const int nbands, const int ngk_ik, const bool is_first_node) const { -// std::cout << "nbands: " << nbands -// << " nbasis: " << nbasis -// << " npol: " << npol -// << " ngk_ik: " << ngk_ik -// << " is_first_node: " << is_first_node -// << std::endl; - // get_exx_potential(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, ik, iq); - -// set_psi(&p_exx_helper->psi); - ModuleBase::timer::tick("OperatorEXXPW", "act_op"); setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); setmem_complex_op()(density_recip, 0, rhopw_dev->npw); - // setmem_complex_op()(psi_all_real, 0, wfcpw->nrxx * GlobalV::NBANDS); - // std::map, bool> has_real; setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + auto q_points = get_q_points(this->ik); + // std::cout << "kpoint " << this->ik << ", qpoints: "; + // for (auto iq: q_points) + // std::cout << iq << ", "; + // std::cout << std::endl; + // ik fixed here, select band n for (int n_iband = 0; n_iband < nbands; n_iband++) { @@ -221,12 +215,11 @@ void OperatorEXXPW::act_op(const int nbands, wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik); // for \psi_nk, get the pw of iq and band m - auto q_points = get_q_points(this->ik); + Real nqs = q_points.size(); for (int iq: q_points) { get_exx_potential(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, this->ik, iq); -// std::cout << "ik" << this->ik << " iq" << iq << std::endl; for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++) { // double wg_mqb_real = GlobalC::exx_helper.wg(iq, m_iband); @@ -243,14 +236,10 @@ void OperatorEXXPW::act_op(const int nbands, // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) cal_density_recip(psi_nk_real, psi_mq_real, ucell->omega); - // bring the density to recip space - // rhopw->real2recip(density_real, density_recip); - // multiply the density with the potential in recip space multiply_potential(density_recip, this->ik, iq); // bring the potential back to real space - // rhopw_dev->recip2real(density_recip, density_real); rho_recip2real(density_recip, density_real); if (false) @@ -290,6 +279,109 @@ void OperatorEXXPW::act_op(const int nbands, } +template +void OperatorEXXPW::act_op_kpar(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik, + const bool is_first_node) const +{ + ModuleBase::timer::tick("OperatorEXXPW", "act_op_kpar"); + + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); + // setmem_complex_op()(psi_all_real, 0, wfcpw->nrxx * GlobalV::NBANDS); + // std::map, bool> has_real; + setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + int nqs = kv->get_nkstot_full(); + + // ik fixed here, select band n + for (int iq = 0; iq < nqs; iq++) + { + // for \psi_nk, get the pw of iq and band m + get_exx_potential(kv, wfcpw, rhopw_dev, pot, tpiba, gamma_extrapolation, ucell->omega, this->ik, iq); + + // decide which pool does the iq belong to + int iq_pool = kv->para_k.whichpool[iq]; + int iq_loc = iq - kv->para_k.startk_pool[iq_pool]; + + for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++) + { + double wg_mqb = 0; + if (iq_pool == GlobalV::MY_POOL) + { + wg_mqb = (*wg)(iq_loc, m_iband); + } +#ifdef __MPI + MPI_Bcast(&wg_mqb, 1, MPI_DOUBLE, kv->para_k.get_startpro_pool(iq_pool), MPI_COMM_WORLD); +#endif + if (wg_mqb < 1e-12) + continue; + + if (iq_pool == GlobalV::MY_POOL) + { + const T* psi_mq = get_pw(m_iband, iq_loc); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc); + // send + } +#ifdef __MPI + MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); +#endif + for (int n_iband = 0; n_iband < nbands; n_iband++) + { + double wg_nkb = (*wg)(this->ik, n_iband); + const T* psi_nk = tmpsi_in + n_iband * nbasis; + // retrieve \psi_nk in real space + wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik); + + + // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) + cal_density_recip(psi_nk_real, psi_mq_real, ucell->omega); + + mul_potential_op()(pot, density_recip, rhopw_dev->npw, wfcpw->nks, this->ik, iq); + + // bring the potential back to real space + rho_recip2real(density_recip, density_real); + + if (false) + { + // do nothing + } + else + { + vec_mul_vec_complex_op()(density_real, psi_mq_real, density_real, wfcpw->nrxx); + } + + + Real wk_iq = kv->wk[iq]; + Real wk_ik = kv->wk[this->ik]; + // std::cout << "wk_iq: " << wk_iq << " wk_ik: " << wk_ik << std::endl; + + Real tmp_scalar = wg_mqb / wk_ik / nqs; + + T* h_psi_nk = tmhpsi + n_iband * nbasis; + Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha; + wfcpw->real_to_recip(ctx, density_real, h_psi_nk, this->ik, true, hybrid_alpha * tmp_scalar); + + + } // end of m_iband + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + + } // end of iq + + } + + ModuleBase::timer::tick("OperatorEXXPW", "act_op_kpar"); + +} + template std::vector OperatorEXXPW::get_q_points(const int ik) const { @@ -440,9 +532,29 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c // for \psi_nk, get the pw of iq and band m // q_points is a vector of integers, 0 to nks-1 std::vector q_points; - for (int iq = 0; iq < wfcpw->nks; iq++) + if (PARAM.inp.nspin == 1) + { + for (int iq = 0; iq < wfcpw->nks; iq++) + { + q_points.push_back(iq); + } + } + else if (PARAM.inp.nspin == 2) + { + int nk = wfcpw->nks / nk_fac; + int k_spin = ik / nk; + for (int iq = 0; iq < wfcpw->nks; iq++) + { + int q_spin = iq / nk; + if (k_spin == q_spin) + { + q_points.push_back(iq); + } + } + } + else { - q_points.push_back(iq); + ModuleBase::WARNING_QUIT("OperatorEXXPW", "nspin == 4 not supported"); } double nqs = q_points.size(); diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h index d823eda015..4e9d953afc 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h @@ -78,6 +78,14 @@ class OperatorEXXPW : public OperatorPW const int ngk_ik = 0, const bool is_first_node = false) const; + void act_op_kpar(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const; + void act_op_ace(const int nbands, const int nbasis, const int npol,