diff --git a/dpdata/deepmd/mixed.py b/dpdata/deepmd/mixed.py index e8abce21b..0d0ad89d9 100644 --- a/dpdata/deepmd/mixed.py +++ b/dpdata/deepmd/mixed.py @@ -46,6 +46,17 @@ def _load_set(folder, nopbc: bool): def to_system_data(folder, type_map=None, labels=True): # data is empty data = load_type(folder) + old_type_map = data["atom_names"].copy() + if type_map is not None: + assert isinstance(type_map, list) + missing_type = [i for i in old_type_map if i not in type_map] + assert ( + not missing_type + ), f"These types are missing in selected type_map: {missing_type} !" + index_map = np.array([type_map.index(i) for i in old_type_map]) + data["atom_names"] = type_map.copy() + else: + index_map = None data["orig"] = np.zeros([3]) if os.path.isfile(os.path.join(folder, "nopbc")): data["nopbc"] = True @@ -63,7 +74,12 @@ def to_system_data(folder, type_map=None, labels=True): nframes = np.reshape(cells, [-1, 3, 3]).shape[0] all_cells.append(np.reshape(cells, [nframes, 3, 3])) all_coords.append(np.reshape(coords, [nframes, -1, 3])) - all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1])) + if index_map is None: + all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1])) + else: + all_real_atom_types.append( + np.reshape(index_map[real_atom_types], [nframes, -1]) + ) if eners is not None: eners = np.reshape(eners, [nframes]) if labels: