From 837a4b1ed070506d86e960b435852c052de0264b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Oct 2024 15:05:57 +0800 Subject: [PATCH 1/6] abacus: add checks on pp and orb in construction of STRU --- dpdata/abacus/scf.py | 74 +++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index 9919e9128..b34c8c88b 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -628,6 +628,9 @@ def make_unlabeled_stru( link_file : bool, optional Whether to link the pseudo potential files and orbital files in the STRU file. If True, then only filename will be written in the STRU file, and make a soft link to the real file. + dest_dir : str, optional + The destination directory to make the soft link of the pseudo potential files and orbital files. + For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written. ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment. If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used. @@ -682,6 +685,20 @@ def ndarray2list(i): out = "ATOMIC_SPECIES\n" if pp_file is not None: pp_file = ndarray2list(pp_file) + ppfiles = None + if isinstance(pp_file,(list, tuple)): + assert len(pp_file) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types" + ppfiles = pp_file + elif isinstance(pp_file, dict): + for iele in data["atom_names"]: + if iele not in pp_file: + raise RuntimeError(f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}") + ppfiles = [pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"]))] + else: + raise RuntimeError(f"ERROR: invalid pp_file: {pp_file}") + else: + ppfiles = None + for iele in range(len(data["atom_names"])): if data["atom_numbs"][iele] == 0: continue @@ -690,19 +707,8 @@ def ndarray2list(i): out += f"{mass[iele]:.3f} " else: out += "1 " - if pp_file is not None: - if isinstance(pp_file, (list, tuple)): - ipp_file = pp_file[iele] - elif isinstance(pp_file, dict): - if data["atom_names"][iele] not in pp_file: - print( - f"ERROR: make_unlabeled_stru: pp_file does not contain {data['atom_names'][iele]}" - ) - ipp_file = None - else: - ipp_file = pp_file[data["atom_names"][iele]] - else: - ipp_file = None + if ppfiles is not None: + ipp_file = ppfiles[iele] if ipp_file is not None: if not link_file: out += ipp_file @@ -710,37 +716,33 @@ def ndarray2list(i): out += os.path.basename(ipp_file.rstrip("/")) if dest_dir is not None: _link_file(dest_dir, ipp_file) - out += "\n" out += "\n" # NUMERICAL_ORBITAL block if numerical_orbital is not None: - assert len(numerical_orbital) == len(data["atom_names"]) numerical_orbital = ndarray2list(numerical_orbital) + orbfiles = [] + if isinstance(numerical_orbital, (list, tuple)): + assert len(numerical_orbital) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types" + orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] + elif isinstance(numerical_orbital, dict): + for iele in data["atom_names"]: + if iele not in numerical_orbital: + raise RuntimeError(f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}") + orbfiles = [numerical_orbital[data["atom_names"][i]] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] + else: + raise RuntimeError(f"ERROR: invalid numerical_orbital: {numerical_orbital}") + + out += "NUMERICAL_ORBITAL\n" - for iele in range(len(data["atom_names"])): - if data["atom_numbs"][iele] == 0: - continue - if isinstance(numerical_orbital, (list, tuple)): - inum_orbital = numerical_orbital[iele] - elif isinstance(numerical_orbital, dict): - if data["atom_names"][iele] not in numerical_orbital: - print( - f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {data['atom_names'][iele]}" - ) - inum_orbital = None - else: - inum_orbital = numerical_orbital[data["atom_names"][iele]] + for iorb in orbfiles: + if not link_file: + out += iorb else: - inum_orbital = None - if inum_orbital is not None: - if not link_file: - out += inum_orbital - else: - out += os.path.basename(inum_orbital.rstrip("/")) - if dest_dir is not None: - _link_file(dest_dir, inum_orbital) + out += os.path.basename(iorb.rstrip("/")) + if dest_dir is not None: + _link_file(dest_dir, iorb) out += "\n" out += "\n" From aba74b217432971a5e77ff2b3d57b952d1d2c7f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 07:06:41 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/abacus/scf.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index b34c8c88b..6c62cffdb 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -630,7 +630,6 @@ def make_unlabeled_stru( If True, then only filename will be written in the STRU file, and make a soft link to the real file. dest_dir : str, optional The destination directory to make the soft link of the pseudo potential files and orbital files. - For velocity, mag, angle1, angle2, sc, and lambda_, if the value is None, then the corresponding information will not be written. ABACUS support defining "mag" and "angle1"/"angle2" at the same time, and in this case, the "mag" only define the norm of the magnetic moment, and "angle1" and "angle2" define the direction of the magnetic moment. If data has spins, then it will be written as mag to STRU file; while if mag is passed at the same time, then mag will be used. @@ -686,19 +685,25 @@ def ndarray2list(i): if pp_file is not None: pp_file = ndarray2list(pp_file) ppfiles = None - if isinstance(pp_file,(list, tuple)): - assert len(pp_file) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types" + if isinstance(pp_file, (list, tuple)): + assert ( + len(pp_file) == len(data["atom_names"]) + ), "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types" ppfiles = pp_file elif isinstance(pp_file, dict): for iele in data["atom_names"]: if iele not in pp_file: - raise RuntimeError(f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}") - ppfiles = [pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"]))] + raise RuntimeError( + f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}" + ) + ppfiles = [ + pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"])) + ] else: raise RuntimeError(f"ERROR: invalid pp_file: {pp_file}") else: ppfiles = None - + for iele in range(len(data["atom_names"])): if data["atom_numbs"][iele] == 0: continue @@ -724,17 +729,28 @@ def ndarray2list(i): numerical_orbital = ndarray2list(numerical_orbital) orbfiles = [] if isinstance(numerical_orbital, (list, tuple)): - assert len(numerical_orbital) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types" - orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] + assert ( + len(numerical_orbital) == len(data["atom_names"]) + ), "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types" + orbfiles = [ + numerical_orbital[i] + for i in range(len(data["atom_names"])) + if data["atom_numbs"][i] != 0 + ] elif isinstance(numerical_orbital, dict): for iele in data["atom_names"]: if iele not in numerical_orbital: - raise RuntimeError(f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}") - orbfiles = [numerical_orbital[data["atom_names"][i]] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] + raise RuntimeError( + f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}" + ) + orbfiles = [ + numerical_orbital[data["atom_names"][i]] + for i in range(len(data["atom_names"])) + if data["atom_numbs"][i] != 0 + ] else: raise RuntimeError(f"ERROR: invalid numerical_orbital: {numerical_orbital}") - - + out += "NUMERICAL_ORBITAL\n" for iorb in orbfiles: if not link_file: From 9d8a512f62dd6c3a061b6aea6133e0a80f52409f Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Oct 2024 15:34:18 +0800 Subject: [PATCH 3/6] add UT --- dpdata/abacus/scf.py | 6 ++++-- tests/test_abacus_stru_dump.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index b34c8c88b..423419d3e 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -687,7 +687,8 @@ def ndarray2list(i): pp_file = ndarray2list(pp_file) ppfiles = None if isinstance(pp_file,(list, tuple)): - assert len(pp_file) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types" + if len(pp_file) != len(data["atom_names"]): + raise RuntimeError("ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types") ppfiles = pp_file elif isinstance(pp_file, dict): for iele in data["atom_names"]: @@ -724,7 +725,8 @@ def ndarray2list(i): numerical_orbital = ndarray2list(numerical_orbital) orbfiles = [] if isinstance(numerical_orbital, (list, tuple)): - assert len(numerical_orbital) == len(data["atom_names"]), "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types" + if len(numerical_orbital) != len(data["atom_names"]): + raise RuntimeError("ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types") orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] elif isinstance(numerical_orbital, dict): for iele in data["atom_names"]: diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 084a5473c..302f9f21d 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -50,6 +50,28 @@ def test_dumpStruLinkFile(self): if os.path.isdir("abacus.scf/tmp"): shutil.rmtree("abacus.scf/tmp") + + def test_dump_stru_pporb_mismatch(self): + self.assertRaises(RuntimeError, + self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], + pp_file={"C": "C.upf", "O": "O.upf"}, + numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a dict and lack of pp for H" + + self.assertRaises(RuntimeError, + self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], + pp_file=["C.upf"], + numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a list and lack of pp for H" + + self.assertRaises(RuntimeError, + self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], + pp_file={"C": "C.upf", "H": "H.upf"}, + numerical_orbital={"C": "C.orb", "O": "O.orb"}), "numerical_orbital is a dict and lack of orbital for H" + + self.assertRaises(RuntimeError, + self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], + pp_file=["C.upf", "H.upf"], + numerical_orbital=["C.orb"]), "numerical_orbital is a list and lack of orbital for H" + def test_dump_spinconstrain(self): self.system_ch4.to( From 2511d646195cb58752877dfe7a8a59b4db053271 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 07:35:37 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/abacus/scf.py | 16 ++++++-- tests/test_abacus_stru_dump.py | 73 ++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index 04586967d..37cd76ab0 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -685,9 +685,11 @@ def ndarray2list(i): if pp_file is not None: pp_file = ndarray2list(pp_file) ppfiles = None - if isinstance(pp_file,(list, tuple)): + if isinstance(pp_file, (list, tuple)): if len(pp_file) != len(data["atom_names"]): - raise RuntimeError("ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types") + raise RuntimeError( + "ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types" + ) ppfiles = pp_file elif isinstance(pp_file, dict): for iele in data["atom_names"]: @@ -729,8 +731,14 @@ def ndarray2list(i): orbfiles = [] if isinstance(numerical_orbital, (list, tuple)): if len(numerical_orbital) != len(data["atom_names"]): - raise RuntimeError("ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types") - orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] + raise RuntimeError( + "ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types" + ) + orbfiles = [ + numerical_orbital[i] + for i in range(len(data["atom_names"])) + if data["atom_numbs"][i] != 0 + ] elif isinstance(numerical_orbital, dict): for iele in data["atom_names"]: if iele not in numerical_orbital: diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 302f9f21d..7e9d1c654 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -50,28 +50,59 @@ def test_dumpStruLinkFile(self): if os.path.isdir("abacus.scf/tmp"): shutil.rmtree("abacus.scf/tmp") - + def test_dump_stru_pporb_mismatch(self): - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file={"C": "C.upf", "O": "O.upf"}, - numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a dict and lack of pp for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file=["C.upf"], - numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a list and lack of pp for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file={"C": "C.upf", "H": "H.upf"}, - numerical_orbital={"C": "C.orb", "O": "O.orb"}), "numerical_orbital is a dict and lack of orbital for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file=["C.upf", "H.upf"], - numerical_orbital=["C.orb"]), "numerical_orbital is a list and lack of orbital for H" - + ( + self.assertRaises( + RuntimeError, + self.system_ch4.to, + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "O": "O.upf"}, + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ), + "pp_file is a dict and lack of pp for H", + ) + + ( + self.assertRaises( + RuntimeError, + self.system_ch4.to, + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf"], + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ), + "pp_file is a list and lack of pp for H", + ) + + ( + self.assertRaises( + RuntimeError, + self.system_ch4.to, + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "H": "H.upf"}, + numerical_orbital={"C": "C.orb", "O": "O.orb"}, + ), + "numerical_orbital is a dict and lack of orbital for H", + ) + + ( + self.assertRaises( + RuntimeError, + self.system_ch4.to, + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf", "H.upf"], + numerical_orbital=["C.orb"], + ), + "numerical_orbital is a list and lack of orbital for H", + ) def test_dump_spinconstrain(self): self.system_ch4.to( From 9228db64a71fbd8b4f55cb20dfabebb98de36651 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 10 Oct 2024 09:47:00 +0800 Subject: [PATCH 5/6] fix --- dpdata/abacus/scf.py | 85 ++++++++++++++-------------------- tests/test_abacus_stru_dump.py | 53 +++++++++++++-------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index 04586967d..3afdbd5fb 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -578,7 +578,7 @@ def get_frame_from_stru(fname): def make_unlabeled_stru( data, frame_idx, - pp_file=None, + pp_file, numerical_orbital=None, numerical_descriptor=None, mass=None, @@ -601,7 +601,7 @@ def make_unlabeled_stru( System data frame_idx : int The index of the frame to dump - pp_file : list of string or dict, optional + pp_file : list of string or dict List of pseudo potential files, or a dictionary of pseudo potential files for each atomnames numerical_orbital : list of string or dict, optional List of orbital files, or a dictionary of orbital files for each atomnames @@ -656,6 +656,24 @@ def ndarray2list(i): return i.tolist() else: return i + + def process_file_input(file_input, atom_names, input_name): + # For pp_file and numerical_orbital, process the file input, and return a list of file names + # file_input can be a list of file names, or a dictionary of file names for each atom names + if isinstance(file_input, (list, tuple)): + if len(file_input) != len(atom_names): + raise ValueError( + f"{input_name} length is not equal to the number of atom types" + ) + return file_input + elif isinstance(file_input, dict): + for element in atom_names: + if element not in file_input: + raise KeyError(f"{input_name} does not contain {element}") + return [file_input[element] for element in atom_names] + else: + raise ValueError(f"Invalid {input_name}: {file_input}") + if link_file and dest_dir is None: print( @@ -682,26 +700,7 @@ def ndarray2list(i): # ATOMIC_SPECIES block out = "ATOMIC_SPECIES\n" - if pp_file is not None: - pp_file = ndarray2list(pp_file) - ppfiles = None - if isinstance(pp_file,(list, tuple)): - if len(pp_file) != len(data["atom_names"]): - raise RuntimeError("ERROR: make_unlabeled_stru: pp_file length is not equal to the number of atom types") - ppfiles = pp_file - elif isinstance(pp_file, dict): - for iele in data["atom_names"]: - if iele not in pp_file: - raise RuntimeError( - f"ERROR: make_unlabeled_stru: pp_file does not contain {iele}" - ) - ppfiles = [ - pp_file[data["atom_names"][i]] for i in range(len(data["atom_names"])) - ] - else: - raise RuntimeError(f"ERROR: invalid pp_file: {pp_file}") - else: - ppfiles = None + ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file") for iele in range(len(data["atom_names"])): if data["atom_numbs"][iele] == 0: @@ -711,40 +710,26 @@ def ndarray2list(i): out += f"{mass[iele]:.3f} " else: out += "1 " - if ppfiles is not None: - ipp_file = ppfiles[iele] - if ipp_file is not None: - if not link_file: - out += ipp_file - else: - out += os.path.basename(ipp_file.rstrip("/")) - if dest_dir is not None: - _link_file(dest_dir, ipp_file) + + ipp_file = ppfiles[iele] + if not link_file: + out += ipp_file + else: + out += os.path.basename(ipp_file.rstrip("/")) + if dest_dir is not None: + _link_file(dest_dir, ipp_file) out += "\n" out += "\n" # NUMERICAL_ORBITAL block if numerical_orbital is not None: numerical_orbital = ndarray2list(numerical_orbital) - orbfiles = [] - if isinstance(numerical_orbital, (list, tuple)): - if len(numerical_orbital) != len(data["atom_names"]): - raise RuntimeError("ERROR: make_unlabeled_stru: numerical_orbital length is not equal to the number of atom types") - orbfiles = [numerical_orbital[i] for i in range(len(data["atom_names"])) if data["atom_numbs"][i] != 0] - elif isinstance(numerical_orbital, dict): - for iele in data["atom_names"]: - if iele not in numerical_orbital: - raise RuntimeError( - f"ERROR: make_unlabeled_stru: numerical_orbital does not contain {iele}" - ) - orbfiles = [ - numerical_orbital[data["atom_names"][i]] - for i in range(len(data["atom_names"])) - if data["atom_numbs"][i] != 0 - ] - else: - raise RuntimeError(f"ERROR: invalid numerical_orbital: {numerical_orbital}") - + orbfiles = process_file_input(numerical_orbital, data["atom_names"], "numerical_orbital") + orbfiles = [ + orbfiles[i] + for i in range(len(data["atom_names"])) + if data["atom_numbs"][i] != 0 + ] out += "NUMERICAL_ORBITAL\n" for iorb in orbfiles: if not link_file: diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 302f9f21d..4c28d2734 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -52,26 +52,41 @@ def test_dumpStruLinkFile(self): shutil.rmtree("abacus.scf/tmp") def test_dump_stru_pporb_mismatch(self): - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file={"C": "C.upf", "O": "O.upf"}, - numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a dict and lack of pp for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file=["C.upf"], - numerical_orbital={"C": "C.orb", "H": "H.orb"}), "pp_file is a list and lack of pp for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file={"C": "C.upf", "H": "H.upf"}, - numerical_orbital={"C": "C.orb", "O": "O.orb"}), "numerical_orbital is a dict and lack of orbital for H" - - self.assertRaises(RuntimeError, - self.system_ch4.to,"stru","STRU_tmp",mass=[12, 1], - pp_file=["C.upf", "H.upf"], - numerical_orbital=["C.orb"]), "numerical_orbital is a list and lack of orbital for H" + with self.assertRaises(KeyError, msg="pp_file is a dict and lack of pp for H"): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "O": "O.upf"}, + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ) + + with self.assertRaises(ValueError, msg="pp_file is a list and lack of pp for H"): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf"], + numerical_orbital={"C": "C.orb", "H": "H.orb"}, + ) + with self.assertRaises(KeyError, msg="numerical_orbital is a dict and lack of orbital for H"): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file={"C": "C.upf", "H": "H.upf"}, + numerical_orbital={"C": "C.orb", "O": "O.orb"}, + ) + + with self.assertRaises(ValueError, msg="numerical_orbital is a list and lack of orbital for H"): + self.system_ch4.to( + "stru", + "STRU_tmp", + mass=[12, 1], + pp_file=["C.upf", "H.upf"], + numerical_orbital=["C.orb"], + ) def test_dump_spinconstrain(self): self.system_ch4.to( From 1246f02b30a26487b6d01188e4416fa3b80147d1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Oct 2024 01:49:37 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/abacus/scf.py | 9 +++++---- tests/test_abacus_stru_dump.py | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index 3afdbd5fb..b1b2cfed9 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -656,7 +656,7 @@ def ndarray2list(i): return i.tolist() else: return i - + def process_file_input(file_input, atom_names, input_name): # For pp_file and numerical_orbital, process the file input, and return a list of file names # file_input can be a list of file names, or a dictionary of file names for each atom names @@ -673,7 +673,6 @@ def process_file_input(file_input, atom_names, input_name): return [file_input[element] for element in atom_names] else: raise ValueError(f"Invalid {input_name}: {file_input}") - if link_file and dest_dir is None: print( @@ -700,7 +699,7 @@ def process_file_input(file_input, atom_names, input_name): # ATOMIC_SPECIES block out = "ATOMIC_SPECIES\n" - ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file") + ppfiles = process_file_input(ndarray2list(pp_file), data["atom_names"], "pp_file") for iele in range(len(data["atom_names"])): if data["atom_numbs"][iele] == 0: @@ -724,7 +723,9 @@ def process_file_input(file_input, atom_names, input_name): # NUMERICAL_ORBITAL block if numerical_orbital is not None: numerical_orbital = ndarray2list(numerical_orbital) - orbfiles = process_file_input(numerical_orbital, data["atom_names"], "numerical_orbital") + orbfiles = process_file_input( + numerical_orbital, data["atom_names"], "numerical_orbital" + ) orbfiles = [ orbfiles[i] for i in range(len(data["atom_names"])) diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 5b8477bb9..4549b6d16 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -60,8 +60,10 @@ def test_dump_stru_pporb_mismatch(self): pp_file={"C": "C.upf", "O": "O.upf"}, numerical_orbital={"C": "C.orb", "H": "H.orb"}, ) - - with self.assertRaises(ValueError, msg="pp_file is a list and lack of pp for H"): + + with self.assertRaises( + ValueError, msg="pp_file is a list and lack of pp for H" + ): self.system_ch4.to( "stru", "STRU_tmp", @@ -69,8 +71,10 @@ def test_dump_stru_pporb_mismatch(self): pp_file=["C.upf"], numerical_orbital={"C": "C.orb", "H": "H.orb"}, ) - - with self.assertRaises(KeyError, msg="numerical_orbital is a dict and lack of orbital for H"): + + with self.assertRaises( + KeyError, msg="numerical_orbital is a dict and lack of orbital for H" + ): self.system_ch4.to( "stru", "STRU_tmp", @@ -78,8 +82,10 @@ def test_dump_stru_pporb_mismatch(self): pp_file={"C": "C.upf", "H": "H.upf"}, numerical_orbital={"C": "C.orb", "O": "O.orb"}, ) - - with self.assertRaises(ValueError, msg="numerical_orbital is a list and lack of orbital for H"): + + with self.assertRaises( + ValueError, msg="numerical_orbital is a list and lack of orbital for H" + ): self.system_ch4.to( "stru", "STRU_tmp",