From 73b7f8269f66fe5047f2e13a274e8224fa77b228 Mon Sep 17 00:00:00 2001 From: kirk0830 Date: Tue, 9 Jan 2024 10:16:47 +0800 Subject: [PATCH 1/3] support list input of bessel_nao_rcut --- source/module_esolver/esolver_ks_pw.cpp | 23 +++++- source/module_io/input.cpp | 40 +++++++++- source/module_io/input.h | 8 +- source/module_io/test/input_test.cpp | 97 ++++++++++++++++++++++++ source/module_io/test/support/INPUT_list | 9 +++ 5 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 source/module_io/test/support/INPUT_list diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index b1a32282f6..294f1636f7 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -1065,7 +1065,28 @@ void ESolver_KS_PW::postprocess() if (winput::out_spillage <= 2) { Numerical_Basis numerical_basis; - numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc); + if(INPUT.bessel_nao_rcuts.size() == 1) + { + numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc); + } + else + { + for(int i = 0; i < INPUT.bessel_nao_rcuts.size(); i++) + { + if(GlobalV::MY_RANK == 0) {std::cout << "update value: bessel_nao_rcut <- " << std::fixed << INPUT.bessel_nao_rcuts[i] << " a.u." << std::endl;} + INPUT.bessel_nao_rcut = INPUT.bessel_nao_rcuts[i]; + numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc); + std::string old_fname_header = winput::spillage_outdir + "/" + "orb_matrix."; + std::string new_fname_header = winput::spillage_outdir + "/" + "orb_matrix_rcut" + std::to_string(int(INPUT.bessel_nao_rcut)) + "deriv"; + for(int derivative_order = 0; derivative_order <= 1; derivative_order++) + { + // rename generated files + std::string old_fname = old_fname_header + std::to_string(derivative_order) + ".dat"; + std::string new_fname = new_fname_header + std::to_string(derivative_order) + ".dat"; + std::rename(old_fname.c_str(), new_fname.c_str()); + } + } + } ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); } } diff --git a/source/module_io/input.cpp b/source/module_io/input.cpp index be8e87a011..0db736c1cc 100644 --- a/source/module_io/input.cpp +++ b/source/module_io/input.cpp @@ -2203,7 +2203,9 @@ bool Input::Read(const std::string& fn) } else if (strcmp("bessel_nao_rcut", word) == 0) { - read_value(ifs, bessel_nao_rcut); + //read_value(ifs, bessel_nao_rcut); + read_value2stdvector(ifs, bessel_nao_rcuts); + bessel_nao_rcut = bessel_nao_rcuts[0]; // also compatible with old input file } else if (strcmp("bessel_nao_tolerence", word) == 0) { @@ -3563,6 +3565,12 @@ void Input::Bcast() Parallel_Common::bcast_bool(bessel_nao_smooth); Parallel_Common::bcast_double(bessel_nao_sigma); Parallel_Common::bcast_string(bessel_nao_ecut); + /* newly support vector/list input of bessel_nao_rcut */ + int nrcut = bessel_nao_rcuts.size(); + Parallel_Common::bcast_int(nrcut); + if (GlobalV::MY_RANK != 0) bessel_nao_rcuts.resize(nrcut); + Parallel_Common::bcast_double(bessel_nao_rcuts.data(), nrcut); + /* end */ Parallel_Common::bcast_double(bessel_nao_rcut); Parallel_Common::bcast_double(bessel_nao_tolerence); Parallel_Common::bcast_int(bessel_descriptor_lmax); @@ -4223,6 +4231,36 @@ void Input::strtolower(char* sa, char* sb) sb[len] = '\0'; } +template +void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var) +{ + // reset var + var.clear(); var.shrink_to_fit(); + std::string line; + std::getline(ifs, line); // read the whole rest of line + std::vector temp; + for(char &c: line) + { + if(c == '\t' || c == ' ' || c == '\n' || c == '#') // space or tab seperates values + { + if(temp.size() > 0) // if temp is not empty, excludes the case of multiple spaces or tabs + { + std::string str(temp.begin(), temp.end()); + var.push_back(std::stod(str)); + temp.clear(); + } + if(c == '\n' || c == '#' || c == '\0') break; // end of line + } + else temp.push_back(c); // other characters + } + if(temp.size() > 0) // the last value + { + std::string str(temp.begin(), temp.end()); + var.push_back(std::stod(str)); + } +} +template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); +template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); // Conut how many types of atoms are listed in STRU int Input::count_ntype(const std::string& fn) { diff --git a/source/module_io/input.h b/source/module_io/input.h index 46ceddcb8e..fe9b749ef8 100644 --- a/source/module_io/input.h +++ b/source/module_io/input.h @@ -546,6 +546,7 @@ class Input double bessel_nao_sigma; // spherical bessel smearing_sigma std::string bessel_nao_ecut; // energy cutoff for spherical bessel functions(Ry) double bessel_nao_rcut; // radial cutoff for spherical bessel functions(a.u.) + std::vector bessel_nao_rcuts; double bessel_nao_tolerence; // tolerence for spherical bessel root // the following are used when generating jle.orb int bessel_descriptor_lmax; // lmax used in descriptor @@ -628,7 +629,7 @@ class Input { ifs >> var; std::string line; - getline(ifs, line); + getline(ifs, line); // read the rest of the line, directly discard it. return; } void read_kspacing(std::ifstream &ifs) @@ -658,6 +659,11 @@ class Input // << std::endl; }; + /* I hope this function would be more and more useful if want to support + vector/list of input */ + template + void read_value2stdvector(std::ifstream& ifs, std::vector& var); + void strtolower(char *sa, char *sb); void read_bool(std::ifstream &ifs, bool &var); }; diff --git a/source/module_io/test/input_test.cpp b/source/module_io/test/input_test.cpp index 9c3c2b42a6..df1de6bb05 100644 --- a/source/module_io/test/input_test.cpp +++ b/source/module_io/test/input_test.cpp @@ -1646,7 +1646,104 @@ TEST_F(InputTest, Check) */ } +bool strcmp_inbuilt(const std::string& str1, const std::string& str2) +{ + if(str1.size() != str2.size()) + return false; + for(int i=0; i value; + while(!ifs.eof()) + { + ifs >> word; + if(strcmp_inbuilt(word, "bessel_nao_rcut_case0")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 1); + EXPECT_EQ(value[0], 7); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case1")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 1); + EXPECT_EQ(value[0], 7); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case2")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 1); + EXPECT_EQ(value[0], 7); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case3")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 1); + EXPECT_EQ(value[0], 7); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case4")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 1); + EXPECT_EQ(value[0], 7); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case5")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 4); + EXPECT_EQ(value[0], 7); + EXPECT_EQ(value[1], 8); + EXPECT_EQ(value[2], 9); + EXPECT_EQ(value[3], 10); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case6")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 4); + EXPECT_EQ(value[0], 7); + EXPECT_EQ(value[1], 8); + EXPECT_EQ(value[2], 9); + EXPECT_EQ(value[3], 10); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case7")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 4); + EXPECT_EQ(value[0], 7); + EXPECT_EQ(value[1], 8); + EXPECT_EQ(value[2], 9); + EXPECT_EQ(value[3], 10); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case8")) + { + value.clear(); value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, value); + EXPECT_EQ(value.size(), 4); + EXPECT_EQ(value[0], 7); + EXPECT_EQ(value[1], 8); + EXPECT_EQ(value[2], 9); + EXPECT_EQ(value[3], 10); + } + } +} #undef private diff --git a/source/module_io/test/support/INPUT_list b/source/module_io/test/support/INPUT_list new file mode 100644 index 0000000000..aba0e49980 --- /dev/null +++ b/source/module_io/test/support/INPUT_list @@ -0,0 +1,9 @@ +bessel_nao_rcut_case0 7 +bessel_nao_rcut_case1 7# case0: test whitespace, 1 value, 1 space between key and value, no comment, 0 space after value: w1100, case1: w1110 +bessel_nao_rcut_case2 7 # w1111 +bessel_nao_rcut_case3 7# t1110 +bessel_nao_rcut_case4 7 # t1111 +bessel_nao_rcut_case5 7 8 9 10# s4110 +bessel_nao_rcut_case6 7 8 9 10 # s4111 +bessel_nao_rcut_case7 7 8 9 10# t4t10 +bessel_nao_rcut_case8 7 8 9 10 # t4t11 \ No newline at end of file From 031605df0966fb7aa327ff9117d71fa356af5816 Mon Sep 17 00:00:00 2001 From: kirk0830 Date: Tue, 9 Jan 2024 13:00:09 +0800 Subject: [PATCH 2/3] use type_trais to do type conversion --- source/module_io/input.cpp | 6 ++-- source/module_io/input.h | 10 ++++++- source/module_io/test/input_test.cpp | 36 ++++++++++++++++++++++++ source/module_io/test/support/INPUT_list | 6 +++- 4 files changed, 54 insertions(+), 4 deletions(-) diff --git a/source/module_io/input.cpp b/source/module_io/input.cpp index 0db736c1cc..d74fe0a74f 100644 --- a/source/module_io/input.cpp +++ b/source/module_io/input.cpp @@ -4246,7 +4246,7 @@ void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var) if(temp.size() > 0) // if temp is not empty, excludes the case of multiple spaces or tabs { std::string str(temp.begin(), temp.end()); - var.push_back(std::stod(str)); + var.push_back(cast_string(str)); temp.clear(); } if(c == '\n' || c == '#' || c == '\0') break; // end of line @@ -4256,11 +4256,13 @@ void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var) if(temp.size() > 0) // the last value { std::string str(temp.begin(), temp.end()); - var.push_back(std::stod(str)); + var.push_back(cast_string(str)); } } template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); +template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); + // Conut how many types of atoms are listed in STRU int Input::count_ntype(const std::string& fn) { diff --git a/source/module_io/input.h b/source/module_io/input.h index fe9b749ef8..b796976417 100644 --- a/source/module_io/input.h +++ b/source/module_io/input.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "module_base/vector3.h" #include "module_md/md_para.h" @@ -663,7 +664,14 @@ class Input vector/list of input */ template void read_value2stdvector(std::ifstream& ifs, std::vector& var); - + template + typename std::enable_if::value, T>::type cast_string(const std::string& str) { return std::stod(str); } + template + typename std::enable_if::value, T>::type cast_string(const std::string& str) { return std::stoi(str); } + template + typename std::enable_if::value, T>::type cast_string(const std::string& str) { return (str == "true" || str == "1"); } + template + typename std::enable_if::value, T>::type cast_string(const std::string& str) { return str; } void strtolower(char *sa, char *sb); void read_bool(std::ifstream &ifs, bool &var); }; diff --git a/source/module_io/test/input_test.cpp b/source/module_io/test/input_test.cpp index df1de6bb05..50c8ff1475 100644 --- a/source/module_io/test/input_test.cpp +++ b/source/module_io/test/input_test.cpp @@ -1742,6 +1742,42 @@ TEST_F(InputTest, ReadValue2stdvector) EXPECT_EQ(value[2], 9); EXPECT_EQ(value[3], 10); } + std::vector str_value; + if(strcmp_inbuilt(word, "bessel_nao_rcut_case9")) + { + str_value.clear(); str_value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, str_value); + EXPECT_EQ(str_value.size(), 1); + EXPECT_EQ(str_value[0], "string1"); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case10")) + { + str_value.clear(); str_value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, str_value); + EXPECT_EQ(str_value.size(), 4); + EXPECT_EQ(str_value[0], "string1"); + EXPECT_EQ(str_value[1], "string2"); + EXPECT_EQ(str_value[2], "string3"); + EXPECT_EQ(str_value[3], "string4"); + } + std::vector double_value; + if(strcmp_inbuilt(word, "bessel_nao_rcut_case11")) + { + double_value.clear(); double_value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, double_value); + EXPECT_EQ(double_value.size(), 1); + EXPECT_EQ(double_value[0], 1.23456789); + } + if(strcmp_inbuilt(word, "bessel_nao_rcut_case12")) + { + double_value.clear(); double_value.shrink_to_fit(); + INPUT.read_value2stdvector(ifs, double_value); + EXPECT_EQ(double_value.size(), 4); + EXPECT_EQ(double_value[0], -1.23456789); + EXPECT_EQ(double_value[1], 2.3456789); + EXPECT_EQ(double_value[2], -3.456789); + EXPECT_EQ(double_value[3], 4.56789); + } } } #undef private diff --git a/source/module_io/test/support/INPUT_list b/source/module_io/test/support/INPUT_list index aba0e49980..a1efcb8e1a 100644 --- a/source/module_io/test/support/INPUT_list +++ b/source/module_io/test/support/INPUT_list @@ -6,4 +6,8 @@ bessel_nao_rcut_case4 7 # t1111 bessel_nao_rcut_case5 7 8 9 10# s4110 bessel_nao_rcut_case6 7 8 9 10 # s4111 bessel_nao_rcut_case7 7 8 9 10# t4t10 -bessel_nao_rcut_case8 7 8 9 10 # t4t11 \ No newline at end of file +bessel_nao_rcut_case8 7 8 9 10 # t4t11 +bessel_nao_rcut_case9 string1 # something +bessel_nao_rcut_case10 string1 string2 string3 string4 +bessel_nao_rcut_case11 1.23456789 +bessel_nao_rcut_case12 -1.23456789 2.3456789 -3.456789 4.56789 \ No newline at end of file From 51f3e872839ae5b49e34da629ceeb8ae86e8727e Mon Sep 17 00:00:00 2001 From: kirk0830 Date: Tue, 9 Jan 2024 14:12:13 +0800 Subject: [PATCH 3/3] turn to use std::transform to iterate vector instead of loop one-by-one --- source/module_io/input.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/source/module_io/input.cpp b/source/module_io/input.cpp index d74fe0a74f..f5c93034b3 100644 --- a/source/module_io/input.cpp +++ b/source/module_io/input.cpp @@ -4238,26 +4238,17 @@ void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var) var.clear(); var.shrink_to_fit(); std::string line; std::getline(ifs, line); // read the whole rest of line - std::vector temp; - for(char &c: line) + line = (line.find('#') == std::string::npos) ? line : line.substr(0, line.find('#')); // remove comments + std::vector tmp; + std::string::size_type start = 0, end = 0; + while ((start = line.find_first_not_of(" \t\n", end)) != std::string::npos) // find the first not of delimiters but not reaches the end { - if(c == '\t' || c == ' ' || c == '\n' || c == '#') // space or tab seperates values - { - if(temp.size() > 0) // if temp is not empty, excludes the case of multiple spaces or tabs - { - std::string str(temp.begin(), temp.end()); - var.push_back(cast_string(str)); - temp.clear(); - } - if(c == '\n' || c == '#' || c == '\0') break; // end of line - } - else temp.push_back(c); // other characters - } - if(temp.size() > 0) // the last value - { - std::string str(temp.begin(), temp.end()); - var.push_back(cast_string(str)); + end = line.find_first_of(" \t\n", start); // find the first of delimiters starting from start pos + tmp.push_back(line.substr(start, end - start)); // push back the substring } + var.resize(tmp.size()); + // capture "this"'s member function cast_string and iterate from tmp.begin() to tmp.end(), transform to var.begin() + std::transform(tmp.begin(), tmp.end(), var.begin(), [this](const std::string& s) { return cast_string(s); }); } template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var); template void Input::read_value2stdvector(std::ifstream& ifs, std::vector& var);