From e4ce75deaa4ec5df82053ad7b5a338e8b823e64b Mon Sep 17 00:00:00 2001 From: caic99 Date: Fri, 1 Aug 2025 06:14:07 +0000 Subject: [PATCH 1/4] fix: skip datatype registration warning for duplicate types --- dpdata/data_type.py | 37 +++++++++++++++++++ dpdata/system.py | 9 ++--- tests/test_custom_data_type.py | 65 +++++++++++++++++++++++++++++----- 3 files changed, 99 insertions(+), 12 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 8cea28c18..c689e99b9 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -64,6 +64,43 @@ def __init__( self.required = required self.deepmd_name = name if deepmd_name is None else deepmd_name + def __eq__(self, other) -> bool: + """Check if two DataType instances are equal. + + Parameters + ---------- + other : object + object to compare with + + Returns + ------- + bool + True if equal, False otherwise + """ + if not isinstance(other, DataType): + return False + return ( + self.name == other.name + and self.dtype == other.dtype + and self.shape == other.shape + and self.required == other.required + and self.deepmd_name == other.deepmd_name + ) + + def __repr__(self) -> str: + """Return string representation of DataType. + + Returns + ------- + str + string representation + """ + return ( + f"DataType(name='{self.name}', dtype={self.dtype.__name__}, " + f"shape={self.shape}, required={self.required}, " + f"deepmd_name='{self.deepmd_name}')" + ) + def real_shape(self, system: System) -> tuple[int]: """Returns expected real shape of a system.""" assert self.shape is not None diff --git a/dpdata/system.py b/dpdata/system.py index d7cf26572..ede561a99 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1103,10 +1103,11 @@ def register_data_type(cls, *data_type: DataType): dtypes_dict = {} for dt in all_dtypes: if dt.name in dtypes_dict: - warnings.warn( - f"Data type {dt.name} is registered twice; only the newly registered one will be used.", - UserWarning, - ) + if dt != dtypes_dict[dt.name]: + warnings.warn( + f"Data type {dt.name} is registered twice with different definitions; only the newly registered one will be used.", + UserWarning, + ) dtypes_dict[dt.name] = dt cls.DTYPES = tuple(dtypes_dict.values()) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 08dc9eba2..b303ab839 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -1,6 +1,7 @@ from __future__ import annotations import unittest +import warnings import h5py # noqa: TID253 import numpy as np @@ -9,6 +10,62 @@ from dpdata.data_type import Axis, DataType +class TestDataType(unittest.TestCase): + """Test DataType class methods.""" + + def test_eq(self): + """Test equality method.""" + dt1 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) + dt2 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) + dt3 = DataType("other", np.ndarray, shape=(Axis.NFRAMES, 3)) + + self.assertTrue(dt1 == dt2) + self.assertFalse(dt1 == dt3) + self.assertFalse(dt1 == "not a DataType") + + def test_repr(self): + """Test string representation.""" + dt = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) + expected = ( + "DataType(name='test', dtype=ndarray, " + "shape=(, 3), required=True, " + "deepmd_name='test')" + ) + self.assertEqual(repr(dt), expected) + + def test_register_same_data_type_no_warning(self): + """Test registering identical DataType instances should not warn.""" + dt1 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3)) + dt2 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3)) + + # Register first time + dpdata.System.register_data_type(dt1) + + # Register same DataType again - should not warn + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + dpdata.System.register_data_type(dt2) + # Check no warnings were issued + self.assertEqual(len(w), 0) + + def test_register_different_data_type_with_warning(self): + """Test registering different DataType instances with same name should warn.""" + dt1 = DataType("test_diff", np.ndarray, shape=(Axis.NFRAMES, 3)) + dt2 = DataType("test_diff", list, shape=(Axis.NFRAMES, 4)) # Different dtype and shape + + # Register first time + dpdata.System.register_data_type(dt1) + + # Register different DataType with same name - should warn + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + dpdata.System.register_data_type(dt2) + # Check warning was issued + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, UserWarning)) + self.assertIn("registered twice with different definitions", str(w[-1].message)) + + class DeepmdLoadDumpCompTest: def setUp(self): self.system = self.cls( @@ -49,14 +106,6 @@ def test_from_deepmd_hdf5(self): x = self.cls("data_foo.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["foo"], self.foo) - def test_duplicated_data_type(self): - dt = DataType("foo", np.ndarray, (Axis.NFRAMES, *self.shape), required=False) - n_dtypes_old = len(self.cls.DTYPES) - with self.assertWarns(UserWarning): - self.cls.register_data_type(dt) - n_dtypes_new = len(self.cls.DTYPES) - self.assertEqual(n_dtypes_old, n_dtypes_new) - def test_to_deepmd_npy_mixed(self): ms = dpdata.MultiSystems(self.system) ms.to_deepmd_npy_mixed("data_foo_mixed") From cfa28d63abc8af39dcd5d55b24c39470b60dfd49 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Aug 2025 06:16:24 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_custom_data_type.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index b303ab839..47ddab498 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -51,7 +51,9 @@ def test_register_same_data_type_no_warning(self): def test_register_different_data_type_with_warning(self): """Test registering different DataType instances with same name should warn.""" dt1 = DataType("test_diff", np.ndarray, shape=(Axis.NFRAMES, 3)) - dt2 = DataType("test_diff", list, shape=(Axis.NFRAMES, 4)) # Different dtype and shape + dt2 = DataType( + "test_diff", list, shape=(Axis.NFRAMES, 4) + ) # Different dtype and shape # Register first time dpdata.System.register_data_type(dt1) @@ -63,7 +65,9 @@ def test_register_different_data_type_with_warning(self): # Check warning was issued self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) - self.assertIn("registered twice with different definitions", str(w[-1].message)) + self.assertIn( + "registered twice with different definitions", str(w[-1].message) + ) class DeepmdLoadDumpCompTest: From 51d67908e7910a6e315c6692d9d30db5bbb8bc02 Mon Sep 17 00:00:00 2001 From: caic99 Date: Fri, 1 Aug 2025 06:30:57 +0000 Subject: [PATCH 3/4] fix ut --- tests/test_custom_data_type.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 47ddab498..b26b0f172 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -13,6 +13,14 @@ class TestDataType(unittest.TestCase): """Test DataType class methods.""" + def setUp(self): + # Store original DTYPES to restore later + self.original_dtypes = dpdata.System.DTYPES + + def tearDown(self): + # Restore original DTYPES + dpdata.System.DTYPES = self.original_dtypes + def test_eq(self): """Test equality method.""" dt1 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) From 2a8ac2ee2995b5559c7d36d0f38378a5a187baac Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Sat, 2 Aug 2025 15:36:56 +0800 Subject: [PATCH 4/4] merge stmt Signed-off-by: Chun Cai --- dpdata/system.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index ede561a99..4c8f350a2 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1102,12 +1102,11 @@ def register_data_type(cls, *data_type: DataType): all_dtypes = cls.DTYPES + tuple(data_type) dtypes_dict = {} for dt in all_dtypes: - if dt.name in dtypes_dict: - if dt != dtypes_dict[dt.name]: - warnings.warn( - f"Data type {dt.name} is registered twice with different definitions; only the newly registered one will be used.", - UserWarning, - ) + if dt.name in dtypes_dict and dt != dtypes_dict[dt.name]: + warnings.warn( + f"Data type {dt.name} is registered twice with different definitions; only the newly registered one will be used.", + UserWarning, + ) dtypes_dict[dt.name] = dt cls.DTYPES = tuple(dtypes_dict.values())