diff --git a/dpdata/system.py b/dpdata/system.py index 8af8e4a59..0748db284 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,6 +1,7 @@ # %% import glob import os +import warnings from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Union @@ -963,7 +964,16 @@ def register_data_type(cls, *data_type: Tuple[DataType]): *data_type : tuple[DataType] data type to be regiestered """ - cls.DTYPES = cls.DTYPES + tuple(data_type) + all_dtypes = cls.DTYPES + tuple(data_type) + dtypes_dict = {} + for dt in all_dtypes: + if dt.name in dtypes_dict: + warnings.warn( + f"Data type {dt.name} is registered twice; only the newly registered one will be used.", + UserWarning, + ) + dtypes_dict[dt.name] = dt + cls.DTYPES = tuple(dtypes_dict.values()) def get_cell_perturb_matrix(cell_pert_fraction): diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 5a0e2bab9..006d6b01e 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -4,6 +4,7 @@ import numpy as np import dpdata +from dpdata.data_type import Axis, DataType class TestDeepmdLoadDumpComp(unittest.TestCase): @@ -44,6 +45,14 @@ def test_from_deepmd_hdf5(self): x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["foo"], self.foo) + def test_duplicated_data_type(self): + dt = DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False) + n_dtypes_old = len(dpdata.LabeledSystem.DTYPES) + with self.assertWarns(UserWarning): + dpdata.LabeledSystem.register_data_type(dt) + n_dtypes_new = len(dpdata.LabeledSystem.DTYPES) + self.assertEqual(n_dtypes_old, n_dtypes_new) + class TestDeepmdLoadDumpCompAny(unittest.TestCase): def setUp(self):