From aca8ee8a030bf1b6d52c242dc8a9f38bd8c04ae1 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 8 Jan 2022 20:54:53 -0800 Subject: [PATCH 1/4] Stronger typing inference for CoGBK. --- sdks/python/apache_beam/runners/common.py | 3 +- .../apache_beam/transforms/combiners.py | 1 + sdks/python/apache_beam/transforms/util.py | 53 ++++++++++++++----- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 44408e2586a8..8a7050e2dab4 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -1475,6 +1475,7 @@ def windows(self): def group_by_key_input_visitor(deterministic_key_coders=True): # Importing here to avoid a circular dependency + # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.pipeline import PipelineVisitor from apache_beam.transforms.core import GroupByKey @@ -1492,8 +1493,6 @@ def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): - # Imported here to avoid circular dependencies. - # pylint: disable=wrong-import-order, wrong-import-position if isinstance(transform_node.transform, GroupByKey): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index a22b408378e6..ba126f3366b6 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -119,6 +119,7 @@ def for_input_type(self, input_type): class Count(object): """Combiners for counting elements.""" + @with_input_types(T) @with_output_types(int) class Globally(CombinerWithoutDefaults): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 14ed2ad559bc..46ff59528e09 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -64,6 +64,7 @@ from apache_beam.transforms.window import NonMergingWindowFn from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import TimestampedValue +from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature from apache_beam.typehints.sharded_key_type import ShardedKeyType from apache_beam.utils import windowed_value @@ -166,7 +167,8 @@ def _extract_input_pvalues(self, pvalueish): def expand(self, pcolls): if isinstance(pcolls, dict): - if all(isinstance(tag, str) and len(tag) < 10 for tag in pcolls.keys()): + tags = list(pcolls.keys()) + if all(isinstance(tag, str) and len(tag) < 10 for tag in tags): # Small, string tags. Pass them as data. pcolls_dict = pcolls restore_tags = None @@ -175,22 +177,49 @@ def expand(self, pcolls): tags = list(pcolls.keys()) pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)} restore_tags = lambda vs: { - tag: vs[str(ix)] - for (ix, tag) in enumerate(tags) + tag: vs[str(ix)] for (ix, tag) in enumerate(tags) } else: # Tags are tuple indices. - num_tags = len(pcolls) - pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)} - restore_tags = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags)) - + tags = [str(ix) for ix in range(len(pcolls))] + pcolls_dict = dict(zip(tags, pcolls)) + restore_tags = lambda vs: tuple(vs[tag] for tag in tags) + + input_key_types = [] + input_value_types = [] + for pcoll in pcolls_dict.values(): + key_type, value_type = typehints.trivial_inference.key_value_types( + pcoll.element_type) + input_key_types.append(key_type) + input_value_types.append(value_type) + output_key_type = typehints.Union[tuple(input_key_types)] + iterable_input_value_types = tuple( + # TODO: Change List[t] to Iterable[t] + typehints.List[t] for t in input_value_types) + + output_value_type = typehints.Dict[ + str, + typehints.Union[iterable_input_value_types]] result = ( - pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline)) + pcolls_dict + | 'CoGroupByKeyImpl' + >> _CoGBKImpl(pipeline=self.pipeline).with_output_types( + typehints.Tuple[output_key_type, output_value_type])) + if restore_tags: - return result | 'RestoreTags' >> MapTuple( - lambda k, vs: (k, restore_tags(vs))) - else: - return result + if isinstance(pcolls, dict): + dict_key_type = typehints.Union[ + tuple(trivial_inference.instance_to_type(tag) for tag in tags)] + output_value_type = typehints.Dict[ + dict_key_type, + typehints.Union[iterable_input_value_types]] + else: + output_value_type = typehints.Tuple[iterable_input_value_types] + result |= 'RestoreTags' >> MapTuple( + lambda k, vs: (k, restore_tags(vs))).with_output_types( + typehints.Tuple[output_key_type, output_value_type]) + + return result class _CoGBKImpl(PTransform): From 18bb0aa2f80dfbbabf768fa4f47c42de9480931d Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 11 Jan 2022 21:05:06 -0800 Subject: [PATCH 2/4] Empty union consistency check fix. --- sdks/python/apache_beam/typehints/typehints.py | 6 +++--- sdks/python/apache_beam/typehints/typehints_test.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 02a095060c8f..e7693ef8b3f4 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1200,9 +1200,9 @@ def is_consistent_with(sub, base): return True sub = normalize(sub, none_as_type=True) base = normalize(base, none_as_type=True) - if isinstance(base, TypeConstraint): - if isinstance(sub, UnionConstraint): - return all(is_consistent_with(c, base) for c in sub.union_types) + if isinstance(sub, UnionConstraint): + return all(is_consistent_with(c, base) for c in sub.union_types) + elif isinstance(base, TypeConstraint): return base._consistent_with_check_(sub) elif isinstance(sub, TypeConstraint): # Nothing but object lives above any type constraints. diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 5711f834fa2c..02a6ec6feaf8 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -218,6 +218,7 @@ def test_nested_compatibility(self): Union[int, Tuple[Any, Any]], Union[Tuple[int, Any], Tuple[Any, int]]) self.assertCompatible(Union[int, SuperClass], SubClass) self.assertCompatible(Union[int, float, SuperClass], Union[int, SubClass]) + self.assertCompatible(int, Union[()]) self.assertNotCompatible(Union[int, SubClass], SuperClass) self.assertNotCompatible( From 5f7482ece35e3df59a367b5104a039b8ac3bf36d Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 11 Jan 2022 21:06:00 -0800 Subject: [PATCH 3/4] yapf --- sdks/python/apache_beam/transforms/combiners.py | 1 - sdks/python/apache_beam/transforms/util.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index ba126f3366b6..a22b408378e6 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -119,7 +119,6 @@ def for_input_type(self, input_type): class Count(object): """Combiners for counting elements.""" - @with_input_types(T) @with_output_types(int) class Globally(CombinerWithoutDefaults): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 46ff59528e09..c7c20263c633 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -177,7 +177,8 @@ def expand(self, pcolls): tags = list(pcolls.keys()) pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)} restore_tags = lambda vs: { - tag: vs[str(ix)] for (ix, tag) in enumerate(tags) + tag: vs[str(ix)] + for (ix, tag) in enumerate(tags) } else: # Tags are tuple indices. @@ -198,21 +199,19 @@ def expand(self, pcolls): typehints.List[t] for t in input_value_types) output_value_type = typehints.Dict[ - str, - typehints.Union[iterable_input_value_types]] + str, typehints.Union[iterable_input_value_types]] result = ( pcolls_dict - | 'CoGroupByKeyImpl' - >> _CoGBKImpl(pipeline=self.pipeline).with_output_types( + | 'CoGroupByKeyImpl' >> + _CoGBKImpl(pipeline=self.pipeline).with_output_types( typehints.Tuple[output_key_type, output_value_type])) if restore_tags: if isinstance(pcolls, dict): - dict_key_type = typehints.Union[ - tuple(trivial_inference.instance_to_type(tag) for tag in tags)] + dict_key_type = typehints.Union[tuple( + trivial_inference.instance_to_type(tag) for tag in tags)] output_value_type = typehints.Dict[ - dict_key_type, - typehints.Union[iterable_input_value_types]] + dict_key_type, typehints.Union[iterable_input_value_types]] else: output_value_type = typehints.Tuple[iterable_input_value_types] result |= 'RestoreTags' >> MapTuple( From 697241eedcc753e9c9d2c74d1ffc578d37b0790e Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 12 Jan 2022 15:51:30 -0800 Subject: [PATCH 4/4] Fix empty CoGBK. --- sdks/python/apache_beam/transforms/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c7c20263c633..756d1bdd2d5d 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -199,7 +199,7 @@ def expand(self, pcolls): typehints.List[t] for t in input_value_types) output_value_type = typehints.Dict[ - str, typehints.Union[iterable_input_value_types]] + str, typehints.Union[iterable_input_value_types or [typehints.Any]]] result = ( pcolls_dict | 'CoGroupByKeyImpl' >>