From a1fc9929d395d7e957fdf40b57451fd56b138a13 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Sep 2023 22:06:40 +0800 Subject: [PATCH 1/2] support assigning 'type_map' for mixed_type --- dpdata/deepmd/mixed.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dpdata/deepmd/mixed.py b/dpdata/deepmd/mixed.py index e8abce21b..75dbecd40 100644 --- a/dpdata/deepmd/mixed.py +++ b/dpdata/deepmd/mixed.py @@ -46,6 +46,15 @@ def _load_set(folder, nopbc: bool): def to_system_data(folder, type_map=None, labels=True): # data is empty data = load_type(folder) + old_type_map = data["atom_names"].copy() + if type_map is not None: + assert isinstance(type_map, list) + missing_type = [i for i in old_type_map if i not in type_map] + assert not missing_type, f"These types are missing in selected type_map: {missing_type} !" + index_map = np.array([type_map.index(i) for i in old_type_map]) + data["atom_names"] = type_map.copy() + else: + index_map = None data["orig"] = np.zeros([3]) if os.path.isfile(os.path.join(folder, "nopbc")): data["nopbc"] = True @@ -63,7 +72,10 @@ def to_system_data(folder, type_map=None, labels=True): nframes = np.reshape(cells, [-1, 3, 3]).shape[0] all_cells.append(np.reshape(cells, [nframes, 3, 3])) all_coords.append(np.reshape(coords, [nframes, -1, 3])) - all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1])) + if index_map is None: + all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1])) + else: + all_real_atom_types.append(np.reshape(index_map[real_atom_types], [nframes, -1])) if eners is not None: eners = np.reshape(eners, [nframes]) if labels: From 8a160e9d98d1aec0d4a13aaabe9dd54dc7526778 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:08:07 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/deepmd/mixed.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dpdata/deepmd/mixed.py b/dpdata/deepmd/mixed.py index 75dbecd40..0d0ad89d9 100644 --- a/dpdata/deepmd/mixed.py +++ b/dpdata/deepmd/mixed.py @@ -50,7 +50,9 @@ def to_system_data(folder, type_map=None, labels=True): if type_map is not None: assert isinstance(type_map, list) missing_type = [i for i in old_type_map if i not in type_map] - assert not missing_type, f"These types are missing in selected type_map: {missing_type} !" + assert ( + not missing_type + ), f"These types are missing in selected type_map: {missing_type} !" index_map = np.array([type_map.index(i) for i in old_type_map]) data["atom_names"] = type_map.copy() else: @@ -75,7 +77,9 @@ def to_system_data(folder, type_map=None, labels=True): if index_map is None: all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1])) else: - all_real_atom_types.append(np.reshape(index_map[real_atom_types], [nframes, -1])) + all_real_atom_types.append( + np.reshape(index_map[real_atom_types], [nframes, -1]) + ) if eners is not None: eners = np.reshape(eners, [nframes]) if labels: