Skip to content
Closed
146 changes: 145 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,129 @@ def __init__(self):
# pylint: disable=useless-parent-delegation
super().__init__()

def _unify_field_types(self, existing_type, field_type):
"""Unify two field types, handling Optional and List types."""
# Extract inner types from Optional if needed
existing_inner = (
existing_type.__args__[0] if hasattr(existing_type, '__args__') and
len(existing_type.__args__) == 1 else existing_type)
field_inner = (
field_type.__args__[0] if hasattr(field_type, '__args__') and
len(field_type.__args__) == 1 else field_type)

# Handle type unification more carefully
if existing_inner == Any or field_inner == Any:
return Optional[Any]
elif existing_inner == field_inner:
return Optional[existing_inner]
Comment on lines +938 to +949
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this logic unify Iterable[str], str to Optional[str] since Dict also has args of length 1? I think we want to actually check if the outer type is Optional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current way is to prioritize List, so it should be Optional[Iterable[str]]. _unify_element_with_schema does this conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - could you add a comment here explaining that this function expects all iterables to already be coerced to lists? I agree we do that correctly below, but it is probably a little brittle to rely on this without being explicit about the API

else:
# Check for list types and prioritize them over other types
from apache_beam.typehints import typehints as th
existing_is_list = (
hasattr(existing_inner, '__origin__') and
existing_inner.__origin__ in (list, th.List))
field_is_list = (
hasattr(field_inner, '__origin__') and
field_inner.__origin__ in (list, th.List))

if existing_is_list and field_is_list:
# Both are list types, unify their element types
existing_elem = existing_inner.__args__[
0] if existing_inner.__args__ else Any
field_elem = field_inner.__args__[0] if field_inner.__args__ else Any
if existing_elem == field_elem:
return Optional[th.List[existing_elem]]
else:
return Optional[th.List[Any]]
elif existing_is_list:
# Existing is list, keep it as list type
return Optional[existing_inner]
elif field_is_list:
# New field is list, use list type
return Optional[field_inner]
Comment on lines +969 to +974
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not just Optional[Union[existing_inner, field_inner]]? Isn't either list or single element valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see we're trying to avoid Union types. Probably this just needs to be encoded as Any then, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is what line 978 does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but if I'm unifying List[int] and int, right now it unifies to Optional[List[int]], right? But that isn't right if I'm flattening {foo: 1} and {foo: [1,2,3]}

else:
# Neither is a list, use Any to avoid unsupported Union
# types in schema translation
return Optional[Any]
Comment on lines +935 to +978
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of _unify_field_types has several issues that can lead to incorrect schema unification and potential runtime errors:

  1. The logic to extract inner types from Optional is incorrect. It uses len(type.__args__) == 1, which is false for typing.Optional[T] (which is an alias for Union[T, NoneType], having 2 type arguments).
  2. The same logic incorrectly treats list-like types such as th.List[T] as their element type T, because th.List[T].__args__ has a length of 1.
  3. The *_is_list checks are unreliable because they operate on these incorrectly "unwrapped" types.

This can lead to incorrect schema inference, for example, treating a list field as a primitive, or failing to correctly unify list types.

I suggest replacing the method with a more robust implementation using typing.get_origin and typing.get_args for type introspection. This will correctly handle Optional types and list types.

from typing import get_origin, get_args

    def _unify_field_types(self, existing_type, field_type):
      """Unify two field types, handling Optional and List types."""

      existing_origin = get_origin(existing_type) or existing_type
      field_origin = get_origin(field_type) or field_type

      existing_inner = get_args(existing_type)[0] if existing_origin is Optional else existing_type
      field_inner = get_args(field_type)[0] if field_origin is Optional else field_type

      if existing_inner == Any or field_inner == Any:
        return Optional[Any]
      elif existing_inner == field_inner:
        return Optional[existing_inner]
      else:
        existing_is_list = existing_origin in (list, th.List)
        field_is_list = field_origin in (list, th.List)

        if existing_is_list and field_is_list:
          existing_elem = get_args(existing_inner)[0] if get_args(existing_inner) else Any
          field_elem = get_args(field_inner)[0] if get_args(field_inner) else Any
          if existing_elem == field_elem:
            return Optional[th.List[existing_elem]]
          else:
            return Optional[th.List[Any]]
        elif existing_is_list:
          return Optional[existing_inner]
        elif field_is_list:
          return Optional[field_inner]
        else:
          return Optional[Any]


