Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,43 @@ def dicts_to_rows(o):
return o


def _unify_element_with_schema(element, target_schema):
"""Convert an element to match the target schema, preserving existing
fields only."""
if target_schema is None:
return element

# If element is already a named tuple, convert to dict first
if hasattr(element, '_asdict'):
element_dict = element._asdict()
elif isinstance(element, dict):
element_dict = element
else:
# This element is not a row, so it can't be unified with a
# row schema.
return element

# Create new element with only the fields that exist in the original
# element plus None for fields that are expected but missing
unified_dict = {}
for field_name in target_schema._fields:
if field_name in element_dict:
value = element_dict[field_name]
# Ensure the value matches the expected type
# This is particularly important for list fields
if value is not None and not isinstance(value, list) and hasattr(
value, '__iter__') and not isinstance(
value, (str, bytes)) and not hasattr(value, '_asdict'):
# Convert iterables to lists if needed
unified_dict[field_name] = list(value)
else:
unified_dict[field_name] = value
else:
unified_dict[field_name] = None

return target_schema(**unified_dict)


class YamlProviders:
class AssertEqual(beam.PTransform):
"""Asserts that the input contains exactly the elements provided.
Expand Down Expand Up @@ -932,6 +969,48 @@ def __init__(self):
# pylint: disable=useless-parent-delegation
super().__init__()

def _merge_schemas(self, pcolls):
"""Merge schemas from multiple PCollections to create a unified schema.

