diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 6e0170c04ea7..1de1506159ef 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -29,6 +29,7 @@ import traceback import types import typing +from collections import defaultdict from itertools import dropwhile from apache_beam import coders @@ -3962,9 +3963,41 @@ def to_runner_api_parameter(self, context): def infer_output_type(self, unused_input_type): if not self.values: return typehints.Any - return typehints.Union[[ - trivial_inference.instance_to_type(v) for v in self.values - ]] + + # No field data - just use default Union. + if not hasattr(self.values[0], 'as_dict'): + return typehints.Union[[ + trivial_inference.instance_to_type(v) for v in self.values + ]] + + first_fields = self.values[0].as_dict().keys() + + # Save field types for each field + field_types_by_field = defaultdict(set) + for row in self.values: + row_dict = row.as_dict() + for field in first_fields: + field_types_by_field[field].add( + trivial_inference.instance_to_type(row_dict.get(field))) + + # Determine the appropriate type for each field + final_fields = [] + for field in first_fields: + field_types = field_types_by_field[field] + non_none_types = {t for t in field_types if t is not type(None)} + + if len(non_none_types) > 1: + final_type = typehints.Union[tuple(non_none_types)] + elif len(non_none_types) == 1 and len(field_types) == 1: + final_type = non_none_types.pop() + elif len(non_none_types) == 1 and len(field_types) == 2: + final_type = typehints.Optional[non_none_types.pop()] + else: + raise TypeError("No types found for field %s", field) + + final_fields.append((field, final_type)) + + return row_type.RowTypeConstraint.from_fields(final_fields) def get_output_type(self): return ( diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 542544bce3c1..57fa21517349 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -31,6 +31,7 @@ from apache_beam.testing.util import equal_to from apache_beam.transforms.window import FixedWindows from apache_beam.typehints import TypeCheckError +from apache_beam.typehints import row_type from apache_beam.typehints import typehints RETURN_NONE_PARTIAL_WARNING = "No iterator is returned" @@ -322,6 +323,36 @@ def test_typecheck_with_default(self): | beam.Map(lambda s: s.upper()).with_input_types(str)) +class CreateInferOutputSchemaTest(unittest.TestCase): + def test_multiple_types_for_field(self): + output_type = beam.Create([beam.Row(a=1), + beam.Row(a='foo')]).infer_output_type(None) + self.assertEqual( + output_type, + row_type.RowTypeConstraint.from_fields([ + ('a', typehints.Union[int, str]) + ])) + + def test_single_type_for_field(self): + output_type = beam.Create([beam.Row(a=1), + beam.Row(a=2)]).infer_output_type(None) + self.assertEqual( + output_type, row_type.RowTypeConstraint.from_fields([('a', int)])) + + def test_optional_type_for_field(self): + output_type = beam.Create([beam.Row(a=1), + beam.Row(a=None)]).infer_output_type(None) + self.assertEqual( + output_type, + row_type.RowTypeConstraint.from_fields([('a', typehints.Optional[int]) + ])) + + def test_none_type_for_field_raises_error(self): + with self.assertRaisesRegex(TypeError, + "('No types found for field %s', 'a')"): + beam.Create([beam.Row(a=None), beam.Row(a=None)]).infer_output_type(None) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/extended_tests/databases/bigquery.yaml b/sdks/python/apache_beam/yaml/extended_tests/databases/bigquery.yaml index f5ab31b3855b..d0357e098bf3 100644 --- a/sdks/python/apache_beam/yaml/extended_tests/databases/bigquery.yaml +++ b/sdks/python/apache_beam/yaml/extended_tests/databases/bigquery.yaml @@ -16,11 +16,20 @@ # fixtures: - - name: BQ_TABLE + - name: BQ_TABLE_0 type: "apache_beam.yaml.integration_tests.temp_bigquery_table" config: project: "apache-beam-testing" - - name: TEMP_DIR + - name: TEMP_DIR_0 + # Need distributed filesystem to be able to read and write from a container. + type: "apache_beam.yaml.integration_tests.gcs_temp_dir" + config: + bucket: "gs://temp-storage-for-end-to-end-tests/temp-it" + - name: BQ_TABLE_1 + type: "apache_beam.yaml.integration_tests.temp_bigquery_table" + config: + project: "apache-beam-testing" + - name: TEMP_DIR_1 # Need distributed filesystem to be able to read and write from a container. type: "apache_beam.yaml.integration_tests.gcs_temp_dir" config: @@ -38,17 +47,17 @@ pipelines: - {label: "389a", rank: 2} - type: WriteToBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE_0}" options: project: "apache-beam-testing" - temp_location: "{TEMP_DIR}" + temp_location: "{TEMP_DIR_0}" - pipeline: type: chain transforms: - type: ReadFromBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE_0}" - type: AssertEqual config: elements: @@ -57,14 +66,14 @@ pipelines: - {label: "389a", rank: 2} options: project: "apache-beam-testing" - temp_location: "{TEMP_DIR}" + temp_location: "{TEMP_DIR_0}" - pipeline: type: chain transforms: - type: ReadFromBigQuery config: - table: "{BQ_TABLE}" + table: "{BQ_TABLE_0}" fields: ["label"] row_restriction: "rank > 0" - type: AssertEqual @@ -74,4 +83,58 @@ pipelines: - {label: "389a"} options: project: "apache-beam-testing" - temp_location: "{TEMP_DIR}" + temp_location: "{TEMP_DIR_0}" + + # ---------------------------------------------------------------------------- + + # New write to verify row restriction based on Timestamp and nullability + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {label: "4a", rank: 3, timestamp: "2024-07-14 00:00:00 UTC"} + - {label: "5a", rank: 4} + - {label: "6a", rank: 5, timestamp: "2024-07-14T02:00:00.123Z"} + - type: WriteToBigQuery + config: + table: "{BQ_TABLE_1}" + + # New read from BQ to verify row restriction with nullable field and filter + # out nullable record + - pipeline: + type: chain + transforms: + - type: ReadFromBigQuery + config: + table: "{BQ_TABLE_1}" + fields: ["label","rank","timestamp"] + row_restriction: "TIMESTAMP(timestamp) <= TIMESTAMP_SUB('2025-07-14 04:00:00', INTERVAL 4 HOUR)" + - type: AssertEqual + config: + elements: + - {label: "4a", rank: 3, timestamp: "2024-07-14 00:00:00 UTC"} + - {label: "6a", rank: 5,timestamp: "2024-07-14T02:00:00.123Z"} + options: + project: "apache-beam-testing" + temp_location: "{TEMP_DIR_1}" + + # New read from BQ to verify row restriction with nullable field and keep + # nullable record + - pipeline: + type: chain + transforms: + - type: ReadFromBigQuery + config: + table: "{BQ_TABLE_1}" + fields: ["timestamp", "label", "rank"] + row_restriction: "timestamp is NULL" + - type: AssertEqual + config: + elements: + - {label: "5a", rank: 4} + options: + project: "apache-beam-testing" + temp_location: "{TEMP_DIR_1}" + diff --git a/sdks/python/apache_beam/yaml/tests/create.yaml b/sdks/python/apache_beam/yaml/tests/create.yaml index 723d8a888c26..30f276671874 100644 --- a/sdks/python/apache_beam/yaml/tests/create.yaml +++ b/sdks/python/apache_beam/yaml/tests/create.yaml @@ -81,3 +81,37 @@ pipelines: - {sdk: MapReduce, year: 2004} - {sdk: Flume} - {sdk: MillWheel, year: 2008} + + # Simple Create with explicit null value + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: Flume, year: null} + - {sdk: MillWheel, year: 2008} + - type: AssertEqual + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: Flume, year: null} + - {sdk: MillWheel, year: 2008} + + # Simple Create with explicit null values for the entire record + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - {sdk: MapReduce, year: 2004} + - {sdk: null, year: null} + - {sdk: MillWheel, year: 2008} + - type: AssertEqual + config: + elements: + - {sdk: MapReduce, year: 2004} + - {} + - {sdk: MillWheel, year: 2008}