Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,7 @@ def windows(self):

def group_by_key_input_visitor(deterministic_key_coders=True):
# Importing here to avoid a circular dependency
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.pipeline import PipelineVisitor
from apache_beam.transforms.core import GroupByKey

Expand All @@ -1492,8 +1493,6 @@ def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
if isinstance(transform_node.transform, GroupByKey):
pcoll = transform_node.inputs[0]
pcoll.element_type = typehints.coerce_to_kv_type(
Expand Down
48 changes: 38 additions & 10 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import windowed_value
Expand Down Expand Up @@ -166,7 +167,8 @@ def _extract_input_pvalues(self, pvalueish):

def expand(self, pcolls):
if isinstance(pcolls, dict):
if all(isinstance(tag, str) and len(tag) < 10 for tag in pcolls.keys()):
tags = list(pcolls.keys())
if all(isinstance(tag, str) and len(tag) < 10 for tag in tags):
# Small, string tags. Pass them as data.
pcolls_dict = pcolls
restore_tags = None
Expand All @@ -180,17 +182,43 @@ def expand(self, pcolls):
}
else:
# Tags are tuple indices.
num_tags = len(pcolls)
pcolls_dict = {str(ix): pcolls[ix] for ix in range(num_tags)}
restore_tags = lambda vs: tuple(vs[str(ix)] for ix in range(num_tags))

tags = [str(ix) for ix in range(len(pcolls))]
pcolls_dict = dict(zip(tags, pcolls))
restore_tags = lambda vs: tuple(vs[tag] for tag in tags)

input_key_types = []
input_value_types = []
for pcoll in pcolls_dict.values():
key_type, value_type = typehints.trivial_inference.key_value_types(
pcoll.element_type)
input_key_types.append(key_type)
input_value_types.append(value_type)
output_key_type = typehints.Union[tuple(input_key_types)]
iterable_input_value_types = tuple(
# TODO: Change List[t] to Iterable[t]
typehints.List[t] for t in input_value_types)

output_value_type = typehints.Dict[
str, typehints.Union[iterable_input_value_types or [typehints.Any]]]
result = (
pcolls_dict | 'CoGroupByKeyImpl' >> _CoGBKImpl(pipeline=self.pipeline))
pcolls_dict
| 'CoGroupByKeyImpl' >>
_CoGBKImpl(pipeline=self.pipeline).with_output_types(
typehints.Tuple[output_key_type, output_value_type]))

if restore_tags:
return result | 'RestoreTags' >> MapTuple(
lambda k, vs: (k, restore_tags(vs)))
else:
return result
if isinstance(pcolls, dict):
dict_key_type = typehints.Union[tuple(
trivial_inference.instance_to_type(tag) for tag in tags)]
output_value_type = typehints.Dict[
dict_key_type, typehints.Union[iterable_input_value_types]]
else:
output_value_type = typehints.Tuple[iterable_input_value_types]
result |= 'RestoreTags' >> MapTuple(
lambda k, vs: (k, restore_tags(vs))).with_output_types(
typehints.Tuple[output_key_type, output_value_type])

return result


class _CoGBKImpl(PTransform):
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,9 +1200,9 @@ def is_consistent_with(sub, base):
return True
sub = normalize(sub, none_as_type=True)
base = normalize(base, none_as_type=True)
if isinstance(base, TypeConstraint):
if isinstance(sub, UnionConstraint):
return all(is_consistent_with(c, base) for c in sub.union_types)
if isinstance(sub, UnionConstraint):
return all(is_consistent_with(c, base) for c in sub.union_types)
elif isinstance(base, TypeConstraint):
return base._consistent_with_check_(sub)
elif isinstance(sub, TypeConstraint):
# Nothing but object lives above any type constraints.
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/typehints/typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def test_nested_compatibility(self):
Union[int, Tuple[Any, Any]], Union[Tuple[int, Any], Tuple[Any, int]])
self.assertCompatible(Union[int, SuperClass], SubClass)
self.assertCompatible(Union[int, float, SuperClass], Union[int, SubClass])
self.assertCompatible(int, Union[()])

self.assertNotCompatible(Union[int, SubClass], SuperClass)
self.assertNotCompatible(
Expand Down