diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 9edf4ace6750..1f0ce1334b9d 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -304,6 +304,9 @@ def expand(self, pcolls): elif isinstance(pcolls, dict): pipeline_arg = {} pcolls = tuple(pcolls.values()) + elif isinstance(pcolls, (tuple, list)): + pipeline_arg = {} + pcolls = tuple(pcolls) else: pipeline_arg = {'pipeline': pcolls.pipeline} pcolls = () diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 19dabbe73a80..418adbe65967 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -123,6 +123,18 @@ def compute_all(self): for transform_id in self._transforms_by_uuid.keys(): self.compute_outputs(transform_id) + def get_input(self, inputs, transform_label, key): + if isinstance(inputs, str): + return self.get_pcollection(inputs) + else: + pipeline = ( + self.root + if isinstance(self.root, beam.Pipeline) else self.root.pipeline) + label = f'{transform_label}-FlattenInputs[{key}]' + return ( + tuple(self.get_pcollection(x) for x in inputs) + | label >> beam.Flatten(pipeline=pipeline)) + def get_pcollection(self, name): if name in self._inputs: return self._inputs[name] @@ -250,29 +262,34 @@ def expand_transform(spec, scope): def expand_leaf_transform(spec, scope): + _LOGGER.info("Expanding %s ", identify_object(spec)) spec = normalize_inputs_outputs(spec) - inputs_dict = { - key: scope.get_pcollection(value) - for (key, value) in spec['input'].items() - } - input_type = spec.get('input_type', 'default') - if input_type == 'list': - inputs = tuple(inputs_dict.values()) - elif input_type == 'map': - inputs = inputs_dict + ptransform = scope.create_ptransform(spec) + transform_label = scope.unique_name(spec, ptransform) + + if spec['type'] == 'Flatten': + # Avoid flattening before the flatten, just to make a nicer graph. + inputs = tuple( + scope.get_pcollection(input) for (key, value) in spec['input'].items() + for input in ([value] if isinstance(value, str) else value)) + else: + inputs_dict = { + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() + } + if len(inputs_dict) == 0: inputs = scope.root elif len(inputs_dict) == 1: inputs = next(iter(inputs_dict.values())) else: inputs = inputs_dict - _LOGGER.info("Expanding %s ", identify_object(spec)) - ptransform = scope.create_ptransform(spec) + try: # TODO: Move validation to construction? with FullyQualifiedNamedTransform.with_filter('*'): - outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform + outputs = inputs | transform_label >> ptransform except Exception as exn: raise ValueError( f"Error apply transform {identify_object(spec)}: {exn}") from exn @@ -291,12 +308,15 @@ def expand_leaf_transform(spec, scope): def expand_composite_transform(spec, scope): spec = normalize_inputs_outputs(normalize_source_sink(spec)) + if 'name' not in spec: + spec['name'] = 'Composite' + transform_label = scope.unique_name(spec, None) inner_scope = Scope( - scope.root, { - key: scope.get_pcollection(value) - for key, - value in spec['input'].items() + scope.root, + { + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() }, spec['transforms'], yaml_provider.merge_providers( @@ -312,17 +332,14 @@ def expand(inputs): for (key, value) in spec['output'].items() } - if 'name' not in spec: - spec['name'] = 'Composite' if spec['name'] is None: # top-level pipeline, don't nest return CompositePTransform.expand(None) else: _LOGGER.info("Expanding %s ", identify_object(spec)) return ({ - key: scope.get_pcollection(value) - for key, - value in spec['input'].items() - } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform() + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() + } or scope.root) | transform_label >> CompositePTransform() def expand_chain_transform(spec, scope): @@ -395,10 +412,8 @@ def normalize_inputs_outputs(spec): def normalize_io(tag): io = spec.get(tag, {}) - if isinstance(io, str): + if isinstance(io, (str, list)): return {tag: io} - elif isinstance(io, list): - return {f'{tag}{ix}': value for ix, value in enumerate(io)} else: return SafeLineLoader.strip_metadata(io, tagged_str=False) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 926d570fa250..052d89280d40 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -105,6 +105,26 @@ def test_chain_with_root(self): ''') assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131])) + def test_implicit_flatten(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: CreateSmall + elements: [1, 2, 3] + - type: Create + name: CreateBig + elements: [100, 200] + - type: PyMap + input: [CreateBig, CreateSmall] + fn: "lambda x: x * x" + output: PyMap + ''') + assert_that(result, equal_to([1, 4, 9, 10000, 40000])) + def test_csv_to_json(self): try: import pandas as pd