diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 53c36a00738f..9818fb1e14b7 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2016,7 +2016,7 @@ def to_runner_api(self, unused_context): return beam_runner_api_pb2.FunctionSpec(urn=self._urn) -def identity(x: T) -> T: +def identity(x): return x @@ -2053,6 +2053,7 @@ def FlatMap(fn=identity, *args, **kwargs): # pylint: disable=invalid-name pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs) pardo.label = label + return pardo diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 54afb365d2d8..542544bce3c1 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -22,6 +22,7 @@ import os import tempfile import unittest +from typing import TypeVar import pytest @@ -29,6 +30,8 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows +from apache_beam.typehints import TypeCheckError +from apache_beam.typehints import typehints RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" @@ -279,6 +282,16 @@ def failure_callback(e, el): self.assertFalse(os.path.isfile(tmp_path)) +def test_callablewrapper_typehint(): + T = TypeVar("T") + + def identity(x: T) -> T: + return x + + dofn = beam.core.CallableWrapperDoFn(identity) + assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any + + class FlatMapTest(unittest.TestCase): def test_default(self): @@ -289,6 +302,25 @@ def test_default(self): | beam.FlatMap()) assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f'])) + def test_default_identity_function_with_typehint(self): + with beam.Pipeline() as pipeline: + letters = ( + pipeline + | beam.Create([["abc"]], reshuffle=False) + | beam.FlatMap() + | beam.Map(lambda s: s.upper()).with_input_types(str)) + + assert_that(letters, equal_to(["ABC"])) + + def test_typecheck_with_default(self): + with pytest.raises(TypeCheckError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | beam.Create([[1, 2, 3]], reshuffle=False) + | beam.FlatMap() + | beam.Map(lambda s: s.upper()).with_input_types(str)) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index dd650aa2fc0b..d7bf1ca9248e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -424,6 +424,11 @@ def strip_iterable(self) -> 'IOTypeHints': output_type = types[0] except ValueError: pass + if isinstance(output_type, typehints.TypeVariable): + # We don't know what T yields, so we just assume Any. + return self._replace( + output_types=((typehints.Any, ), {}), + origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) yielded_type = typehints.get_yielded_type(output_type) return self._replace( diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index dd110ced5bb8..a2909b4e545f 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -131,6 +131,7 @@ def _test_strip_iterable_fail(self, before): def test_strip_iterable(self): self._test_strip_iterable(None, None) self._test_strip_iterable(typehints.Any, typehints.Any) + self._test_strip_iterable(T, typehints.Any) self._test_strip_iterable(typehints.Iterable[str], str) self._test_strip_iterable(typehints.List[str], str) self._test_strip_iterable(typehints.Iterator[str], str) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index ef86a0eeb075..26c727380fc0 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1563,6 +1563,8 @@ def get_yielded_type(type_hint): Raises: ValueError if not iterable. """ + if isinstance(type_hint, typing.TypeVar): + return typing.Any if isinstance(type_hint, AnyTypeConstraint): return type_hint if is_consistent_with(type_hint, Iterator[Any]): diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index f5e551f56d83..c5c8b85f8c08 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1654,6 +1654,7 @@ def test_iterables(self): typehints.get_yielded_type(typehints.Tuple[int, str])) self.assertEqual(int, typehints.get_yielded_type(typehints.Set[int])) self.assertEqual(int, typehints.get_yielded_type(typehints.FrozenSet[int])) + self.assertEqual(typing.Any, typehints.get_yielded_type(T)) self.assertEqual( typehints.Union[int, str], typehints.get_yielded_type(