From 7ca622128a37f8bb90aaa99665020ef734a204d9 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Wed, 4 Jun 2025 23:05:48 -0400 Subject: [PATCH 01/10] create unit test --- sdks/python/apache_beam/transforms/core_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 54afb365d2d8..dc0827428f3d 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -289,6 +289,16 @@ def test_default(self): | beam.FlatMap()) assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f'])) + def test_default_with_typehint(self): + with beam.Pipeline() as pipeline: + letters = ( + pipeline + | beam.Create([["abc"]], reshuffle=False) + | beam.FlatMap() + | beam.Map(lambda s: s.upper()).with_input_types(str)) + + assert_that(letters, equal_to(["ABC"])) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From fa1fdf48b23039d3dd67c2d0e1851eb820ecdf7c Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 08:00:07 -0400 Subject: [PATCH 02/10] add some unit tests --- sdks/python/apache_beam/transforms/core_test.py | 12 ++++++++++++ sdks/python/apache_beam/typehints/decorators_test.py | 1 + sdks/python/apache_beam/typehints/typehints_test.py | 2 ++ 3 files changed, 15 insertions(+) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index dc0827428f3d..03980c9c0f83 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -21,7 +21,9 @@ import logging import os import tempfile +import typing import unittest +from typing import TypeVar import pytest @@ -289,6 +291,16 @@ def test_default(self): | beam.FlatMap()) assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f'])) + def test_callablewrapper_typehint(self): + T = TypeVar("T") + + def identity(x: T) -> T: + return x + + dofn = beam.core.CallableWrapperDoFn(identity) + assert dofn.get_type_hints() is None + assert dofn.get_type_hints().strip_iterable()[1][0][0] is typing.Any + def test_default_with_typehint(self): with beam.Pipeline() as pipeline: letters = ( diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index dd110ced5bb8..cdac821dcb82 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -131,6 +131,7 @@ def _test_strip_iterable_fail(self, before): def test_strip_iterable(self): self._test_strip_iterable(None, None) self._test_strip_iterable(typehints.Any, typehints.Any) + self._test_strip_iterable(T, None) self._test_strip_iterable(typehints.Iterable[str], str) self._test_strip_iterable(typehints.List[str], str) self._test_strip_iterable(typehints.Iterator[str], str) diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index f5e551f56d83..fe913ddef124 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1662,6 +1662,8 @@ def test_iterables(self): def test_not_iterable(self): with self.assertRaisesRegex(ValueError, r'not iterable'): typehints.get_yielded_type(int) + with self.assertRaisesRegex(ValueError, r'not iterable'): + typehints.get_yielded_type(T) def test_union_not_iterable(self): with self.assertRaisesRegex(ValueError, r'not iterable'): From 1659891a512c0b5f4ce58047c0523600cd2e6c98 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 08:50:05 -0400 Subject: [PATCH 03/10] fix bug where T is considered iterable --- sdks/python/apache_beam/typehints/decorators.py | 7 +++++++ sdks/python/apache_beam/typehints/decorators_test.py | 2 +- sdks/python/apache_beam/typehints/typehints.py | 3 +++ sdks/python/apache_beam/typehints/typehints_test.py | 2 ++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index dd650aa2fc0b..2057a55f41e2 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -424,6 +424,13 @@ def strip_iterable(self) -> 'IOTypeHints': output_type = types[0] except ValueError: pass + # if output_type == T, we can't strip it. + print( + 'output_type:', + output_type, + f"{isinstance(output_type, typehints.TypeVariable) = }") + if isinstance(output_type, typehints.TypeVariable): + return self yielded_type = typehints.get_yielded_type(output_type) return self._replace( diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index cdac821dcb82..cbe9000e63ad 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -131,7 +131,6 @@ def _test_strip_iterable_fail(self, before): def test_strip_iterable(self): self._test_strip_iterable(None, None) self._test_strip_iterable(typehints.Any, typehints.Any) - self._test_strip_iterable(T, None) self._test_strip_iterable(typehints.Iterable[str], str) self._test_strip_iterable(typehints.List[str], str) self._test_strip_iterable(typehints.Iterator[str], str) @@ -144,6 +143,7 @@ def test_strip_iterable(self): self._test_strip_iterable(typehints.Set[str], str) self._test_strip_iterable(typehints.FrozenSet[str], str) + #self._test_strip_iterable_fail(T) self._test_strip_iterable_fail(typehints.Union[str, int]) self._test_strip_iterable_fail(typehints.Optional[str]) self._test_strip_iterable_fail( diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index ef86a0eeb075..e3193e08aebf 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1563,7 +1563,10 @@ def get_yielded_type(type_hint): Raises: ValueError if not iterable. """ + if isinstance(type_hint, TypeVariable): + raise ValueError('%s is not iterable' % type_hint) if isinstance(type_hint, AnyTypeConstraint): + print(f"{type_hint} is considered a AnyTypeConstraint") return type_hint if is_consistent_with(type_hint, Iterator[Any]): return type_hint.yielded_type diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index fe913ddef124..466da2a73319 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1664,6 +1664,8 @@ def test_not_iterable(self): typehints.get_yielded_type(int) with self.assertRaisesRegex(ValueError, r'not iterable'): typehints.get_yielded_type(T) + with self.assertRaisesRegex(ValueError, r'not iterable'): + typehints.get_yielded_type(typehints.TypeVariable("T")) def test_union_not_iterable(self): with self.assertRaisesRegex(ValueError, r'not iterable'): From 264b1f09bb12aa77137576cd64cd8944b780842f Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 09:09:11 -0400 Subject: [PATCH 04/10] update strip_iterable to return Any for "stripped iterable" type of TypeVariable --- sdks/python/apache_beam/transforms/core_test.py | 5 ++--- sdks/python/apache_beam/typehints/decorators.py | 10 ++++------ sdks/python/apache_beam/typehints/decorators_test.py | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 03980c9c0f83..3e4b62d4144d 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -21,7 +21,6 @@ import logging import os import tempfile -import typing import unittest from typing import TypeVar @@ -31,6 +30,7 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows +from apache_beam.typehints import typehints RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" @@ -298,8 +298,7 @@ def identity(x: T) -> T: return x dofn = beam.core.CallableWrapperDoFn(identity) - assert dofn.get_type_hints() is None - assert dofn.get_type_hints().strip_iterable()[1][0][0] is typing.Any + assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any def test_default_with_typehint(self): with beam.Pipeline() as pipeline: diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 2057a55f41e2..d7bf1ca9248e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -424,13 +424,11 @@ def strip_iterable(self) -> 'IOTypeHints': output_type = types[0] except ValueError: pass - # if output_type == T, we can't strip it. - print( - 'output_type:', - output_type, - f"{isinstance(output_type, typehints.TypeVariable) = }") if isinstance(output_type, typehints.TypeVariable): - return self + # We don't know what T yields, so we just assume Any. + return self._replace( + output_types=((typehints.Any, ), {}), + origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) yielded_type = typehints.get_yielded_type(output_type) return self._replace( diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index cbe9000e63ad..a2909b4e545f 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -131,6 +131,7 @@ def _test_strip_iterable_fail(self, before): def test_strip_iterable(self): self._test_strip_iterable(None, None) self._test_strip_iterable(typehints.Any, typehints.Any) + self._test_strip_iterable(T, typehints.Any) self._test_strip_iterable(typehints.Iterable[str], str) self._test_strip_iterable(typehints.List[str], str) self._test_strip_iterable(typehints.Iterator[str], str) @@ -143,7 +144,6 @@ def test_strip_iterable(self): self._test_strip_iterable(typehints.Set[str], str) self._test_strip_iterable(typehints.FrozenSet[str], str) - #self._test_strip_iterable_fail(T) self._test_strip_iterable_fail(typehints.Union[str, int]) self._test_strip_iterable_fail(typehints.Optional[str]) self._test_strip_iterable_fail( From 9b51d3a2af4843bad334aad8fbcb334ad18931dd Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 09:18:42 -0400 Subject: [PATCH 05/10] remove typehint from identity function and add a test to test for proper typechecking --- sdks/python/apache_beam/transforms/core.py | 3 ++- sdks/python/apache_beam/transforms/core_test.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 53c36a00738f..9818fb1e14b7 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2016,7 +2016,7 @@ def to_runner_api(self, unused_context): return beam_runner_api_pb2.FunctionSpec(urn=self._urn) -def identity(x: T) -> T: +def identity(x): return x @@ -2053,6 +2053,7 @@ def FlatMap(fn=identity, *args, **kwargs): # pylint: disable=invalid-name pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs) pardo.label = label + return pardo diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 3e4b62d4144d..3a885976e100 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -30,7 +30,7 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows -from apache_beam.typehints import typehints +from apache_beam.typehints import typehints, TypeCheckError RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" @@ -300,7 +300,7 @@ def identity(x: T) -> T: dofn = beam.core.CallableWrapperDoFn(identity) assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any - def test_default_with_typehint(self): + def test_default_identity_function_with_typehint(self): with beam.Pipeline() as pipeline: letters = ( pipeline @@ -310,6 +310,15 @@ def test_default_with_typehint(self): assert_that(letters, equal_to(["ABC"])) + def test_typecheck_with_default(self): + with pytest.raises(TypeCheckError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | beam.Create([[1, 2, 3]], reshuffle=False) + | beam.FlatMap() + | beam.Map(lambda s: s.upper()).with_input_types(str)) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 5f96c1b21f89defb4d0fca7ef8f1edfa945176fa Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 09:21:22 -0400 Subject: [PATCH 06/10] Move callablewrapp typehint test --- .../apache_beam/transforms/core_test.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 3a885976e100..254ff56c4b0e 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -281,6 +281,16 @@ def failure_callback(e, el): self.assertFalse(os.path.isfile(tmp_path)) +def test_callablewrapper_typehint(): + T = TypeVar("T") + + def identity(x: T) -> T: + return x + + dofn = beam.core.CallableWrapperDoFn(identity) + assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any + + class FlatMapTest(unittest.TestCase): def test_default(self): @@ -291,15 +301,6 @@ def test_default(self): | beam.FlatMap()) assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f'])) - def test_callablewrapper_typehint(self): - T = TypeVar("T") - - def identity(x: T) -> T: - return x - - dofn = beam.core.CallableWrapperDoFn(identity) - assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any - def test_default_identity_function_with_typehint(self): with beam.Pipeline() as pipeline: letters = ( From 5ff3deaacdbcdb58f58e0c476c0d6c7419a6b938 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 09:27:52 -0400 Subject: [PATCH 07/10] Remove print --- sdks/python/apache_beam/typehints/typehints.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index e3193e08aebf..67dff3b7b4db 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1566,7 +1566,6 @@ def get_yielded_type(type_hint): if isinstance(type_hint, TypeVariable): raise ValueError('%s is not iterable' % type_hint) if isinstance(type_hint, AnyTypeConstraint): - print(f"{type_hint} is considered a AnyTypeConstraint") return type_hint if is_consistent_with(type_hint, Iterator[Any]): return type_hint.yielded_type From e51198187e7072c88077b91dd11c36ab19a8a14c Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 10:39:40 -0400 Subject: [PATCH 08/10] isort --- sdks/python/apache_beam/transforms/core_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 254ff56c4b0e..f9918115e350 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -20,17 +20,17 @@ import logging import os +import pytest import tempfile import unittest from typing import TypeVar -import pytest - import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows -from apache_beam.typehints import typehints, TypeCheckError +from apache_beam.typehints import TypeCheckError +from apache_beam.typehints import typehints RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" From 4505c37e265eef62c7c5b27bee9a847c42416311 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 11:13:21 -0400 Subject: [PATCH 09/10] isort --- sdks/python/apache_beam/transforms/core_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index f9918115e350..542544bce3c1 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -20,11 +20,12 @@ import logging import os -import pytest import tempfile import unittest from typing import TypeVar +import pytest + import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to From a3665940a860076bed1ffb93ab2319b77c41f2b9 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 5 Jun 2025 22:59:15 -0400 Subject: [PATCH 10/10] return any for yielded type of T --- sdks/python/apache_beam/typehints/typehints.py | 4 ++-- sdks/python/apache_beam/typehints/typehints_test.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 67dff3b7b4db..26c727380fc0 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1563,8 +1563,8 @@ def get_yielded_type(type_hint): Raises: ValueError if not iterable. """ - if isinstance(type_hint, TypeVariable): - raise ValueError('%s is not iterable' % type_hint) + if isinstance(type_hint, typing.TypeVar): + return typing.Any if isinstance(type_hint, AnyTypeConstraint): return type_hint if is_consistent_with(type_hint, Iterator[Any]): diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 466da2a73319..c5c8b85f8c08 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1654,6 +1654,7 @@ def test_iterables(self): typehints.get_yielded_type(typehints.Tuple[int, str])) self.assertEqual(int, typehints.get_yielded_type(typehints.Set[int])) self.assertEqual(int, typehints.get_yielded_type(typehints.FrozenSet[int])) + self.assertEqual(typing.Any, typehints.get_yielded_type(T)) self.assertEqual( typehints.Union[int, str], typehints.get_yielded_type( @@ -1662,10 +1663,6 @@ def test_iterables(self): def test_not_iterable(self): with self.assertRaisesRegex(ValueError, r'not iterable'): typehints.get_yielded_type(int) - with self.assertRaisesRegex(ValueError, r'not iterable'): - typehints.get_yielded_type(T) - with self.assertRaisesRegex(ValueError, r'not iterable'): - typehints.get_yielded_type(typehints.TypeVariable("T")) def test_union_not_iterable(self): with self.assertRaisesRegex(ValueError, r'not iterable'):