diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 7e6b1f9d5b88..d8b93133e448 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -169,7 +169,8 @@ class Top(object): """Combiners for obtaining extremal elements.""" # pylint: disable=no-self-argument - + @with_input_types(T) + @with_output_types(List[T]) class Of(CombinerWithoutDefaults): """Obtain a list of the compare-most N elements in a PCollection. @@ -203,10 +204,12 @@ def expand(self, pcoll): # This is a more efficient global algorithm. top_per_bundle = pcoll | core.ParDo( _TopPerBundle(self._n, self._key, self._reverse)) - # If pcoll is empty, we can't guerentee that top_per_bundle + # If pcoll is empty, we can't guarantee that top_per_bundle # won't be empty, so inject at least one empty accumulator - # so that downstream is guerenteed to produce non-empty output. - empty_bundle = pcoll.pipeline | core.Create([(None, [])]) + # so that downstream is guaranteed to produce non-empty output. + empty_bundle = ( + pcoll.pipeline | core.Create([(None, [])]).with_output_types( + top_per_bundle.element_type)) return ((top_per_bundle, empty_bundle) | core.Flatten() | core.GroupByKey() | core.ParDo( @@ -220,6 +223,8 @@ def expand(self, pcoll): TopCombineFn(self._n, self._key, self._reverse)).without_defaults() + @with_input_types(Tuple[K, V]) + @with_output_types(Tuple[K, List[V]]) class PerKey(ptransform.PTransform): """Identifies the compare-most N elements associated with each key. @@ -525,6 +530,8 @@ class Sample(object): # pylint: disable=no-self-argument + @with_input_types(T) + @with_output_types(List[T]) class FixedSizeGlobally(CombinerWithoutDefaults): """Sample n elements from the input PCollection without replacement.""" def __init__(self, n): @@ -544,6 +551,8 @@ def display_data(self): def default_label(self): return 'FixedSizeGlobally(%d)' % self._n + @with_input_types(Tuple[K, V]) + @with_output_types(Tuple[K, List[V]]) class FixedSizePerKey(ptransform.PTransform): """Sample n elements associated with each key without replacement.""" def __init__(self, n): @@ -692,6 +701,8 @@ def add_input(self, accumulator, element, *args, **kwargs): ] +@with_input_types(T) +@with_output_types(List[T]) class ToList(CombinerWithoutDefaults): """A global CombineFn that condenses a PCollection into a single list.""" def expand(self, pcoll): @@ -720,6 +731,8 @@ def extract_output(self, accumulator): return accumulator +@with_input_types(Tuple[K, V]) +@with_output_types(Dict[K, V]) class ToDict(CombinerWithoutDefaults): """A global CombineFn that condenses a PCollection into a single dict. @@ -757,6 +770,8 @@ def extract_output(self, accumulator): return accumulator +@with_input_types(T) +@with_output_types(Set[T]) class ToSet(CombinerWithoutDefaults): """A global CombineFn that condenses a PCollection into a set.""" def expand(self, pcoll): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index ec001e284fbb..9ce33f501294 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2333,9 +2333,8 @@ def test_per_key_pipeline_checking_violated(self): self.assertStartswith( e.exception.args[0], - "Type hint violation for 'CombinePerKey(TopCombineFn)': " - "requires Tuple[TypeVariable[K], TypeVariable[T]] " - "but got {} for element".format(int)) + "Input type hint violation at TopMod: expected Tuple[TypeVariable[K], " + "TypeVariable[V]], got {}".format(int)) def test_per_key_pipeline_checking_satisfied(self): d = ( @@ -2492,10 +2491,8 @@ def test_to_dict_pipeline_check_violated(self): self.assertStartswith( e.exception.args[0], - "Type hint violation for 'CombinePerKey': " - "requires " - "Tuple[TypeVariable[K], Tuple[TypeVariable[K], TypeVariable[V]]] " - "but got Tuple[None, int] for element") + "Input type hint violation at ToDict: expected Tuple[TypeVariable[K], " + "TypeVariable[V]], got {}".format(int)) def test_to_dict_pipeline_check_satisfied(self): d = (