def _merge_schemas(self, pcolls):
"""Merge schemas from multiple PCollections to create a unified schema.

This function creates a unified schema that contains all fields from all
input PCollections. Fields are made optional to handle missing values.
"""
from apache_beam.typehints.schemas import named_fields_from_element_type

# Collect all schemas
schemas = []
for pcoll in pcolls:
if hasattr(pcoll, 'element_type') and pcoll.element_type:
try:
fields = named_fields_from_element_type(pcoll.element_type)
schemas.append(dict(fields))
except (ValueError, TypeError):
# If we can't extract schema, skip this PCollection
continue

if not schemas:
return None

# Merge all field names and types, making them optional
all_fields = {}
for schema in schemas:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that there aren't conflicting types here (e.g. pcoll1 wants 'foo': int, pcoll2 wants 'foo': str)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would yield a Union type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the tests below to validate this works with _unify_field_types by treating them as Optional[Any] to simply the logics for Flatten given the Union could be a very long list (e.g., Optional[Union[int, str, list,....]]). Probably very hard to handle the nested structures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to handle nested structures? Could we just say given:

pcoll1: {'foo': TypeA}
pcoll2: {'foo': TypeB}

outPcoll: {'foo': Union[TypeA, TypeB]}

and ignore the nested representations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_unify_field_types for now at least does not use Union. Whenever two types are different, it uses Optional[Any]. I have some bit concerns about how accurate we need to infer the schema (e.g., stop at the list level like what you suggest or just do the simplest one my PR uses). I also think we should support specifying the schema and then it will make no sense for us to unify the schemas with our rules.

for field_name, field_type in schema.items():
if field_name in all_fields:
# If field exists with different type, use Union
existing_type = all_fields[field_name]
if existing_type != field_type:
all_fields[field_name] = self._unify_field_types(
existing_type, field_type)
else:
# Make field optional since not all PCollections may have it
all_fields[field_name] = Optional[field_type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we keep track of when one of these schema difference conditions is hit and warn?


# Create unified schema
if all_fields:
from apache_beam.typehints.schemas import named_fields_to_schema
from apache_beam.typehints.schemas import named_tuple_from_schema
unified_schema = named_fields_to_schema(list(all_fields.items()))
return named_tuple_from_schema(unified_schema)

return None

def _unify_element_with_schema(self, element, target_schema):
"""Convert an element to match the target schema, preserving existing
fields only."""
if target_schema is None:
return element

# If element is already a named tuple, convert to dict first
if hasattr(element, '_asdict'):
element_dict = element._asdict()
elif isinstance(element, dict):
element_dict = element
else:
return element

# Create new element with only the fields that exist in the original
# element plus None for fields that are expected but missing
unified_dict = {}
for field_name in target_schema._fields:
if field_name in element_dict:
value = element_dict[field_name]
# Ensure the value matches the expected type
# This is particularly important for list fields
if value is not None and not isinstance(value, list) and hasattr(
value, '__iter__') and not isinstance(value, (str, bytes)):
# Convert iterables to lists if needed
unified_dict[field_name] = list(value)
else:
unified_dict[field_name] = value
else:
unified_dict[field_name] = None

return target_schema(**unified_dict)

def expand(self, pcolls):
if isinstance(pcolls, beam.PCollection):
pipeline_arg = {}
Expand All @@ -942,7 +1065,28 @@ def expand(self, pcolls):
else:
pipeline_arg = {'pipeline': pcolls.pipeline}
pcolls = ()
return pcolls | beam.Flatten(**pipeline_arg)

if not pcolls:
return pcolls | beam.Flatten(**pipeline_arg)

# Try to unify schemas
unified_schema = self._merge_schemas(pcolls)

if unified_schema is None:
# No schema unification needed, use standard flatten
return pcolls | beam.Flatten(**pipeline_arg)

# Apply schema unification to each PCollection
unified_pcolls = []
for i, pcoll in enumerate(pcolls):
unified_pcoll = pcoll | f'UnifySchema{i}' >> beam.Map(
lambda element, schema=unified_schema: self.
_unify_element_with_schema(element, schema)).with_output_types(
unified_schema)
unified_pcolls.append(unified_pcoll)

# Flatten the unified PCollections
return unified_pcolls | beam.Flatten(**pipeline_arg)

class WindowInto(beam.PTransform):
# pylint: disable=line-too-long
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,8 @@ def preprocess_flattened_inputs(spec):
def all_inputs(t):
for key, values in t.get('input', {}).items():
if isinstance(values, list):
for ix, values in enumerate(values):
yield f'{key}{ix}', values
for ix, value in enumerate(values):
yield f'{key}{ix}', value
else:
yield key, values

Expand Down
Loading
Loading