Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions source/source_base/timer_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef TIMER_WRAPPER_H
#define TIMER_WRAPPER_H

#include <chrono>

#ifdef __MPI
#include <mpi.h>
#endif

namespace ModuleBase {

/**
* @brief Time point type that works in both MPI and non-MPI environments
*/
typedef double TimePoint;

/**
* @brief Get current time as a TimePoint
*
* @return TimePoint Current time
*/
inline TimePoint get_time()
{
#ifdef __MPI
int is_initialized = 0;
MPI_Initialized(&is_initialized);
if (is_initialized)
{
return MPI_Wtime();
}
else
{
return std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count() / 1e6;
}
#else
return std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()).count() / 1e6;
#endif
}

/**
* @brief Calculate duration between two TimePoints in seconds
*
* @param start Start time point
* @param end End time point
* @return double Duration in seconds
*/
inline double get_duration(const TimePoint& start, const TimePoint& end)
{
return end - start;
}

}

#endif // TIMER_WRAPPER_H
51 changes: 18 additions & 33 deletions source/source_esolver/esolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,23 +311,23 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
// of LR-TDDFT is implemented.
std::cout << " PREPARING FOR EXCITED STATES." << std::endl;
// initialize the 2nd ESolver_LR at the temporary pointer
ModuleESolver::ESolver* p_esolver_lr = nullptr;
if (PARAM.globalv.gamma_only_local)
{
p_esolver_lr = new LR::ESolver_LR<double, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<double, double>*>(p_esolver)),
inp,
ucell);
}
else
{
p_esolver_lr = new LR::ESolver_LR<std::complex<double>, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver)),
inp,
ucell);
}
// clean the 1st ESolver_KS and swap the pointer
ModuleESolver::clean_esolver(p_esolver, false); // do not call Cblacs_exit, remain it for the 2nd ESolver
ModuleESolver::ESolver* p_esolver_lr = nullptr;
if (PARAM.globalv.gamma_only_local)
{
p_esolver_lr = new LR::ESolver_LR<double, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<double, double>*>(p_esolver)),
inp,
ucell);
}
else
{
p_esolver_lr = new LR::ESolver_LR<std::complex<double>, double>(
std::move(*dynamic_cast<ModuleESolver::ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver)),
inp,
ucell);
}
// clean the 1st ESolver_KS and swap the pointer
delete p_esolver;
return p_esolver_lr;
}
#endif
Expand Down Expand Up @@ -355,20 +355,5 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
+ " line " + std::to_string(__LINE__));
}

void clean_esolver(ESolver*& pesolver, const bool lcao_cblacs_exit)
{
// Zhang Xiaoyang modified in 2024/7/6:
// Note: because of the init method of serial lcao hsolver
// it needs no release step for it, or this [delete] will cause Segmentation Fault
// Probably it will be modified later.
#ifdef __MPI
delete pesolver;
#ifdef __LCAO
if (lcao_cblacs_exit)
{
Cblacs_exit(1);
}
#endif
#endif
}

} // namespace ModuleESolver
2 changes: 1 addition & 1 deletion source/source_esolver/esolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std::string determine_type();
*/
ESolver* init_esolver(const Input_para& inp, UnitCell& ucell);

void clean_esolver(ESolver*& pesolver, const bool lcao_cblacs_exit = false);


} // namespace ModuleESolver

Expand Down
10 changes: 2 additions & 8 deletions source/source_esolver/esolver_fp.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

#include "esolver.h"

#ifndef __MPI
#include <chrono>
#endif
#include "source_base/timer_wrapper.h"

#include "source_basis/module_pw/pw_basis.h" // plane wave basis
#include "source_cell/module_symmetry/symmetry.h" // symmetry analysis
Expand Down Expand Up @@ -83,11 +81,7 @@ class ESolver_FP: public ESolver
bool pw_rho_flag = false; ///< flag for pw_rho, 0: not initialized, 1: initialized

