diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 42087565013c..9ce9b8e21804 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -756,6 +756,43 @@ def dicts_to_rows(o): return o +def _unify_element_with_schema(element, target_schema): + """Convert an element to match the target schema, preserving existing + fields only.""" + if target_schema is None: + return element + + # If element is already a named tuple, convert to dict first + if hasattr(element, '_asdict'): + element_dict = element._asdict() + elif isinstance(element, dict): + element_dict = element + else: + # This element is not a row, so it can't be unified with a + # row schema. + return element + + # Create new element with only the fields that exist in the original + # element plus None for fields that are expected but missing + unified_dict = {} + for field_name in target_schema._fields: + if field_name in element_dict: + value = element_dict[field_name] + # Ensure the value matches the expected type + # This is particularly important for list fields + if value is not None and not isinstance(value, list) and hasattr( + value, '__iter__') and not isinstance( + value, (str, bytes)) and not hasattr(value, '_asdict'): + # Convert iterables to lists if needed + unified_dict[field_name] = list(value) + else: + unified_dict[field_name] = value + else: + unified_dict[field_name] = None + + return target_schema(**unified_dict) + + class YamlProviders: class AssertEqual(beam.PTransform): """Asserts that the input contains exactly the elements provided. @@ -932,6 +969,48 @@ def __init__(self): # pylint: disable=useless-parent-delegation super().__init__() + def _merge_schemas(self, pcolls): + """Merge schemas from multiple PCollections to create a unified schema. + + This function creates a unified schema that contains all fields from all + input PCollections. Fields are made optional to handle missing values. + If fields have different types, they are unified to Optional[Any]. + """ + from apache_beam.typehints.schemas import named_fields_from_element_type + + # Collect all schemas + schemas = [] + for pcoll in pcolls: + if hasattr(pcoll, 'element_type') and pcoll.element_type: + try: + fields = named_fields_from_element_type(pcoll.element_type) + schemas.append(dict(fields)) + except (ValueError, TypeError): + # If we can't extract schema, skip this PCollection + continue + + if not schemas: + return None + + # Merge all field names and types. + all_field_names = set().union(*(s.keys() for s in schemas)) + unified_fields = {} + for name in all_field_names: + present_types = {s[name] for s in schemas if name in s} + if len(present_types) > 1: + unified_fields[name] = Optional[Any] + else: + unified_fields[name] = Optional[present_types.pop()] + + # Create unified schema + if unified_fields: + from apache_beam.typehints.schemas import named_fields_to_schema + from apache_beam.typehints.schemas import named_tuple_from_schema + unified_schema = named_fields_to_schema(list(unified_fields.items())) + return named_tuple_from_schema(unified_schema) + + return None + def expand(self, pcolls): if isinstance(pcolls, beam.PCollection): pipeline_arg = {} @@ -942,7 +1021,27 @@ def expand(self, pcolls): else: pipeline_arg = {'pipeline': pcolls.pipeline} pcolls = () - return pcolls | beam.Flatten(**pipeline_arg) + + if not pcolls: + return pcolls | beam.Flatten(**pipeline_arg) + + # Try to unify schemas + unified_schema = self._merge_schemas(pcolls) + + if unified_schema is None: + # No schema unification needed, use standard flatten + return pcolls | beam.Flatten(**pipeline_arg) + + # Apply schema unification to each PCollection before flattening. + unified_pcolls = [] + for i, pcoll in enumerate(pcolls): + unified_pcoll = pcoll | f'UnifySchema{i}' >> beam.Map( + _unify_element_with_schema, + target_schema=unified_schema).with_output_types(unified_schema) + unified_pcolls.append(unified_pcoll) + + # Flatten the unified PCollections + return unified_pcolls | beam.Flatten(**pipeline_arg) class WindowInto(beam.PTransform): # pylint: disable=line-too-long diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 8710fe379c37..341cd8d65f41 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -861,8 +861,8 @@ def preprocess_flattened_inputs(spec): def all_inputs(t): for key, values in t.get('input', {}).items(): if isinstance(values, list): - for ix, values in enumerate(values): - yield f'{key}{ix}', values + for ix, value in enumerate(values): + yield f'{key}{ix}', value else: yield key, values diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 1a99507d76d7..543f13eeff58 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -477,6 +477,216 @@ def test_composite_resource_hints(self): b'1000000000', proto) + def test_flatten_unifies_schemas(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {ride_id: '1', passenger_count: 1} + - {ride_id: '2', passenger_count: 2} + - type: Create + name: Create2 + config: + elements: + - {ride_id: '3'} + - {ride_id: '4'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {ride_id: '1', passenger_count: 1} + - {ride_id: '2', passenger_count: 2} + - {ride_id: '3'} + - {ride_id: '4'} + ''') + + def test_flatten_unifies_optional_fields(self): + """Test that Flatten correctly unifies schemas with optional fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', name: 'Alice', age: 30} + - {id: '2', name: 'Bob', age: 25} + - type: Create + name: Create2 + config: + elements: + - {id: '3', name: 'Charlie'} + - {id: '4', name: 'Diana'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', name: 'Alice', age: 30} + - {id: '2', name: 'Bob', age: 25} + - {id: '3', name: 'Charlie'} + - {id: '4', name: 'Diana'} + ''') + + def test_flatten_unifies_different_types(self): + """Test that Flatten correctly unifies schemas with different + field types.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: 1, value: 100} + - {id: 2, value: 200} + - type: Create + name: Create2 + config: + elements: + - {id: '3', value: 'text'} + - {id: '4', value: 'data'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: 1, value: 100} + - {id: 2, value: 200} + - {id: '3', value: 'text'} + - {id: '4', value: 'data'} + ''') + + def test_flatten_unifies_list_fields(self): + """Test that Flatten correctly unifies schemas with list fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', tags: ['red', 'blue']} + - {id: '2', tags: ['green']} + - type: Create + name: Create2 + config: + elements: + - {id: '3', tags: ['yellow', 'purple', 'orange']} + - {id: '4', tags: []} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', tags: ['red', 'blue']} + - {id: '2', tags: ['green']} + - {id: '3', tags: ['yellow', 'purple', 'orange']} + - {id: '4', tags: []} + ''') + + def test_flatten_unifies_with_missing_fields(self): + """Test that Flatten correctly unifies schemas when some inputs have + missing fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', name: 'Alice', department: 'Engineering', + salary: 75000} + - {id: '2', name: 'Bob', department: 'Marketing', + salary: 65000} + - type: Create + name: Create2 + config: + elements: + - {id: '3', name: 'Charlie', department: 'Sales'} + - {id: '4', name: 'Diana'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', name: 'Alice', department: 'Engineering', + salary: 75000} + - {id: '2', name: 'Bob', department: 'Marketing', + salary: 65000} + - {id: '3', name: 'Charlie', department: 'Sales'} + - {id: '4', name: 'Diana'} + ''') + + def test_flatten_unifies_complex_mixed_schemas(self): + """Test that Flatten correctly unifies complex mixed + schemas.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: 1, name: 'Product A', price: 29.99, + categories: ['electronics', 'gadgets']} + - {id: 2, name: 'Product B', price: 15.50, + categories: ['books']} + - type: Create + name: Create2 + config: + elements: + - {id: 3, name: 'Product C', categories: ['clothing']} + - {id: 4, name: 'Product D', price: 99.99} + - type: Create + name: Create3 + config: + elements: + - {id: 5, name: 'Product E', price: 5.00, + categories: []} + - type: Flatten + input: [Create1, Create2, Create3] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: 1, name: 'Product A', price: 29.99, + categories: ['electronics', 'gadgets']} + - {id: 2, name: 'Product B', price: 15.50, + categories: ['books']} + - {id: 3, name: 'Product C', categories: ['clothing']} + - {id: 4, name: 'Product D', price: 99.99} + - {id: 5, name: 'Product E', price: 5.00, + categories: []} + ''') + class ErrorHandlingTest(unittest.TestCase): def test_error_handling_outputs(self):