This function creates a unified schema that contains all fields from all
input PCollections. Fields are made optional to handle missing values.
If fields have different types, they are unified to Optional[Any].
"""
from apache_beam.typehints.schemas import named_fields_from_element_type

# Collect all schemas
schemas = []
for pcoll in pcolls:
if hasattr(pcoll, 'element_type') and pcoll.element_type:
try:
fields = named_fields_from_element_type(pcoll.element_type)
schemas.append(dict(fields))
except (ValueError, TypeError):
# If we can't extract schema, skip this PCollection
continue

if not schemas:
return None

# Merge all field names and types.
all_field_names = set().union(*(s.keys() for s in schemas))
unified_fields = {}
for name in all_field_names:
present_types = {s[name] for s in schemas if name in s}
if len(present_types) > 1:
unified_fields[name] = Optional[Any]
else:
unified_fields[name] = Optional[present_types.pop()]

# Create unified schema
if unified_fields:
from apache_beam.typehints.schemas import named_fields_to_schema
from apache_beam.typehints.schemas import named_tuple_from_schema
unified_schema = named_fields_to_schema(list(unified_fields.items()))
return named_tuple_from_schema(unified_schema)

return None

def expand(self, pcolls):
if isinstance(pcolls, beam.PCollection):
pipeline_arg = {}
Expand All @@ -942,7 +1021,27 @@ def expand(self, pcolls):
else:
pipeline_arg = {'pipeline': pcolls.pipeline}
pcolls = ()
return pcolls | beam.Flatten(**pipeline_arg)

if not pcolls:
return pcolls | beam.Flatten(**pipeline_arg)

# Try to unify schemas
unified_schema = self._merge_schemas(pcolls)

if unified_schema is None:
# No schema unification needed, use standard flatten
return pcolls | beam.Flatten(**pipeline_arg)

# Apply schema unification to each PCollection before flattening.
unified_pcolls = []
for i, pcoll in enumerate(pcolls):
unified_pcoll = pcoll | f'UnifySchema{i}' >> beam.Map(
_unify_element_with_schema,
target_schema=unified_schema).with_output_types(unified_schema)
unified_pcolls.append(unified_pcoll)

# Flatten the unified PCollections
return unified_pcolls | beam.Flatten(**pipeline_arg)

class WindowInto(beam.PTransform):
# pylint: disable=line-too-long
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,8 @@ def preprocess_flattened_inputs(spec):
def all_inputs(t):
for key, values in t.get('input', {}).items():
if isinstance(values, list):
for ix, values in enumerate(values):
yield f'{key}{ix}', values
for ix, value in enumerate(values):
yield f'{key}{ix}', value
else:
yield key, values

Expand Down
210 changes: 210 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,216 @@ def test_composite_resource_hints(self):
b'1000000000',
proto)

def test_flatten_unifies_schemas(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {ride_id: '1', passenger_count: 1}
- {ride_id: '2', passenger_count: 2}
- type: Create
name: Create2
config:
elements:
- {ride_id: '3'}
- {ride_id: '4'}
- type: Flatten
input: [Create1, Create2]
- type: AssertEqual
input: Flatten
config:
elements:
- {ride_id: '1', passenger_count: 1}
- {ride_id: '2', passenger_count: 2}
- {ride_id: '3'}
- {ride_id: '4'}
''')

def test_flatten_unifies_optional_fields(self):
"""Test that Flatten correctly unifies schemas with optional fields."""
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {id: '1', name: 'Alice', age: 30}
- {id: '2', name: 'Bob', age: 25}
- type: Create
name: Create2
config:
elements:
- {id: '3', name: 'Charlie'}
- {id: '4', name: 'Diana'}
- type: Flatten
input: [Create1, Create2]
- type: AssertEqual
input: Flatten
config:
elements:
- {id: '1', name: 'Alice', age: 30}
- {id: '2', name: 'Bob', age: 25}
- {id: '3', name: 'Charlie'}
- {id: '4', name: 'Diana'}
''')

def test_flatten_unifies_different_types(self):
"""Test that Flatten correctly unifies schemas with different
field types."""
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {id: 1, value: 100}
- {id: 2, value: 200}
- type: Create
name: Create2
config:
elements:
- {id: '3', value: 'text'}
- {id: '4', value: 'data'}
- type: Flatten
input: [Create1, Create2]
- type: AssertEqual
input: Flatten
config:
elements:
- {id: 1, value: 100}
- {id: 2, value: 200}
- {id: '3', value: 'text'}
- {id: '4', value: 'data'}
''')

def test_flatten_unifies_list_fields(self):
"""Test that Flatten correctly unifies schemas with list fields."""
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {id: '1', tags: ['red', 'blue']}
- {id: '2', tags: ['green']}
- type: Create
name: Create2
config:
elements:
- {id: '3', tags: ['yellow', 'purple', 'orange']}
- {id: '4', tags: []}
- type: Flatten
input: [Create1, Create2]
- type: AssertEqual
input: Flatten
config:
elements:
- {id: '1', tags: ['red', 'blue']}
- {id: '2', tags: ['green']}
- {id: '3', tags: ['yellow', 'purple', 'orange']}
- {id: '4', tags: []}
''')

def test_flatten_unifies_with_missing_fields(self):
"""Test that Flatten correctly unifies schemas when some inputs have
missing fields."""
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {id: '1', name: 'Alice', department: 'Engineering',
salary: 75000}
- {id: '2', name: 'Bob', department: 'Marketing',
salary: 65000}
- type: Create
name: Create2
config:
elements:
- {id: '3', name: 'Charlie', department: 'Sales'}
- {id: '4', name: 'Diana'}
- type: Flatten
input: [Create1, Create2]
- type: AssertEqual
input: Flatten
config:
elements:
- {id: '1', name: 'Alice', department: 'Engineering',
salary: 75000}
- {id: '2', name: 'Bob', department: 'Marketing',
salary: 65000}
- {id: '3', name: 'Charlie', department: 'Sales'}
- {id: '4', name: 'Diana'}
''')

def test_flatten_unifies_complex_mixed_schemas(self):
"""Test that Flatten correctly unifies complex mixed
schemas."""
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {id: 1, name: 'Product A', price: 29.99,
categories: ['electronics', 'gadgets']}
- {id: 2, name: 'Product B', price: 15.50,
categories: ['books']}
- type: Create
name: Create2
config:
elements:
- {id: 3, name: 'Product C', categories: ['clothing']}
- {id: 4, name: 'Product D', price: 99.99}
- type: Create
name: Create3
config:
elements:
- {id: 5, name: 'Product E', price: 5.00,
categories: []}
- type: Flatten
input: [Create1, Create2, Create3]
- type: AssertEqual
input: Flatten
config:
elements:
- {id: 1, name: 'Product A', price: 29.99,
categories: ['electronics', 'gadgets']}
- {id: 2, name: 'Product B', price: 15.50,
categories: ['books']}
- {id: 3, name: 'Product C', categories: ['clothing']}
- {id: 4, name: 'Product D', price: 99.99}
- {id: 5, name: 'Product E', price: 5.00,
categories: []}
''')


class ErrorHandlingTest(unittest.TestCase):
def test_error_handling_outputs(self):
Expand Down
Loading