//! the start time of scf iteration
#ifdef __MPI
double iter_time;
#else
std::chrono::system_clock::time_point iter_time;
#endif
ModuleBase::TimePoint iter_time;
};
} // namespace ModuleESolver

Expand Down
15 changes: 3 additions & 12 deletions source/source_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "esolver_ks.h"
#include "source_base/timer_wrapper.h"

// for jason output information
#include "source_io/json_output/init_info.h"
Expand Down Expand Up @@ -190,11 +191,7 @@ void ESolver_KS<T, Device>::iter_init(UnitCell& ucell, const int istep, const in
ModuleIO::write_head(GlobalV::ofs_running, istep, iter, this->basisname);
}

#ifdef __MPI
iter_time = MPI_Wtime();
#else
iter_time = std::chrono::system_clock::now();
#endif
iter_time = ModuleBase::get_time();

if (PARAM.inp.esolver_type == "ksdft")
{
Expand Down Expand Up @@ -281,13 +278,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i


// the end, print time
#ifdef __MPI
double duration = (double)(MPI_Wtime() - iter_time);
#else
double duration
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iter_time)).count()
/ static_cast<double>(1e6);
#endif
double duration = ModuleBase::get_duration(iter_time, ModuleBase::get_time());

// print energies
elecstate::print_etot(ucell.magnet, *pelec, conv_esolver, iter, drho,
Expand Down
28 changes: 18 additions & 10 deletions source/source_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,25 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
ESolver_KS<TK>::after_all_runners(ucell);

auto* hamilt_lcao = dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
if(!hamilt_lcao)
{
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::after_all_runners","p_hamilt does not exist");
}
if(!hamilt_lcao)
{
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::after_all_runners","p_hamilt does not exist");
}

ModuleIO::ctrl_runner_lcao<TK, TR>(ucell,
PARAM.inp, this->kv, this->pelec, this->dmat, this->pv, this->Pgrid,
this->gd, this->psi, this->chr, hamilt_lcao,
this->two_center_bundle_,
this->orb_, this->pw_rho, this->pw_rhod,
this->sf, this->locpp.vloc, this->exx_nao, this->solvent);
ModuleIO::ctrl_runner_lcao<TK, TR>(ucell,
PARAM.inp, this->kv, this->pelec, this->dmat, this->pv, this->Pgrid,
this->gd, this->psi, this->chr, hamilt_lcao,
this->two_center_bundle_,
this->orb_, this->pw_rho, this->pw_rhod,
this->sf, this->locpp.vloc, this->exx_nao, this->solvent);


#ifdef __MPI
#ifdef __LCAO
// Exit BLACS environment for LCAO calculations
Cblacs_exit(1);
#endif
#endif

ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners");
}
Expand Down
14 changes: 5 additions & 9 deletions source/source_esolver/esolver_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ ESolver_OF::ESolver_OF()

