From 4d80585238d212189aae6770f74cb7b1c7887806 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Sep 2023 14:20:39 -0400 Subject: [PATCH 1/2] add a public API to register data types dynamically --- dpdata/system.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index daf0c8858..d84bcb948 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -954,6 +954,17 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): idx = pick_by_amber_mask(parm, maskstr) return self.pick_atom_idx(idx, nopbc=nopbc) + @classmethod + def register_data_type(cls, *data_type: Tuple[DataType]): + """Register data type. + + Parameters + ---------- + *data_type : tuple[DataType] + data type to be regiestered + """ + cls.DTYPES = cls.DTYPES + tuple(data_type) + def get_cell_perturb_matrix(cell_pert_fraction): if cell_pert_fraction < 0: @@ -1599,9 +1610,9 @@ def to_format(self, *args, **kwargs): setattr(MultiSystems, method, get_func(formatcls)) # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized - System.DTYPES = System.DTYPES + get_data_types(labeled=False) - LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False) - LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True) + System.register_data_type(*get_data_types(labeled=False)) + LabeledSystem.register_data_type(*get_data_types(labeled=False)) + LabeledSystem.register_data_type(*get_data_types(labeled=True)) add_format_methods() From bed865fcc01e29c397320770b789a4e546303392 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:21:45 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpdata/system.py b/dpdata/system.py index d84bcb948..8af8e4a59 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -957,7 +957,7 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): @classmethod def register_data_type(cls, *data_type: Tuple[DataType]): """Register data type. - + Parameters ---------- *data_type : tuple[DataType]