Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def _load_set(folder, nopbc: bool):
def to_system_data(folder, type_map=None, labels=True):
# data is empty
data = load_type(folder)
old_type_map = data["atom_names"].copy()
if type_map is not None:
assert isinstance(type_map, list)
missing_type = [i for i in old_type_map if i not in type_map]
assert (
not missing_type
), f"These types are missing in selected type_map: {missing_type} !"
index_map = np.array([type_map.index(i) for i in old_type_map])
data["atom_names"] = type_map.copy()
else:
index_map = None
data["orig"] = np.zeros([3])
if os.path.isfile(os.path.join(folder, "nopbc")):
data["nopbc"] = True
Expand All @@ -63,7 +74,12 @@ def to_system_data(folder, type_map=None, labels=True):
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
if index_map is None:
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
else:
all_real_atom_types.append(
np.reshape(index_map[real_atom_types], [nframes, -1])
)
if eners is not None:
eners = np.reshape(eners, [nframes])
if labels:
Expand Down