diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index fa8b56f916dc..82ae9a50ace4 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -119,7 +119,12 @@ def process(self, row, w=DoFn.WindowParam, pane=DoFn.PaneInfoParam): # reorder the data in columnar format. for i, n in enumerate(self._schema.names): - self._buffer[i].append(row[n]) + # Handle missing nullable fields by using None as default value + field = self._schema.field(i) + if field.nullable and n not in row: + self._buffer[i].append(None) + else: + self._buffer[i].append(row[n]) def finish_bundle(self): if len(self._buffer[0]) > 0: diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index 9371705a1fa3..78d1db4cc7c2 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -59,12 +59,11 @@ try: import pyarrow as pa import pyarrow.parquet as pq + ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.')) except ImportError: pa = None - pl = None pq = None - -ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.')) + ARROW_MAJOR_VERSION = 0 @unittest.skipIf(pa is None, "PyArrow is not installed.") @@ -422,6 +421,76 @@ def test_schema_read_write(self): | Map(stable_repr)) assert_that(readback, equal_to([stable_repr(r) for r in rows])) + def test_write_with_nullable_fields_missing_data(self): + """Test WriteToParquet with nullable fields where some fields are missing. + + This test addresses the bug reported in: + https://github.com/apache/beam/issues/35791 + where WriteToParquet fails with a KeyError if any nullable + field is missing in the data. + """ + # Define PyArrow schema with all fields nullable + schema = pa.schema([ + pa.field("id", pa.int64(), nullable=True), + pa.field("name", pa.string(), nullable=True), + pa.field("age", pa.int64(), nullable=True), + pa.field("email", pa.string(), nullable=True), + ]) + + # Sample data with missing nullable fields + data = [ + { + 'id': 1, 'name': 'Alice', 'age': 30 + }, # missing 'email' + { + 'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com' + }, # all fields present + { + 'id': 3, 'name': 'Charlie', 'age': None, 'email': None + }, # explicit None values + { + 'id': 4, 'name': 'David' + }, # missing 'age' and 'email' + ] + + with TemporaryDirectory() as tmp_dirname: + path = os.path.join(tmp_dirname, 'nullable_test') + + # Write data with missing nullable fields - this should not raise KeyError + with TestPipeline() as p: + _ = ( + p + | Create(data) + | WriteToParquet( + path, schema, num_shards=1, shard_name_template='')) + + # Read back and verify the data + with TestPipeline() as p: + readback = ( + p + | ReadFromParquet(path + '*') + | Map(json.dumps, sort_keys=True)) + + # Expected data should have None for missing nullable fields + expected_data = [ + { + 'id': 1, 'name': 'Alice', 'age': 30, 'email': None + }, + { + 'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com' + }, + { + 'id': 3, 'name': 'Charlie', 'age': None, 'email': None + }, + { + 'id': 4, 'name': 'David', 'age': None, 'email': None + }, + ] + + assert_that( + readback, + equal_to([json.dumps(r, sort_keys=True) for r in expected_data])) + def test_batched_read(self): with TemporaryDirectory() as tmp_dirname: path = os.path.join(tmp_dirname + "tmp_filename")