From 151336aecb34485945ed239883a0fde4338b79c3 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Mar 2025 15:13:55 -0400 Subject: [PATCH 1/2] [DO NOT MERGE] Add support for collections dict subclasses --- .../typehints/native_type_compatibility.py | 12 ++++--- .../native_type_compatibility_test.py | 9 ++++++ .../python/apache_beam/typehints/typehints.py | 2 +- .../apache_beam/typehints/typehints_test.py | 31 +++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index da5bd6b0c0c4..18a492e36756 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -68,6 +68,8 @@ collections.abc.Sequence, ] +_CONVERTED_MODULES = ('typing', 'collections', 'collections.abc') + def _get_args(typ): """Returns a list of arguments to the given type. @@ -127,6 +129,10 @@ def _match_is_primitive(match_against): return lambda user_type: _is_primitive(user_type, match_against) +def _match_is_dict(user_type): + return _is_primitive(user_type, dict) or _safe_issubclass(user_type, dict) + + def _match_is_exactly_mapping(user_type): # Avoid unintentionally catching all subtypes (e.g. strings and mappings). expected_origin = collections.abc.Mapping @@ -353,8 +359,7 @@ def convert_to_beam_type(typ): # This is needed to fix https://github.com/apache/beam/issues/33356 pass - elif (typ_module != 'typing') and (typ_module != - 'collections.abc') and not is_builtin(typ): + elif typ_module not in _CONVERTED_MODULES and not is_builtin(typ): # Only translate primitives and types from collections.abc and typing. return typ if (typ_module == 'collections.abc' and @@ -371,8 +376,7 @@ def convert_to_beam_type(typ): # unsupported. _TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any), _TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any), - _TypeMapEntry( - match=_match_is_primitive(dict), arity=2, beam_type=typehints.Dict), + _TypeMapEntry(match=_match_is_dict, arity=2, beam_type=typehints.Dict), _TypeMapEntry( match=_match_is_exactly_iterable, arity=1, diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index e5366260c88e..df4f7153a056 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -253,6 +253,15 @@ def test_convert_to_beam_type_with_collections_types(self): 'sequence of tuples', collections.abc.Sequence[tuple[str, int]], typehints.Sequence[typehints.Tuple[str, int]]), + ( + 'ordered dict', + collections.OrderedDict[str, int], + typehints.Dict[str, int]), + ( + 'default dict', + collections.defaultdict[str, int], + typehints.Dict[str, int]), + ('count', collections.Counter[str, int], typehints.Dict[str, int]), ] for test_case in test_cases: diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 51b1b1ca68d0..217e31de9eda 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1315,7 +1315,7 @@ def normalize(x, none_as_type=False): elif x in _KNOWN_PRIMITIVE_TYPES: return _KNOWN_PRIMITIVE_TYPES[x] elif getattr(x, '__module__', - None) in ('typing', 'collections.abc') or getattr( + None) in ('typing', 'collections', 'collections.abc') or getattr( x, '__origin__', None) in _KNOWN_PRIMITIVE_TYPES: beam_type = native_type_compatibility.convert_to_beam_type(x) if beam_type != x: diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index a81da5abec40..7a174708dd1b 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -714,6 +714,11 @@ def test_type_checks_not_dict(self): 'Dict type-constraint violated. All passed instances ' 'must be of type dict. [1, 2] is of type list.', e.exception.args[0]) + + def test_type_check_collection(self): + hint = typehints.Dict[str, int] + l = collections.defaultdict(list[("blue", 2)]) + self.assertIsNone(hint.type_check(l)) def test_type_check_invalid_key_type(self): hint = typehints.Dict[typehints.Tuple[int, int, int], typehints.List[str]] @@ -767,12 +772,38 @@ def test_normalize_with_builtin_dict(self): converted_beam_type = typehints.normalize(dict[str, int], False) self.assertEqual(converted_beam_type, expected_beam_type) + def test_normalize_with_collections_dicts(self): + test_cases = [ + ( + 'default dict', + collections.defaultdict[str, bool], + typehints.Dict[str, bool]), + ( + 'ordered dict', + collections.OrderedDict[str, bool], + typehints.Dict[str, bool]), + ('counter', collections.Counter[str, int], typehints.Dict[str, int]), + ] + for test_case in test_cases: + description = test_case[0] + collections_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = typehints.normalize(collections_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) + def test_builtin_and_type_compatibility(self): self.assertCompatible(dict, typing.Dict) self.assertCompatible(dict[str, int], typing.Dict[str, int]) self.assertCompatible( dict[str, list[int]], typing.Dict[str, typing.List[int]]) + def test_collections_subclass_compatibility(self): + self.assertCompatible( + collections.defaultdict[str, bool], typing.Dict[str, bool]) + self.assertCompatible( + collections.OrderedDict[str, int], typing.Dict[str, int]) + self.assertCompatible(collections.Counter[str, int], typing.Dict[str, int]) + class BaseSetHintTest: class CommonTests(TypeHintTestCase): From 9f037436cdbe3a17dfbc68f7581455bed99e5809 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 11 Mar 2025 15:32:43 -0400 Subject: [PATCH 2/2] yapf --- sdks/python/apache_beam/typehints/typehints_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 7a174708dd1b..3b87de23d199 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -714,7 +714,7 @@ def test_type_checks_not_dict(self): 'Dict type-constraint violated. All passed instances ' 'must be of type dict. [1, 2] is of type list.', e.exception.args[0]) - + def test_type_check_collection(self): hint = typehints.Dict[str, int] l = collections.defaultdict(list[("blue", 2)])