Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,34 @@ def expand(self, pcolls):
else:
pipeline_arg = {'pipeline': pcolls.pipeline}
pcolls = ()

# Check schema compatibility for non-empty collections
# Skip validation for auto-generated flattens
# (those with "-Flatten[" in the label)
if len(pcolls) > 1 and not (hasattr(self, 'label') and self.label and
'-Flatten[' in self.label):
from apache_beam.typehints import schemas
schemas_to_check = []
for pcoll in pcolls:
if hasattr(pcoll, 'element_type') and pcoll.element_type:
try:
schema = schemas.schema_from_element_type(pcoll.element_type)
schemas_to_check.append(schema)
except TypeError:
# Skip PCollections without schema
pass

# If we have schemas to check, ensure they are all the same
if len(schemas_to_check) > 1:
first_schema = schemas_to_check[0]
for i, schema in enumerate(schemas_to_check[1:], 1):
if schema != first_schema:
raise RuntimeError(
f"Cannot flatten PCollections with different schemas. "
f"PCollection 0 has schema {first_schema}, "
f"but PCollection {i} has schema {schema}. "
"All PCollections must have the same schema.")

return pcolls | beam.Flatten(**pipeline_arg)

class WindowInto(beam.PTransform):
Expand Down
145 changes: 122 additions & 23 deletions sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_chain_with_root(self):
providers=TEST_PROVIDERS)
assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131]))

def create_has_schema(self):
def test_create_has_schema(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
Expand Down Expand Up @@ -217,6 +217,104 @@ def test_implicit_flatten(self):
providers=TEST_PROVIDERS)
assert_that(result, equal_to([1, 4, 9, 10000, 40000]))

def test_flatten_different_schemas_error(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
with self.assertRaisesRegex(
Exception, r"Cannot flatten PCollections with different schemas"):
_ = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {'ride_id': '1', 'passenger_count': 1}
- {'ride_id': '2', 'passenger_count': 2}
- type: Create
name: Create2
config:
elements:
- {'ride_id': '3'}
- {'ride_id': '4'}
- type: Flatten
name: Flatten1
input:
- Create1
- Create2
output: Flatten1
''',
providers=TEST_PROVIDERS)

def test_flatten_compatible_schemas_success(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {'ride_id': '1', 'passenger_count': 1}
- {'ride_id': '2', 'passenger_count': 2}
- type: Create
name: Create2
config:
elements:
- {'ride_id': '3', 'passenger_count': 3}
- {'ride_id': '4', 'passenger_count': 4}
- type: Flatten
name: Flatten1
input:
- Create1
- Create2
output: Flatten1
''',
providers=TEST_PROVIDERS)
# This should not raise an error since the schemas are identical
assert_that(
result,
equal_to([
beam.Row(ride_id='1', passenger_count=1),
beam.Row(ride_id='2', passenger_count=2),
beam.Row(ride_id='3', passenger_count=3),
beam.Row(ride_id='4', passenger_count=4)
]))

def test_flatten_with_null_values_error(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# This should raise an error because null values create different schema types
# (nullable logical type vs INT64)
with self.assertRaisesRegex(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@damccorm I do not like this but Beam treat this case with different schemas since passenger_count should be INT64 not None. So we report the error now with this validation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting - one option would be to define our own deep equality function here (basically a consistency check), or we could try to move forward with #35672

I don't think we can move forward with this change if it regresses these use cases

ValueError, r"Cannot flatten PCollections with different schemas"):
p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: Create1
config:
elements:
- {'ride_id': '1', 'passenger_count': 1}
- {'ride_id': '2', 'passenger_count': 2}
- type: Create
name: Create2
config:
elements:
- {'ride_id': '3', 'passenger_count': null}
- {'ride_id': '4', 'passenger_count': null}
- type: Flatten
name: Flatten1
input:
- Create2
- Create1
''',
providers=TEST_PROVIDERS)

def test_csv_to_json(self):
try:
import pandas as pd
Expand Down Expand Up @@ -330,28 +428,29 @@ def test_name_is_ambiguous(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
# pylint: disable=expression-not-assigned
with self.assertRaisesRegex(ValueError, r'Circular reference detected.*'):
p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: CreateData
config:
elements: [0, 1, 3, 4]
- type: PyMap
name: PyMap
config:
fn: "lambda elem: elem + 2"
input: CreateData
- type: PyMap
name: AnotherMap
config:
fn: "lambda elem: elem + 3"
input: PyMap
output: AnotherMap
''',
providers=TEST_PROVIDERS)
result = p | YamlTransform(
'''
type: composite
transforms:
- type: Create
name: CreateData
config:
elements: [0, 1, 3, 4]
- type: PyMap
name: PyMap
config:
fn: "lambda elem: elem + 2"
input: CreateData
- type: PyMap
name: AnotherMap
config:
fn: "lambda elem: elem + 3"
input: PyMap
output: AnotherMap
''',
providers=TEST_PROVIDERS)
# This should work correctly without circular reference
assert_that(result, equal_to([5, 6, 8, 9]))

def test_empty_inputs_throws_error(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
Expand Down
Loading