diff --git a/dpdata/system.py b/dpdata/system.py index daf0c8858..8af8e4a59 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -954,6 +954,17 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): idx = pick_by_amber_mask(parm, maskstr) return self.pick_atom_idx(idx, nopbc=nopbc) + @classmethod + def register_data_type(cls, *data_type: Tuple[DataType]): + """Register data type. + + Parameters + ---------- + *data_type : tuple[DataType] + data type to be regiestered + """ + cls.DTYPES = cls.DTYPES + tuple(data_type) + def get_cell_perturb_matrix(cell_pert_fraction): if cell_pert_fraction < 0: @@ -1599,9 +1610,9 @@ def to_format(self, *args, **kwargs): setattr(MultiSystems, method, get_func(formatcls)) # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized - System.DTYPES = System.DTYPES + get_data_types(labeled=False) - LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False) - LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True) + System.register_data_type(*get_data_types(labeled=False)) + LabeledSystem.register_data_type(*get_data_types(labeled=False)) + LabeledSystem.register_data_type(*get_data_types(labeled=True)) add_format_methods()