diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index c707240d098..760d004186b 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -115,6 +115,11 @@ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]): if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list): raise TypeError("All elements in idx_list must be of the same datetime64 precision") self.idx_list = np.array(idx_list) + # Normalize datetime64 to nanosecond precision for consistent hashing + # Different precisions (e.g. 'ns' vs 's') are equal but have different hashes, + # which breaks dict lookups and set operations (see issue #1806) + if self.idx_list.dtype.kind == "M": + self.idx_list = self.idx_list.astype("datetime64[ns]") # NOTE: only the first appearance is indexed self.index_map = dict(zip(self.idx_list, range(len(self)))) self._is_sorted = False diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 89fccb4d91f..94794fd7f79 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -112,6 +112,62 @@ def test_corner_cases(self): with self.assertRaises(TypeError): sd = idd.SingleData([1, 2, 3], index=timeindex) + def test_datetime64_precision_normalization(self): + """Test that datetime64 values with different precisions work correctly. + + numpy.datetime64 values with different precisions (e.g. 'ns' vs 's') are + equal but have different hashes, which breaks dict/set operations. + Normalizing to 'ns' in Index.__init__ fixes this (issue #1806). + """ + # Create two SingleData with different datetime64 precisions + ns_index = [ + np.datetime64("2017-01-04T00:00:00.000000000"), + np.datetime64("2017-01-05T00:00:00.000000000"), + np.datetime64("2017-01-06T00:00:00.000000000"), + ] + s_index = [ + np.datetime64("2017-01-04T00:00:00"), + np.datetime64("2017-01-05T00:00:00"), + np.datetime64("2017-01-06T00:00:00"), + ] + sd_ns = idd.SingleData([1, 2, 3], index=ns_index) + sd_s = idd.SingleData([4, 5, 6], index=s_index) + + # Both should be normalized to ns precision + self.assertEqual(sd_ns.index.idx_list.dtype, np.dtype("datetime64[ns]")) + self.assertEqual(sd_s.index.idx_list.dtype, np.dtype("datetime64[ns]")) + + # Cross-precision lookup should work + self.assertEqual(sd_ns.loc[np.datetime64("2017-01-04T00:00:00")], 1) + self.assertEqual(sd_s.loc[np.datetime64("2017-01-05T00:00:00.000000000")], 5) + + # Index.__or__ should work across precisions + combined = sd_ns.index | sd_s.index + self.assertEqual(len(combined), 3) + + # _align_indices should work across precisions + result = sd_ns + sd_s + self.assertEqual(result.iloc[0], 5) + self.assertEqual(result.iloc[1], 7) + self.assertEqual(result.iloc[2], 9) + + # concat should work across precisions + md = idd.concat([sd_ns, sd_s], axis=1) + self.assertEqual(md.data.shape, (3, 2)) + + # sum_by_index should work across precisions + summed = idd.sum_by_index([sd_ns, sd_s], sd_ns.index, fill_value=0) + self.assertEqual(summed.iloc[0], 5) + + # to_dict should produce consistent keys + d = sd_ns.to_dict() + self.assertIn(np.datetime64("2017-01-04T00:00:00"), d) + + # reindex across precisions should work + new_index = idd.Index(s_index) + reindexed = sd_ns.reindex(new_index) + self.assertEqual(reindexed.iloc[0], 1) + def test_ops(self): sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])