Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ def to_runner_api(self, unused_context):
return beam_runner_api_pb2.FunctionSpec(urn=self._urn)


def identity(x: T) -> T:
def identity(x):
return x


Expand Down Expand Up @@ -2053,6 +2053,7 @@ def FlatMap(fn=identity, *args, **kwargs): # pylint: disable=invalid-name

pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs)
pardo.label = label

return pardo


Expand Down
32 changes: 32 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
import os
import tempfile
import unittest
from typing import TypeVar

import pytest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.window import FixedWindows
from apache_beam.typehints import TypeCheckError
from apache_beam.typehints import typehints

RETURN_NONE_PARTIAL_WARNING = "No iterator is returned"

Expand Down Expand Up @@ -279,6 +282,16 @@ def failure_callback(e, el):
self.assertFalse(os.path.isfile(tmp_path))


def test_callablewrapper_typehint():
T = TypeVar("T")

def identity(x: T) -> T:
return x

dofn = beam.core.CallableWrapperDoFn(identity)
assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any


class FlatMapTest(unittest.TestCase):
def test_default(self):

Expand All @@ -289,6 +302,25 @@ def test_default(self):
| beam.FlatMap())
assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f']))

def test_default_identity_function_with_typehint(self):
with beam.Pipeline() as pipeline:
letters = (
pipeline
| beam.Create([["abc"]], reshuffle=False)
| beam.FlatMap()
| beam.Map(lambda s: s.upper()).with_input_types(str))

assert_that(letters, equal_to(["ABC"]))

def test_typecheck_with_default(self):
with pytest.raises(TypeCheckError):
with beam.Pipeline() as pipeline:
_ = (
pipeline
| beam.Create([[1, 2, 3]], reshuffle=False)
| beam.FlatMap()
| beam.Map(lambda s: s.upper()).with_input_types(str))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/typehints/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,11 @@ def strip_iterable(self) -> 'IOTypeHints':
output_type = types[0]
except ValueError:
pass
if isinstance(output_type, typehints.TypeVariable):
# We don't know what T yields, so we just assume Any.
return self._replace(
output_types=((typehints.Any, ), {}),
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))

yielded_type = typehints.get_yielded_type(output_type)
return self._replace(
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/typehints/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _test_strip_iterable_fail(self, before):
def test_strip_iterable(self):
self._test_strip_iterable(None, None)
self._test_strip_iterable(typehints.Any, typehints.Any)
self._test_strip_iterable(T, typehints.Any)
self._test_strip_iterable(typehints.Iterable[str], str)
self._test_strip_iterable(typehints.List[str], str)
self._test_strip_iterable(typehints.Iterator[str], str)
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,8 @@ def get_yielded_type(type_hint):
Raises:
ValueError if not iterable.
"""
if isinstance(type_hint, typing.TypeVar):
return typing.Any
if isinstance(type_hint, AnyTypeConstraint):
return type_hint
if is_consistent_with(type_hint, Iterator[Any]):
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/typehints/typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,7 @@ def test_iterables(self):
typehints.get_yielded_type(typehints.Tuple[int, str]))
self.assertEqual(int, typehints.get_yielded_type(typehints.Set[int]))
self.assertEqual(int, typehints.get_yielded_type(typehints.FrozenSet[int]))
self.assertEqual(typing.Any, typehints.get_yielded_type(T))
self.assertEqual(
typehints.Union[int, str],
typehints.get_yielded_type(
Expand Down
Loading