From c7b363bac34980d2117bb04d80e5f26f3c643d34 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 25 Apr 2023 15:28:01 -0700 Subject: [PATCH 1/3] Allow implicit flattening for yaml inputs. --- sdks/python/apache_beam/yaml/yaml_provider.py | 3 + .../python/apache_beam/yaml/yaml_transform.py | 64 +++++++++++-------- .../apache_beam/yaml/yaml_transform_test.py | 20 ++++++ 3 files changed, 62 insertions(+), 25 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 9edf4ace6750..f4e44e8b81bc 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -304,6 +304,9 @@ def expand(self, pcolls): elif isinstance(pcolls, dict): pipeline_arg = {} pcolls = tuple(pcolls.values()) + elif isinstance(pcolls, (tuple, list)): + pipeline_arg = {} + pcolls = pcolls else: pipeline_arg = {'pipeline': pcolls.pipeline} pcolls = () diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 19dabbe73a80..3e9fe6f41342 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -123,6 +123,16 @@ def compute_all(self): for transform_id in self._transforms_by_uuid.keys(): self.compute_outputs(transform_id) + def get_input(self, inputs, transform_label, key): + if isinstance(inputs, str): + return self.get_pcollection(inputs) + else: + return tuple( + self.get_pcollection(x) for x in + inputs) | f'{transform_label}-FlattenInputs[{key}]' >> beam.Flatten( + pipeline=self.root if isinstance(self.root, beam.Pipeline + ) else self.root.pipeline) + def get_pcollection(self, name): if name in self._inputs: return self._inputs[name] @@ -250,29 +260,35 @@ def expand_transform(spec, scope): def expand_leaf_transform(spec, scope): + _LOGGER.info("Expanding %s ", identify_object(spec)) spec = normalize_inputs_outputs(spec) - inputs_dict = { - key: scope.get_pcollection(value) - for (key, value) in spec['input'].items() - } - input_type = spec.get('input_type', 'default') - if input_type == 'list': - inputs = tuple(inputs_dict.values()) - elif input_type == 'map': - inputs = inputs_dict + ptransform = scope.create_ptransform(spec) + transform_label = scope.unique_name(spec, ptransform) + + if spec['type'] == 'Flatten': + # Avoid flattening before the flatten, just to make a nicer graph. + inputs = tuple( + scope.get_pcollection(input) for key, + value in spec['input'].items() + for input in ([value] if isinstance(value, str) else value)) + else: + inputs_dict = { + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() + } + if len(inputs_dict) == 0: inputs = scope.root elif len(inputs_dict) == 1: inputs = next(iter(inputs_dict.values())) else: inputs = inputs_dict - _LOGGER.info("Expanding %s ", identify_object(spec)) - ptransform = scope.create_ptransform(spec) + try: # TODO: Move validation to construction? with FullyQualifiedNamedTransform.with_filter('*'): - outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform + outputs = inputs | transform_label >> ptransform except Exception as exn: raise ValueError( f"Error apply transform {identify_object(spec)}: {exn}") from exn @@ -291,12 +307,15 @@ def expand_leaf_transform(spec, scope): def expand_composite_transform(spec, scope): spec = normalize_inputs_outputs(normalize_source_sink(spec)) + if 'name' not in spec: + spec['name'] = 'Composite' + transform_label = scope.unique_name(spec, None) inner_scope = Scope( - scope.root, { - key: scope.get_pcollection(value) - for key, - value in spec['input'].items() + scope.root, + { + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() }, spec['transforms'], yaml_provider.merge_providers( @@ -312,17 +331,14 @@ def expand(inputs): for (key, value) in spec['output'].items() } - if 'name' not in spec: - spec['name'] = 'Composite' if spec['name'] is None: # top-level pipeline, don't nest return CompositePTransform.expand(None) else: _LOGGER.info("Expanding %s ", identify_object(spec)) return ({ - key: scope.get_pcollection(value) - for key, - value in spec['input'].items() - } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform() + key: scope.get_input(value, transform_label, key) + for (key, value) in spec['input'].items() + } or scope.root) | transform_label >> CompositePTransform() def expand_chain_transform(spec, scope): @@ -395,10 +411,8 @@ def normalize_inputs_outputs(spec): def normalize_io(tag): io = spec.get(tag, {}) - if isinstance(io, str): + if isinstance(io, (str, list)): return {tag: io} - elif isinstance(io, list): - return {f'{tag}{ix}': value for ix, value in enumerate(io)} else: return SafeLineLoader.strip_metadata(io, tagged_str=False) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 926d570fa250..052d89280d40 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -105,6 +105,26 @@ def test_chain_with_root(self): ''') assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131])) + def test_implicit_flatten(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: CreateSmall + elements: [1, 2, 3] + - type: Create + name: CreateBig + elements: [100, 200] + - type: PyMap + input: [CreateBig, CreateSmall] + fn: "lambda x: x * x" + output: PyMap + ''') + assert_that(result, equal_to([1, 4, 9, 10000, 40000])) + def test_csv_to_json(self): try: import pandas as pd From 33a1dace545a16b475098df369c9c93c553eba16 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 25 Apr 2023 16:10:44 -0700 Subject: [PATCH 2/3] lint --- sdks/python/apache_beam/yaml/yaml_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index f4e44e8b81bc..1f0ce1334b9d 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -306,7 +306,7 @@ def expand(self, pcolls): pcolls = tuple(pcolls.values()) elif isinstance(pcolls, (tuple, list)): pipeline_arg = {} - pcolls = pcolls + pcolls = tuple(pcolls) else: pipeline_arg = {'pipeline': pcolls.pipeline} pcolls = () From 2f337a48f3c78c0ab803c12bb4a4442f60021140 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 27 Apr 2023 12:40:24 -0700 Subject: [PATCH 3/3] reformatting for readability --- sdks/python/apache_beam/yaml/yaml_transform.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 3e9fe6f41342..418adbe65967 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -127,11 +127,13 @@ def get_input(self, inputs, transform_label, key): if isinstance(inputs, str): return self.get_pcollection(inputs) else: - return tuple( - self.get_pcollection(x) for x in - inputs) | f'{transform_label}-FlattenInputs[{key}]' >> beam.Flatten( - pipeline=self.root if isinstance(self.root, beam.Pipeline - ) else self.root.pipeline) + pipeline = ( + self.root + if isinstance(self.root, beam.Pipeline) else self.root.pipeline) + label = f'{transform_label}-FlattenInputs[{key}]' + return ( + tuple(self.get_pcollection(x) for x in inputs) + | label >> beam.Flatten(pipeline=pipeline)) def get_pcollection(self, name): if name in self._inputs: @@ -268,8 +270,7 @@ def expand_leaf_transform(spec, scope): if spec['type'] == 'Flatten': # Avoid flattening before the flatten, just to make a nicer graph. inputs = tuple( - scope.get_pcollection(input) for key, - value in spec['input'].items() + scope.get_pcollection(input) for (key, value) in spec['input'].items() for input in ([value] if isinstance(value, str) else value)) else: