Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import traceback
import types
import typing
from collections import defaultdict
from itertools import dropwhile

from apache_beam import coders
Expand Down Expand Up @@ -3962,9 +3963,41 @@ def to_runner_api_parameter(self, context):
def infer_output_type(self, unused_input_type):
if not self.values:
return typehints.Any
return typehints.Union[[
trivial_inference.instance_to_type(v) for v in self.values
]]

# No field data - just use default Union.
if not hasattr(self.values[0], 'as_dict'):
return typehints.Union[[
trivial_inference.instance_to_type(v) for v in self.values
]]

first_fields = self.values[0].as_dict().keys()

# Save field types for each field
field_types_by_field = defaultdict(set)
for row in self.values:
row_dict = row.as_dict()
for field in first_fields:
field_types_by_field[field].add(
trivial_inference.instance_to_type(row_dict.get(field)))

# Determine the appropriate type for each field
final_fields = []
for field in first_fields:
field_types = field_types_by_field[field]
non_none_types = {t for t in field_types if t is not type(None)}

if len(non_none_types) > 1:
final_type = typehints.Union[tuple(non_none_types)]
elif len(non_none_types) == 1 and len(field_types) == 1:
final_type = non_none_types.pop()
elif len(non_none_types) == 1 and len(field_types) == 2:
final_type = typehints.Optional[non_none_types.pop()]
else:
raise TypeError("No types found for field %s", field)

final_fields.append((field, final_type))

return row_type.RowTypeConstraint.from_fields(final_fields)

def get_output_type(self):
return (
Expand Down
31 changes: 31 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from apache_beam.testing.util import equal_to
from apache_beam.transforms.window import FixedWindows
from apache_beam.typehints import TypeCheckError
from apache_beam.typehints import row_type
from apache_beam.typehints import typehints

RETURN_NONE_PARTIAL_WARNING = "No iterator is returned"
Expand Down Expand Up @@ -322,6 +323,36 @@ def test_typecheck_with_default(self):
| beam.Map(lambda s: s.upper()).with_input_types(str))


class CreateInferOutputSchemaTest(unittest.TestCase):
def test_multiple_types_for_field(self):
output_type = beam.Create([beam.Row(a=1),
beam.Row(a='foo')]).infer_output_type(None)
self.assertEqual(
output_type,
row_type.RowTypeConstraint.from_fields([
('a', typehints.Union[int, str])
]))

def test_single_type_for_field(self):
output_type = beam.Create([beam.Row(a=1),
beam.Row(a=2)]).infer_output_type(None)
self.assertEqual(
output_type, row_type.RowTypeConstraint.from_fields([('a', int)]))

def test_optional_type_for_field(self):
output_type = beam.Create([beam.Row(a=1),
beam.Row(a=None)]).infer_output_type(None)
self.assertEqual(
output_type,
row_type.RowTypeConstraint.from_fields([('a', typehints.Optional[int])
]))

def test_none_type_for_field_raises_error(self):
with self.assertRaisesRegex(TypeError,
"('No types found for field %s', 'a')"):
beam.Create([beam.Row(a=None), beam.Row(a=None)]).infer_output_type(None)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
#

fixtures:
- name: BQ_TABLE
- name: BQ_TABLE_0
type: "apache_beam.yaml.integration_tests.temp_bigquery_table"
config:
project: "apache-beam-testing"
- name: TEMP_DIR
- name: TEMP_DIR_0
# Need distributed filesystem to be able to read and write from a container.
type: "apache_beam.yaml.integration_tests.gcs_temp_dir"
config:
bucket: "gs://temp-storage-for-end-to-end-tests/temp-it"
- name: BQ_TABLE_1
type: "apache_beam.yaml.integration_tests.temp_bigquery_table"
config:
project: "apache-beam-testing"
- name: TEMP_DIR_1
# Need distributed filesystem to be able to read and write from a container.
type: "apache_beam.yaml.integration_tests.gcs_temp_dir"
config:
Expand All @@ -38,17 +47,17 @@ pipelines:
- {label: "389a", rank: 2}
- type: WriteToBigQuery
config:
table: "{BQ_TABLE}"
table: "{BQ_TABLE_0}"
options:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR}"
temp_location: "{TEMP_DIR_0}"

