diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 24df59ddc5cc..8b7958a00b80 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -43,6 +43,7 @@ Avro file. """ # pytype: skip-file +import ctypes import os from functools import partial from typing import Any @@ -544,12 +545,35 @@ def close(self, writer): _AvroSchemaType = Union[str, List, Dict] +def avro_union_type_to_beam_type(union_type: List) -> schema_pb2.FieldType: + """convert an avro union type to a beam type + + if the union type is a nullable, and it is a nullable union of an avro + primitive with a corresponding beam primitive then create a nullable beam + field of the corresponding beam type, otherwise return an Any type. + + Args: + union_type: the avro union type to convert. + + Returns: + the beam type of the avro union. + """ + if len(union_type) == 2 and "null" in union_type: + for avro_type in union_type: + if avro_type in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES: + return schema_pb2.FieldType( + atomic_type=AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES[avro_type], + nullable=True) + return schemas.typing_to_runner_api(Any) + return schemas.typing_to_runner_api(Any) + + def avro_type_to_beam_type(avro_type: _AvroSchemaType) -> schema_pb2.FieldType: if isinstance(avro_type, str): return avro_type_to_beam_type({'type': avro_type}) elif isinstance(avro_type, list): # Union type - return schemas.typing_to_runner_api(Any) + return avro_union_type_to_beam_type(avro_type) type_name = avro_type['type'] if type_name in AVRO_PRIMITIVES_TO_BEAM_PRIMITIVES: return schema_pb2.FieldType( @@ -605,11 +629,37 @@ def to_row(record): to_row) +def avro_atomic_value_to_beam_atomic_value(avro_type: str, value): + """convert an avro atomic value to a beam atomic value + + if the avro type is an int or long, convert the value into from signed to + unsigned because VarInt.java expects the number to be unsigned when + decoding the number. + + Args: + avro_type: the avro type of the corresponding value. + value: the avro atomic value. + + Returns: + the converted beam atomic value. + """ + if value is None: + return value + elif avro_type == "int": + return ctypes.c_uint32(value).value + elif avro_type == "long": + return ctypes.c_uint64(value).value + else: + return value + + def avro_value_to_beam_value( beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": - return lambda value: value + avro_type = BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type] + return lambda value: avro_atomic_value_to_beam_atomic_value( + avro_type, value) elif type_info == "array_type": element_converter = avro_value_to_beam_value( beam_type.array_type.element_type) @@ -621,7 +671,7 @@ def avro_value_to_beam_value( elif type_info == "map_type": if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: raise TypeError( - f'Only strings allowd as map keys when converting from AVRO, ' + f'Only strings allowed as map keys when converting from AVRO, ' f'found {beam_type}') value_converter = avro_value_to_beam_value(beam_type.map_type.value_type) return lambda value: {k: value_converter(v) for (k, v) in value.items()} @@ -646,39 +696,63 @@ def beam_schema_to_avro_schema( schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) +def unnest_primitive_type(beam_type: schema_pb2.FieldType): + """unnests beam types that map to avro primitives or unions. + + if mapping to a avro primitive or a union, don't nest the field type + for complex types, like arrays, we need to nest the type. + Example: { 'type': 'string' } -> 'string' + { 'type': 'array', 'items': 'string' } + -> { 'type': 'array', 'items': 'string' } + + Args: + beam_type: the beam type to map to avro. + + Returns: + the converted avro type with the primitive or union type unnested. + """ + avro_type = beam_type_to_avro_type(beam_type) + return avro_type['type'] if beam_type.WhichOneof( + "type_info") == "atomic_type" else avro_type + + def beam_type_to_avro_type(beam_type: schema_pb2.FieldType) -> _AvroSchemaType: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": - return {'type': BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type]} + avro_primitive = BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type] + avro_type = [ + avro_primitive, 'null' + ] if beam_type.nullable else avro_primitive + return {'type': avro_type} elif type_info == "array_type": return { 'type': 'array', - 'items': beam_type_to_avro_type(beam_type.array_type.element_type) + 'items': unnest_primitive_type(beam_type.array_type.element_type) } elif type_info == "iterable_type": return { 'type': 'array', - 'items': beam_type_to_avro_type(beam_type.iterable_type.element_type) + 'items': unnest_primitive_type(beam_type.iterable_type.element_type) } elif type_info == "map_type": if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: raise TypeError( - f'Only strings allowd as map keys when converting to AVRO, ' + f'Only strings allowed as map keys when converting to AVRO, ' f'found {beam_type}') return { 'type': 'map', - 'values': beam_type_to_avro_type(beam_type.map_type.element_type) + 'values': unnest_primitive_type(beam_type.map_type.element_type) } elif type_info == "row_type": return { 'type': 'record', 'name': beam_type.row_type.schema.id, 'fields': [{ - 'name': field.name, 'type': beam_type_to_avro_type(field.type) + 'name': field.name, 'type': unnest_primitive_type(field.type) } for field in beam_type.row_type.schema.fields], } else: - raise ValueError(f"Unconvertale type: {beam_type}") + raise ValueError(f"Unconvertable type: {beam_type}") def beam_row_to_avro_dict( @@ -693,29 +767,55 @@ def beam_row_to_avro_dict( return lambda row: convert(row[0]) +def beam_atomic_value_to_avro_atomic_value(avro_type: str, value): + """convert a beam atomic value to an avro atomic value + + since numeric values are converted to unsigned in + avro_atomic_value_to_beam_atomic_value we need to convert + back to a signed number. + + Args: + avro_type: avro type of the corresponding value. + value: the beam atomic value. + + Returns: + the converted avro atomic value. + """ + if value is None: + return value + elif avro_type == "int": + return ctypes.c_int32(value).value + elif avro_type == "long": + return ctypes.c_int64(value).value + else: + return value + + def beam_value_to_avro_value( beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": - return lambda value: value + avro_type = BEAM_PRIMITIVES_TO_AVRO_PRIMITIVES[beam_type.atomic_type] + return lambda value: beam_atomic_value_to_avro_atomic_value( + avro_type, value) elif type_info == "array_type": - element_converter = avro_value_to_beam_value( + element_converter = beam_value_to_avro_value( beam_type.array_type.element_type) return lambda value: [element_converter(e) for e in value] elif type_info == "iterable_type": - element_converter = avro_value_to_beam_value( + element_converter = beam_value_to_avro_value( beam_type.iterable_type.element_type) return lambda value: [element_converter(e) for e in value] elif type_info == "map_type": if beam_type.map_type.key_type.atomic_type != schema_pb2.STRING: raise TypeError( - f'Only strings allowd as map keys when converting from AVRO, ' + f'Only strings allowed as map keys when converting from AVRO, ' f'found {beam_type}') - value_converter = avro_value_to_beam_value(beam_type.map_type.value_type) + value_converter = beam_value_to_avro_value(beam_type.map_type.value_type) return lambda value: {k: value_converter(v) for (k, v) in value.items()} elif type_info == "row_type": converters = { - field.name: avro_value_to_beam_value(field.type) + field.name: beam_value_to_avro_value(field.type) for field in beam_type.row_type.schema.fields } return lambda value: { diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index c54ac40711b1..c95fbb612592 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -20,31 +20,41 @@ import logging import math import os +import pytest import tempfile import unittest -from typing import List +from typing import List, Any +import fastavro import hamcrest as hc from fastavro.schema import parse_schema from fastavro import writer import apache_beam as beam -from apache_beam import Create +from apache_beam import Create, schema_pb2 from apache_beam.io import avroio from apache_beam.io import filebasedsource from apache_beam.io import iobase from apache_beam.io import source_test_utils from apache_beam.io.avroio import _FastAvroSource # For testing +from apache_beam.io.avroio import avro_schema_to_beam_schema # For testing +from apache_beam.io.avroio import beam_schema_to_avro_schema # For testing +from apache_beam.io.avroio import avro_atomic_value_to_beam_atomic_value # For testing +from apache_beam.io.avroio import avro_union_type_to_beam_type # For testing +from apache_beam.io.avroio import beam_atomic_value_to_avro_atomic_value # For testing from apache_beam.io.avroio import _create_avro_sink # For testing from apache_beam.io.filesystems import FileSystems +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher +from apache_beam.transforms.sql import SqlTransform from apache_beam.transforms.userstate import CombiningValueStateSpec from apache_beam.utils.timestamp import Timestamp +from apache_beam.typehints import schemas # Import snappy optionally; some tests will be skipped when import fails. try: @@ -160,6 +170,76 @@ def test_schema_read_write(self): | beam.Map(stable_repr)) assert_that(readback, equal_to([stable_repr(r) for r in rows])) + @pytest.mark.xlang_sql_expansion_service + @unittest.skipIf( + TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is + None, + "Must be run with a runner that supports staging java artifacts.") + def test_avro_schema_to_beam_schema_with_nullable_atomic_fields(self): + records = [] + records.extend(self.RECORDS) + records.append({ + 'name': 'Bruce', 'favorite_number': None, 'favorite_color': None + }) + with tempfile.TemporaryDirectory() as tmp_dirname_input: + input_path = os.path.join(tmp_dirname_input, 'tmp_filename.avro') + parsed_schema = fastavro.parse_schema(json.loads(self.SCHEMA_STRING)) + with open(input_path, 'wb') as tmp_avro_file: + fastavro.writer(tmp_avro_file, parsed_schema, records) + + with tempfile.TemporaryDirectory() as tmp_dirname_output: + + with TestPipeline() as p: + _ = ( + p + | avroio.ReadFromAvro(input_path, as_rows=True) + | SqlTransform("SELECT * FROM PCOLLECTION") + | avroio.WriteToAvro(tmp_dirname_output)) + with TestPipeline() as p: + readback = (p | avroio.ReadFromAvro(tmp_dirname_output + "*")) + assert_that(readback, equal_to(records)) + + def test_avro_atomic_value_to_beam_atomic_value(self): + input_outputs = [('int', 1, 1), ('int', -1, 0xffffffff), + ('int', None, None), ('long', 1, 1), + ('long', -1, 0xffffffffffffffff), ('long', None, None), + ('string', 'foo', 'foo')] + for test_avro_type, test_value, expected_value in input_outputs: + actual_value = avro_atomic_value_to_beam_atomic_value( + test_avro_type, test_value) + hc.assert_that(actual_value, hc.equal_to(expected_value)) + + def test_beam_atomic_value_to_avro_atomic_value(self): + input_outputs = [('int', 1, 1), ('int', 0xffffffff, -1), + ('int', None, None), ('long', 1, 1), + ('long', 0xffffffffffffffff, -1), ('long', None, None), + ('string', 'foo', 'foo')] + for test_avro_type, test_value, expected_value in input_outputs: + actual_value = beam_atomic_value_to_avro_atomic_value( + test_avro_type, test_value) + hc.assert_that(actual_value, hc.equal_to(expected_value)) + + def test_avro_union_type_to_beam_type_with_nullable_long(self): + union_type = ['null', 'long'] + beam_type = avro_union_type_to_beam_type(union_type) + expected_beam_type = schema_pb2.FieldType( + atomic_type=schema_pb2.INT64, nullable=True) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + + def test_avro_union_type_to_beam_type_with_string_long(self): + union_type = ['string', 'long'] + beam_type = avro_union_type_to_beam_type(union_type) + expected_beam_type = schemas.typing_to_runner_api(Any) + hc.assert_that(beam_type, hc.equal_to(expected_beam_type)) + + def test_avro_schema_to_beam_and_back(self): + avro_schema = fastavro.parse_schema(json.loads(self.SCHEMA_STRING)) + beam_schema = avro_schema_to_beam_schema(avro_schema) + converted_avro_schema = beam_schema_to_avro_schema(beam_schema) + expected_fields = json.loads(self.SCHEMA_STRING)["fields"] + hc.assert_that( + converted_avro_schema["fields"], hc.equal_to(expected_fields)) + def test_read_without_splitting(self): file_name = self._write_data() expected_result = self.RECORDS