diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 0ea95228aeee..eceac0b1857f 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -28,6 +28,7 @@ from apache_beam.yaml.yaml_transform import ensure_errors_consumed from apache_beam.yaml.yaml_transform import ensure_transforms_have_types from apache_beam.yaml.yaml_transform import expand_composite_transform +from apache_beam.yaml.yaml_transform import expand_pipeline from apache_beam.yaml.yaml_transform import extract_name from apache_beam.yaml.yaml_transform import identify_object from apache_beam.yaml.yaml_transform import normalize_inputs_outputs @@ -988,6 +989,69 @@ def test_init_with_dict(self): self.assertEqual(result._spec['type'], "composite") # preprocessed spec +class ExpandPipelineTest(unittest.TestCase): + def test_expand_pipeline_with_pipeline_key_only(self): + spec = ''' + pipeline: + type: chain + transforms: + - type: Create + config: + elements: [1,2,3] + - type: LogForTesting + ''' + with new_pipeline() as p: + expand_pipeline(p, spec, validate_schema=None) + + def test_expand_pipeline_with_pipeline_and_option_keys(self): + spec = ''' + pipeline: + type: chain + transforms: + - type: Create + config: + elements: [1,2,3] + - type: LogForTesting + options: + streaming: false + ''' + with new_pipeline() as p: + expand_pipeline(p, spec, validate_schema=None) + + def test_expand_pipeline_with_extra_top_level_keys(self): + spec = ''' + template: + version: "1.0" + author: "test_user" + + pipeline: + type: chain + transforms: + - type: Create + config: + elements: [1,2,3] + - type: LogForTesting + + other_metadata: "This is an ignored comment." + ''' + with new_pipeline() as p: + expand_pipeline(p, spec, validate_schema=None) + + def test_expand_pipeline_with_incorrect_pipelines_key_fails(self): + spec = ''' + pipelines: + type: chain + transforms: + - type: Create + config: + elements: [1,2,3] + - type: LogForTesting + ''' + with new_pipeline() as p: + with self.assertRaises(KeyError): + expand_pipeline(p, spec, validate_schema=None) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()