- pipeline:
type: chain
transforms:
- type: ReadFromBigQuery
config:
table: "{BQ_TABLE}"
table: "{BQ_TABLE_0}"
- type: AssertEqual
config:
elements:
Expand All @@ -57,14 +66,14 @@ pipelines:
- {label: "389a", rank: 2}
options:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR}"
temp_location: "{TEMP_DIR_0}"

- pipeline:
type: chain
transforms:
- type: ReadFromBigQuery
config:
table: "{BQ_TABLE}"
table: "{BQ_TABLE_0}"
fields: ["label"]
row_restriction: "rank > 0"
- type: AssertEqual
Expand All @@ -74,4 +83,58 @@ pipelines:
- {label: "389a"}
options:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR}"
temp_location: "{TEMP_DIR_0}"

# ----------------------------------------------------------------------------

# New write to verify row restriction based on Timestamp and nullability
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- {label: "4a", rank: 3, timestamp: "2024-07-14 00:00:00 UTC"}
- {label: "5a", rank: 4}
- {label: "6a", rank: 5, timestamp: "2024-07-14T02:00:00.123Z"}
- type: WriteToBigQuery
config:
table: "{BQ_TABLE_1}"

# New read from BQ to verify row restriction with nullable field and filter
# out nullable record
- pipeline:
type: chain
transforms:
- type: ReadFromBigQuery
config:
table: "{BQ_TABLE_1}"
fields: ["label","rank","timestamp"]
row_restriction: "TIMESTAMP(timestamp) <= TIMESTAMP_SUB('2025-07-14 04:00:00', INTERVAL 4 HOUR)"
- type: AssertEqual
config:
elements:
- {label: "4a", rank: 3, timestamp: "2024-07-14 00:00:00 UTC"}
- {label: "6a", rank: 5,timestamp: "2024-07-14T02:00:00.123Z"}
options:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR_1}"

# New read from BQ to verify row restriction with nullable field and keep
# nullable record
- pipeline:
type: chain
transforms:
- type: ReadFromBigQuery
config:
table: "{BQ_TABLE_1}"
fields: ["timestamp", "label", "rank"]
row_restriction: "timestamp is NULL"
- type: AssertEqual
config:
elements:
- {label: "5a", rank: 4}
options:
project: "apache-beam-testing"
temp_location: "{TEMP_DIR_1}"

34 changes: 34 additions & 0 deletions sdks/python/apache_beam/yaml/tests/create.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,37 @@ pipelines:
- {sdk: MapReduce, year: 2004}
- {sdk: Flume}
- {sdk: MillWheel, year: 2008}

# Simple Create with explicit null value
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- {sdk: MapReduce, year: 2004}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@damccorm this is related to what we discussed. We already have ValidateWithSchema for YAML, which can help validate the schema. But we do not have a way to define the schema. For example, here, year could be NULL but the inferred schema from Beam is Optional[Any], which cannot be accepted by WriteToBigQuery since WriteToBigQuery does not allow any type. We should do something like ValidateWithSchema that allows users to define the output schema.

Copy link
Collaborator Author

@derrickaw derrickaw Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @damccorm offline and he thought it was a good idea with adding an output schema to create and most other transforms with bad records going to error handling. I will open an issue for this (#35742). Thanks.

- {sdk: Flume, year: null}
- {sdk: MillWheel, year: 2008}
- type: AssertEqual
config:
elements:
- {sdk: MapReduce, year: 2004}
- {sdk: Flume, year: null}
- {sdk: MillWheel, year: 2008}

# Simple Create with explicit null values for the entire record
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- {sdk: MapReduce, year: 2004}
- {sdk: null, year: null}
- {sdk: MillWheel, year: 2008}
- type: AssertEqual
config:
elements:
- {sdk: MapReduce, year: 2004}
- {}
- {sdk: MillWheel, year: 2008}
Loading