diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 44408e2586a8..8a7050e2dab4 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -1475,6 +1475,7 @@ def windows(self): def group_by_key_input_visitor(deterministic_key_coders=True): # Importing here to avoid a circular dependency + # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.pipeline import PipelineVisitor from apache_beam.transforms.core import GroupByKey @@ -1492,8 +1493,6 @@ def enter_composite_transform(self, transform_node): self.visit_transform(transform_node) def visit_transform(self, transform_node): - # Imported here to avoid circular dependencies. - # pylint: disable=wrong-import-order, wrong-import-position if isinstance(transform_node.transform, GroupByKey): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 14ed2ad559bc..756d1bdd2d5d 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -64,6 +64,7 @@ from apache_beam.transforms.window import NonMergingWindowFn from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import TimestampedValue +from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature from apache_beam.typehints.sharded_key_type import ShardedKeyType from apache_beam.utils import windowed_value @@ -166,7 +167,8 @@ def _extract_input_pvalues(self, pvalueish): def expand(self, pcolls): if isinstance(pcolls, dict): - if all(isinstance(tag, str) and len(tag) < 10 for tag in pcolls.keys()): + tags = list(pcolls.keys()) + if all(isinstance(tag, str) and len(tag) < 10 for tag in tags): # Small, string tags. Pass them as data. pcolls_dict = pcolls restore_tags = None @@ -180,17 +182,43 @@ def expand(self, pcolls): } else: # Tags are tuple indices. - num_tags = len(pcolls) - pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)} - restore_tags = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags)) - + tags = [str(ix) for ix in range(len(pcolls))] + pcolls_dict = dict(zip(tags, pcolls)) + restore_tags = lambda vs: tuple(vs[tag] for tag in tags) + + input_key_types = [] + input_value_types = [] + for pcoll in pcolls_dict.values(): + key_type, value_type = typehints.trivial_inference.key_value_types( + pcoll.element_type) + input_key_types.append(key_type) + input_value_types.append(value_type) + output_key_type = typehints.Union[tuple(input_key_types)] + iterable_input_value_types = tuple( + # TODO: Change List[t] to Iterable[t] + typehints.List[t] for t in input_value_types) + + output_value_type = typehints.Dict[ + str, typehints.Union[iterable_input_value_types or [typehints.Any]]] result = ( - pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline)) + pcolls_dict + | 'CoGroupByKeyImpl' >> + _CoGBKImpl(pipeline=self.pipeline).with_output_types( + typehints.Tuple[output_key_type, output_value_type])) + if restore_tags: - return result | 'RestoreTags' >> MapTuple( - lambda k, vs: (k, restore_tags(vs))) - else: - return result + if isinstance(pcolls, dict): + dict_key_type = typehints.Union[tuple( + trivial_inference.instance_to_type(tag) for tag in tags)] + output_value_type = typehints.Dict[ + dict_key_type, typehints.Union[iterable_input_value_types]] + else: + output_value_type = typehints.Tuple[iterable_input_value_types] + result |= 'RestoreTags' >> MapTuple( + lambda k, vs: (k, restore_tags(vs))).with_output_types( + typehints.Tuple[output_key_type, output_value_type]) + + return result class _CoGBKImpl(PTransform): diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 02a095060c8f..e7693ef8b3f4 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1200,9 +1200,9 @@ def is_consistent_with(sub, base): return True sub = normalize(sub, none_as_type=True) base = normalize(base, none_as_type=True) - if isinstance(base, TypeConstraint): - if isinstance(sub, UnionConstraint): - return all(is_consistent_with(c, base) for c in sub.union_types) + if isinstance(sub, UnionConstraint): + return all(is_consistent_with(c, base) for c in sub.union_types) + elif isinstance(base, TypeConstraint): return base._consistent_with_check_(sub) elif isinstance(sub, TypeConstraint): # Nothing but object lives above any type constraints. diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 5711f834fa2c..02a6ec6feaf8 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -218,6 +218,7 @@ def test_nested_compatibility(self): Union[int, Tuple[Any, Any]], Union[Tuple[int, Any], Tuple[Any, int]]) self.assertCompatible(Union[int, SuperClass], SubClass) self.assertCompatible(Union[int, float, SuperClass], Union[int, SubClass]) + self.assertCompatible(int, Union[()]) self.assertNotCompatible(Union[int, SubClass], SuperClass) self.assertNotCompatible(