Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
collections.abc.Sequence,
]

_CONVERTED_MODULES = ('typing', 'collections', 'collections.abc')


def _get_args(typ):
"""Returns a list of arguments to the given type.
Expand Down Expand Up @@ -127,6 +129,10 @@ def _match_is_primitive(match_against):
return lambda user_type: _is_primitive(user_type, match_against)


def _match_is_dict(user_type):
return _is_primitive(user_type, dict) or _safe_issubclass(user_type, dict)


def _match_is_exactly_mapping(user_type):
# Avoid unintentionally catching all subtypes (e.g. strings and mappings).
expected_origin = collections.abc.Mapping
Expand Down Expand Up @@ -353,8 +359,7 @@ def convert_to_beam_type(typ):
# This is needed to fix https://github.com/apache/beam/issues/33356
pass

elif (typ_module != 'typing') and (typ_module !=
'collections.abc') and not is_builtin(typ):
elif typ_module not in _CONVERTED_MODULES and not is_builtin(typ):
# Only translate primitives and types from collections.abc and typing.
return typ
if (typ_module == 'collections.abc' and
Expand All @@ -371,8 +376,7 @@ def convert_to_beam_type(typ):
# unsupported.
_TypeMapEntry(match=is_forward_ref, arity=0, beam_type=typehints.Any),
_TypeMapEntry(match=is_any, arity=0, beam_type=typehints.Any),
_TypeMapEntry(
match=_match_is_primitive(dict), arity=2, beam_type=typehints.Dict),
_TypeMapEntry(match=_match_is_dict, arity=2, beam_type=typehints.Dict),
_TypeMapEntry(
match=_match_is_exactly_iterable,
arity=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ def test_convert_to_beam_type_with_collections_types(self):
'sequence of tuples',
collections.abc.Sequence[tuple[str, int]],
typehints.Sequence[typehints.Tuple[str, int]]),
(
'ordered dict',
collections.OrderedDict[str, int],
typehints.Dict[str, int]),
(
'default dict',
collections.defaultdict[str, int],
typehints.Dict[str, int]),
('count', collections.Counter[str, int], typehints.Dict[str, int]),
]

for test_case in test_cases:
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def normalize(x, none_as_type=False):
elif x in _KNOWN_PRIMITIVE_TYPES:
return _KNOWN_PRIMITIVE_TYPES[x]
elif getattr(x, '__module__',
None) in ('typing', 'collections.abc') or getattr(
None) in ('typing', 'collections', 'collections.abc') or getattr(
x, '__origin__', None) in _KNOWN_PRIMITIVE_TYPES:
beam_type = native_type_compatibility.convert_to_beam_type(x)
if beam_type != x:
Expand Down
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/typehints/typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,11 @@ def test_type_checks_not_dict(self):
'must be of type dict. [1, 2] is of type list.',
e.exception.args[0])

def test_type_check_collection(self):
hint = typehints.Dict[str, int]
l = collections.defaultdict(list[("blue", 2)])
self.assertIsNone(hint.type_check(l))

def test_type_check_invalid_key_type(self):
hint = typehints.Dict[typehints.Tuple[int, int, int], typehints.List[str]]
d = {(1, 2): ['m', '1', '2', '3']}
Expand Down Expand Up @@ -767,12 +772,38 @@ def test_normalize_with_builtin_dict(self):
converted_beam_type = typehints.normalize(dict[str, int], False)
self.assertEqual(converted_beam_type, expected_beam_type)

def test_normalize_with_collections_dicts(self):
test_cases = [
(
'default dict',
collections.defaultdict[str, bool],
typehints.Dict[str, bool]),
(
'ordered dict',
collections.OrderedDict[str, bool],
typehints.Dict[str, bool]),
('counter', collections.Counter[str, int], typehints.Dict[str, int]),
]
for test_case in test_cases:
description = test_case[0]
collections_type = test_case[1]
expected_beam_type = test_case[2]
converted_beam_type = typehints.normalize(collections_type)
self.assertEqual(converted_beam_type, expected_beam_type, description)

def test_builtin_and_type_compatibility(self):
self.assertCompatible(dict, typing.Dict)
self.assertCompatible(dict[str, int], typing.Dict[str, int])
self.assertCompatible(
dict[str, list[int]], typing.Dict[str, typing.List[int]])

def test_collections_subclass_compatibility(self):
self.assertCompatible(
collections.defaultdict[str, bool], typing.Dict[str, bool])
self.assertCompatible(
collections.OrderedDict[str, int], typing.Dict[str, int])
self.assertCompatible(collections.Counter[str, int], typing.Dict[str, int])


class BaseSetHintTest:
class CommonTests(TypeHintTestCase):
Expand Down
Loading