ESolver_OF::~ESolver_OF()
{
//****************************************************
// do not add any codes in this deconstructor funcion
//****************************************************
delete psi_;
//****************************************************
// do not add any codes in this deconstructor funcion
//****************************************************
delete psi_;
delete[] this->pphi_;

for (int i = 0; i < PARAM.inp.nspin; ++i)
Expand Down Expand Up @@ -137,11 +137,7 @@ void ESolver_OF::runner(UnitCell& ucell, const int istep)
this->iter_ = 0;

bool conv_esolver = false; // this conv_esolver is added by mohan 20250302
#ifdef __MPI
this->iter_time = MPI_Wtime();
#else
this->iter_time = std::chrono::system_clock::now();
#endif
this->iter_time = ModuleBase::get_time();

while (true)
{
Expand Down
6 changes: 1 addition & 5 deletions source/source_esolver/esolver_of_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ void ESolver_OF_TDDFT::runner(UnitCell& ucell, const int istep)
this->iter_ = 0;

bool conv_esolver = false; // this conv_esolver is added by mohan 20250302
#ifdef __MPI
this->iter_time = MPI_Wtime();
#else
this->iter_time = std::chrono::system_clock::now();
#endif
this->iter_time = ModuleBase::get_time();

if (this->phi_td.empty())
{
Expand Down
8 changes: 2 additions & 6 deletions source/source_main/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ void Driver::driver_run()
else if (cal == "get_pchg" || cal == "get_wf" || cal == "gen_bessel" || cal == "gen_opt_abfs" ||
cal == "test_memory" || cal == "test_neighbour")
{
//! supported "other" functions:
//! get_pchg(LCAO),
//! test_memory(PW,LCAO),
//! test_neighbour(LCAO),
//! gen_bessel(PW), et al.
const int istep = 0;
p_esolver->others(ucell, istep);
}
Expand All @@ -106,7 +101,8 @@ void Driver::driver_run()
//! 5: clean up esolver
p_esolver->after_all_runners(ucell);

ModuleESolver::clean_esolver(p_esolver);
delete p_esolver;

this->finalize_hardware();

//! 6: output the json file
Expand Down
32 changes: 9 additions & 23 deletions source/source_pw/module_ofdft/of_print_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
* and write the components of the total energy into running_log.
*/
void OFDFT::print_info(const int iter,
#ifdef __MPI
double &iter_time,
#else
std::chrono::system_clock::time_point &iter_time,
#endif
const double &energy_current,
const double &energy_last,
const double &normdLdphi,
const elecstate::ElecState *pelec,
KEDF_Manager *kedf_manager,
const bool conv_esolver)
ModuleBase::TimePoint &iter_time,
const double &energy_current,
const double &energy_last,
const double &normdLdphi,
const elecstate::ElecState *pelec,
KEDF_Manager *kedf_manager,
const bool conv_esolver)
{
if (iter == 0)
{
Expand All @@ -35,13 +31,7 @@ void OFDFT::print_info(const int iter,
{"tn", "TN"}
};
std::string iteration = prefix_map[PARAM.inp.of_method] + std::to_string(iter);
#ifdef __MPI
double duration = (double)(MPI_Wtime() - iter_time);
#else
double duration
= (std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - iter_time)).count()
/ static_cast<double>(1e6);
#endif
double duration = ModuleBase::get_duration(iter_time, ModuleBase::get_time());
std::cout << " " << std::setw(8) << iteration
<< std::setw(18) << std::scientific << std::setprecision(8) << energy_current * ModuleBase::Ry_to_eV
<< std::setw(18) << (energy_current - energy_last) * ModuleBase::Ry_to_eV
Expand Down Expand Up @@ -141,9 +131,5 @@ void OFDFT::print_info(const int iter,
GlobalV::ofs_running << table.str() << std::endl;

// reset the iter_time for the next iteration
#ifdef __MPI
iter_time = MPI_Wtime();
#else
iter_time = std::chrono::system_clock::now();
#endif
iter_time = ModuleBase::get_time();
}
20 changes: 8 additions & 12 deletions source/source_pw/module_ofdft/of_print_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
#include "source_estate/elecstate.h" // electronic states
#include "source_pw/module_ofdft/kedf_manager.h"

#include <chrono>
#include "source_base/timer_wrapper.h"


namespace OFDFT
{

void print_info(const int iter,
#ifdef __MPI
double &iter_time,
#else
std::chrono::system_clock::time_point &iter_time,
#endif
const double &energy_current,
const double &energy_last,
const double &normdLdphi,
const elecstate::ElecState *pelec,
KEDF_Manager *kedf_manager,
const bool conv_esolver);
ModuleBase::TimePoint &iter_time,
const double &energy_current,
const double &energy_last,
const double &normdLdphi,
const elecstate::ElecState *pelec,
KEDF_Manager *kedf_manager,
const bool conv_esolver);

}

Expand Down
Loading