diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index 42087565013c..07a9971dd6da 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -932,6 +932,129 @@ def __init__(self): # pylint: disable=useless-parent-delegation super().__init__() + def _unify_field_types(self, existing_type, field_type): + """Unify two field types, handling Optional and List types.""" + # Extract inner types from Optional if needed + existing_inner = ( + existing_type.__args__[0] if hasattr(existing_type, '__args__') and + len(existing_type.__args__) == 1 else existing_type) + field_inner = ( + field_type.__args__[0] if hasattr(field_type, '__args__') and + len(field_type.__args__) == 1 else field_type) + + # Handle type unification more carefully + if existing_inner == Any or field_inner == Any: + return Optional[Any] + elif existing_inner == field_inner: + return Optional[existing_inner] + else: + # Check for list types and prioritize them over other types + from apache_beam.typehints import typehints as th + existing_is_list = ( + hasattr(existing_inner, '__origin__') and + existing_inner.__origin__ in (list, th.List)) + field_is_list = ( + hasattr(field_inner, '__origin__') and + field_inner.__origin__ in (list, th.List)) + + if existing_is_list and field_is_list: + # Both are list types, unify their element types + existing_elem = existing_inner.__args__[ + 0] if existing_inner.__args__ else Any + field_elem = field_inner.__args__[0] if field_inner.__args__ else Any + if existing_elem == field_elem: + return Optional[th.List[existing_elem]] + else: + return Optional[th.List[Any]] + elif existing_is_list: + # Existing is list, keep it as list type + return Optional[existing_inner] + elif field_is_list: + # New field is list, use list type + return Optional[field_inner] + else: + # Neither is a list, use Any to avoid unsupported Union + # types in schema translation + return Optional[Any] + + def _merge_schemas(self, pcolls): + """Merge schemas from multiple PCollections to create a unified schema. + + This function creates a unified schema that contains all fields from all + input PCollections. Fields are made optional to handle missing values. + """ + from apache_beam.typehints.schemas import named_fields_from_element_type + + # Collect all schemas + schemas = [] + for pcoll in pcolls: + if hasattr(pcoll, 'element_type') and pcoll.element_type: + try: + fields = named_fields_from_element_type(pcoll.element_type) + schemas.append(dict(fields)) + except (ValueError, TypeError): + # If we can't extract schema, skip this PCollection + continue + + if not schemas: + return None + + # Merge all field names and types, making them optional + all_fields = {} + for schema in schemas: + for field_name, field_type in schema.items(): + if field_name in all_fields: + # If field exists with different type, use Union + existing_type = all_fields[field_name] + if existing_type != field_type: + all_fields[field_name] = self._unify_field_types( + existing_type, field_type) + else: + # Make field optional since not all PCollections may have it + all_fields[field_name] = Optional[field_type] + + # Create unified schema + if all_fields: + from apache_beam.typehints.schemas import named_fields_to_schema + from apache_beam.typehints.schemas import named_tuple_from_schema + unified_schema = named_fields_to_schema(list(all_fields.items())) + return named_tuple_from_schema(unified_schema) + + return None + + def _unify_element_with_schema(self, element, target_schema): + """Convert an element to match the target schema, preserving existing + fields only.""" + if target_schema is None: + return element + + # If element is already a named tuple, convert to dict first + if hasattr(element, '_asdict'): + element_dict = element._asdict() + elif isinstance(element, dict): + element_dict = element + else: + return element + + # Create new element with only the fields that exist in the original + # element plus None for fields that are expected but missing + unified_dict = {} + for field_name in target_schema._fields: + if field_name in element_dict: + value = element_dict[field_name] + # Ensure the value matches the expected type + # This is particularly important for list fields + if value is not None and not isinstance(value, list) and hasattr( + value, '__iter__') and not isinstance(value, (str, bytes)): + # Convert iterables to lists if needed + unified_dict[field_name] = list(value) + else: + unified_dict[field_name] = value + else: + unified_dict[field_name] = None + + return target_schema(**unified_dict) + def expand(self, pcolls): if isinstance(pcolls, beam.PCollection): pipeline_arg = {} @@ -942,7 +1065,28 @@ def expand(self, pcolls): else: pipeline_arg = {'pipeline': pcolls.pipeline} pcolls = () - return pcolls | beam.Flatten(**pipeline_arg) + + if not pcolls: + return pcolls | beam.Flatten(**pipeline_arg) + + # Try to unify schemas + unified_schema = self._merge_schemas(pcolls) + + if unified_schema is None: + # No schema unification needed, use standard flatten + return pcolls | beam.Flatten(**pipeline_arg) + + # Apply schema unification to each PCollection + unified_pcolls = [] + for i, pcoll in enumerate(pcolls): + unified_pcoll = pcoll | f'UnifySchema{i}' >> beam.Map( + lambda element, schema=unified_schema: self. + _unify_element_with_schema(element, schema)).with_output_types( + unified_schema) + unified_pcolls.append(unified_pcoll) + + # Flatten the unified PCollections + return unified_pcolls | beam.Flatten(**pipeline_arg) class WindowInto(beam.PTransform): # pylint: disable=line-too-long diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 8710fe379c37..341cd8d65f41 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -861,8 +861,8 @@ def preprocess_flattened_inputs(spec): def all_inputs(t): for key, values in t.get('input', {}).items(): if isinstance(values, list): - for ix, values in enumerate(values): - yield f'{key}{ix}', values + for ix, value in enumerate(values): + yield f'{key}{ix}', value else: yield key, values diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index 1a99507d76d7..ee98581346c8 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -477,6 +477,238 @@ def test_composite_resource_hints(self): b'1000000000', proto) + def test_flatten_unifies_schemas(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {ride_id: '1', passenger_count: 1} + - {ride_id: '2', passenger_count: 2} + - type: Create + name: Create2 + config: + elements: + - {ride_id: '3'} + - {ride_id: '4'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {ride_id: '1', passenger_count: 1} + - {ride_id: '2', passenger_count: 2} + - {ride_id: '3'} + - {ride_id: '4'} + ''') + + def test_flatten_unifies_optional_fields(self): + """Test that Flatten correctly unifies schemas with optional fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', name: 'Alice', age: 30} + - {id: '2', name: 'Bob', age: 25} + - type: Create + name: Create2 + config: + elements: + - {id: '3', name: 'Charlie'} + - {id: '4', name: 'Diana'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', name: 'Alice', age: 30} + - {id: '2', name: 'Bob', age: 25} + - {id: '3', name: 'Charlie'} + - {id: '4', name: 'Diana'} + ''') + + def test_flatten_unifies_different_types(self): + """Test that Flatten correctly unifies schemas with different + field types.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: 1, value: 100} + - {id: 2, value: 200} + - type: Create + name: Create2 + config: + elements: + - {id: '3', value: 'text'} + - {id: '4', value: 'data'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: 1, value: 100} + - {id: 2, value: 200} + - {id: '3', value: 'text'} + - {id: '4', value: 'data'} + ''') + + def test_flatten_unifies_list_fields(self): + """Test that Flatten correctly unifies schemas with list fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', tags: ['red', 'blue']} + - {id: '2', tags: ['green']} + - type: Create + name: Create2 + config: + elements: + - {id: '3', tags: ['yellow', 'purple', 'orange']} + - {id: '4', tags: []} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', tags: ['red', 'blue']} + - {id: '2', tags: ['green']} + - {id: '3', tags: ['yellow', 'purple', 'orange']} + - {id: '4', tags: []} + ''') + + def test_flatten_unifies_with_missing_fields(self): + """Test that Flatten correctly unifies schemas when some inputs have + missing fields.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + _ = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: '1', name: 'Alice', department: 'Engineering', + salary: 75000} + - {id: '2', name: 'Bob', department: 'Marketing', + salary: 65000} + - type: Create + name: Create2 + config: + elements: + - {id: '3', name: 'Charlie', department: 'Sales'} + - {id: '4', name: 'Diana'} + - type: Flatten + input: [Create1, Create2] + - type: AssertEqual + input: Flatten + config: + elements: + - {id: '1', name: 'Alice', department: 'Engineering', + salary: 75000} + - {id: '2', name: 'Bob', department: 'Marketing', + salary: 65000} + - {id: '3', name: 'Charlie', department: 'Sales'} + - {id: '4', name: 'Diana'} + ''') + + def test_flatten_unifies_complex_mixed_schemas(self): + """Test that Flatten correctly unifies complex mixed + schemas.""" + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + name: Create1 + config: + elements: + - {id: 1, name: 'Product A', price: 29.99, + categories: ['electronics', 'gadgets']} + - {id: 2, name: 'Product B', price: 15.50, + categories: ['books']} + - type: Create + name: Create2 + config: + elements: + - {id: 3, name: 'Product C', categories: ['clothing']} + - {id: 4, name: 'Product D', price: 99.99} + - type: Create + name: Create3 + config: + elements: + - {id: 5, name: 'Product E', price: 5.00, + categories: []} + - type: Flatten + input: [Create1, Create2, Create3] + output: Flatten + ''') + + # Verify that the result contains all expected elements + # with proper schema unification + def check_result(actual): + expected_ids = {1, 2, 3, 4, 5} + actual_ids = { + getattr(row, 'id', row.get('id') if hasattr(row, 'get') else None) + for row in actual + } + assert actual_ids == expected_ids, ( + f"Expected IDs {expected_ids}, got {actual_ids}") + + # Check that all rows have required fields + for row in actual: + row_id = getattr( + row, 'id', row.get('id') if hasattr(row, 'get') else None) + name = getattr( + row, 'name', row.get('name') if hasattr(row, 'get') else None) + assert row_id is not None, f"Missing id field in row {row}" + assert name is not None, f"Missing name field in row {row}" + # Optional fields should be present but may be None/empty + price = getattr( + row, 'price', row.get('price') if hasattr(row, 'get') else None) + categories = getattr( + row, + 'categories', + row.get('categories') if hasattr(row, 'get') else None) + assert price is not None or row_id == 3, \ + f"Missing price field in row {row}" + assert categories is not None or row_id == 4, \ + f"Missing categories field in row {row}" + + assert_that(result, check_result) + class ErrorHandlingTest(unittest.TestCase): def test_error_handling_outputs(self):