diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 5f68294972c6..6245bb7703fb 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -585,7 +585,7 @@ def default_type_hints(self): # TODO(sourabhbajaj): Do we want to remove the responsibility of these from # the DoFn or maybe the runner def infer_output_type(self, input_type): - # TODO(robertwb): Side inputs types. + # TODO(BEAM-8247): Side inputs types. # TODO(robertwb): Assert compatibility with input type hint? return self._strip_output_annotations( trivial_inference.infer_return_type(self.process, [input_type])) @@ -618,7 +618,9 @@ def _fn_takes_side_inputs(fn): # We can't tell; maybe it does. return True - return len(signature.parameters) > 1 + return (len(signature.parameters) > 1 or + any(p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD + for p in signature.parameters.values())) class CallableWrapperDoFn(DoFn): diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 82c4a63c4225..48129ecbec56 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -221,9 +221,52 @@ def repeat(s, *times): result = ['a', 'bb', 'c'] | beam.Map(repeat, 3) self.assertEqual(['aaa', 'bbbbbb', 'ccc'], sorted(result)) - # TODO(robertwb): Support partially defined varargs. - # with self.assertRaises(typehints.TypeCheckError): - # ['a', 'bb', 'c'] | beam.Map(repeat, 'z') + if sys.version_info >= (3,): + with self.assertRaisesRegexp( + typehints.TypeCheckError, + r'requires Tuple\[int, ...\] but got Tuple\[str, ...\]'): + ['a', 'bb', 'c'] | beam.Map(repeat, 'z') + + def test_var_positional_only_side_input_hint(self): + # Test that a lambda that accepts only a VAR_POSITIONAL can accept + # side-inputs. + # TODO(BEAM-8247): There's a bug with trivial_inference inferring the output + # type when side-inputs are used (their type hints are not passed). Remove + # with_output_types(...) when this bug is fixed. + result = (['a', 'b', 'c'] + | beam.Map(lambda *args: args, 5).with_input_types(int, str) + .with_output_types(typehints.Tuple[str, int])) + self.assertEqual([('a', 5), ('b', 5), ('c', 5)], sorted(result)) + + # Type hint order doesn't matter for VAR_POSITIONAL. + result = (['a', 'b', 'c'] + | beam.Map(lambda *args: args, 5).with_input_types(int, str) + .with_output_types(typehints.Tuple[str, int])) + self.assertEqual([('a', 5), ('b', 5), ('c', 5)], sorted(result)) + + if sys.version_info >= (3,): + with self.assertRaisesRegexp( + typehints.TypeCheckError, + r'requires Tuple\[Union\[int, str\], ...\] but got ' + r'Tuple\[Union\[float, int\], ...\]'): + _ = [1.2] | beam.Map(lambda *_: 'a', 5).with_input_types(int, str) + + def test_var_keyword_side_input_hint(self): + # Test that a lambda that accepts a VAR_KEYWORD can accept + # side-inputs. + result = (['a', 'b', 'c'] + | beam.Map(lambda e, **kwargs: (e, kwargs), kw=5) + .with_input_types(str, ignored=int)) + self.assertEqual([('a', {'kw': 5}), ('b', {'kw': 5}), ('c', {'kw': 5})], + sorted(result)) + + if sys.version_info >= (3,): + with self.assertRaisesRegexp( + typehints.TypeCheckError, + r'requires Dict\[str, str\] but got Dict\[str, int\]'): + _ = (['a', 'b', 'c'] + | beam.Map(lambda e, **_: 'a', kw=5) + .with_input_types(str, ignored=str)) def test_deferred_side_inputs(self): @typehints.with_input_types(str, int)