From ffa074866a5797eb29f2270ae4894a92184a697f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 4 Oct 2023 16:53:34 +0200 Subject: [PATCH 1/8] Construct a writer tree For V1 and V2 there are some differences that are hard to enforce without this: - `1: snapshot_id` is required for V1, optional for V2 - `105: block_size_in_bytes` needs to be written for V1, but omitted for V2 (this leverages the `write-default`). - `3: sequence_number` and `4: file_sequence_number` can be omited for V1. Everything that we read, we map it to V2. However, when writing we also want to be compliant with the V1 spec, and this is where the writer tree comes in since we construct a tree for V1 or V2. --- pyiceberg/avro/file.py | 25 +- pyiceberg/avro/resolver.py | 136 +++++++++- pyiceberg/avro/writer.py | 36 ++- pyiceberg/manifest.py | 382 +++++++++++++++++------------ pyiceberg/schema.py | 2 +- pyiceberg/types.py | 3 + tests/avro/test_decoder.py | 4 +- tests/avro/test_file.py | 89 ++++++- tests/avro/test_reader.py | 4 +- tests/avro/test_resolver.py | 85 +++++-- tests/test_integration_manifest.py | 1 - tests/utils/test_manifest.py | 21 +- 12 files changed, 566 insertions(+), 222 deletions(-) diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index dc843f6dc0..995b211927 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -38,7 +38,7 @@ from pyiceberg.avro.decoder import BinaryDecoder, new_decoder from pyiceberg.avro.encoder import BinaryEncoder from pyiceberg.avro.reader import Reader -from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve +from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve_reader, resolve_writer from pyiceberg.avro.writer import Writer from pyiceberg.io import InputFile, OutputFile, OutputStream from pyiceberg.schema import Schema @@ -172,7 +172,7 @@ def __enter__(self) -> AvroFile[D]: if not self.read_schema: self.read_schema = self.schema - self.reader = resolve(self.schema, self.read_schema, self.read_types, self.read_enums) + self.reader = resolve_reader(self.schema, self.read_schema, self.read_types, self.read_enums) return self @@ -222,18 +222,28 @@ def _read_header(self) -> AvroFileHeader: class AvroOutputFile(Generic[D]): output_file: OutputFile output_stream: OutputStream - schema: Schema + file_schema: Schema schema_name: str encoder: BinaryEncoder sync_bytes: bytes writer: Writer - def __init__(self, output_file: OutputFile, schema: Schema, schema_name: str, metadata: Dict[str, str] = EMPTY_DICT) -> None: + def __init__( + self, + output_file: OutputFile, + file_schema: Schema, + schema_name: str, + schema: Optional[Schema] = None, + metadata: Dict[str, str] = EMPTY_DICT, + ) -> None: self.output_file = output_file - self.schema = schema + self.file_schema = file_schema self.schema_name = schema_name self.sync_bytes = os.urandom(SYNC_SIZE) - self.writer = construct_writer(self.schema) + if schema is None: + self.writer = construct_writer(self.file_schema) + else: + self.writer = resolve_writer(self.file_schema, schema) self.metadata = metadata def __enter__(self) -> AvroOutputFile[D]: @@ -247,7 +257,6 @@ def __enter__(self) -> AvroOutputFile[D]: self.encoder = BinaryEncoder(self.output_stream) self._write_header() - self.writer = construct_writer(self.schema) return self @@ -258,7 +267,7 @@ def __exit__( self.output_stream.close() def _write_header(self) -> None: - json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.schema, schema_name=self.schema_name)) + json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name)) meta = {**self.metadata, _SCHEMA_KEY: json_schema, _CODEC_KEY: "null"} header = AvroFileHeader(magic=MAGIC, meta=meta, sync=self.sync_bytes) construct_writer(META_SCHEMA).write(self.encoder, header) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 8b2daeb7c7..905704b18f 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -53,12 +53,14 @@ BooleanWriter, DateWriter, DecimalWriter, + DefaultWriter, DoubleWriter, FixedWriter, FloatWriter, IntegerWriter, ListWriter, MapWriter, + NoneWriter, OptionWriter, StringWriter, StructWriter, @@ -112,11 +114,12 @@ def construct_reader( Args: file_schema (Schema | IcebergType): The schema of the Avro file. + read_types (Dict[int, Callable[..., StructProtocol]]): Constructors for structs for certain field-ids Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - return resolve(file_schema, file_schema, read_types) + return resolve_reader(file_schema, file_schema, read_types) def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: @@ -128,7 +131,7 @@ def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - return visit(file_schema, ConstructWriter()) + return visit(file_schema, CONSTRUCT_WRITER_VISITOR) class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]): @@ -138,7 +141,7 @@ def schema(self, schema: Schema, struct_result: Writer) -> Writer: return struct_result def struct(self, struct: StructType, field_results: List[Writer]) -> Writer: - return StructWriter(tuple(field_results)) + return StructWriter(tuple((pos, result) for pos, result in enumerate(field_results))) def field(self, field: NestedField, field_result: Writer) -> Writer: return field_result if field.required else OptionWriter(field_result) @@ -192,7 +195,26 @@ def visit_binary(self, binary_type: BinaryType) -> Writer: return BinaryWriter() -def resolve( +CONSTRUCT_WRITER_VISITOR = ConstructWriter() + + +def resolve_writer( + struct_schema: Union[Schema, IcebergType], + write_schema: Union[Schema, IcebergType], +) -> Writer: + """Resolve the file and read schema to produce a reader. + + Args: + struct_schema (Schema | IcebergType): The schema of the Avro file. + write_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. + + Raises: + NotImplementedError: If attempting to resolve an unrecognized object type. + """ + return visit_with_partner(struct_schema, write_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore + + +def resolve_reader( file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, @@ -210,7 +232,7 @@ def resolve( NotImplementedError: If attempting to resolve an unrecognized object type. """ return visit_with_partner( - file_schema, read_schema, SchemaResolver(read_types, read_enums), SchemaPartnerAccessor() + file_schema, read_schema, ReadSchemaResolver(read_types, read_enums), SchemaPartnerAccessor() ) # type: ignore @@ -233,7 +255,107 @@ def skip(self, decoder: BinaryDecoder) -> None: pass -class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): +class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): + def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Writer) -> Writer: + return result + + def struct(self, struct: StructType, provided_struct: Optional[IcebergType], field_writers: List[Writer]) -> Writer: + if not isinstance(provided_struct, StructType): + raise ResolveError(f"File/write schema are not aligned for struct, got {provided_struct}") + + provided_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(provided_struct.fields)} + + results: List[Tuple[Optional[int], Writer]] = [] + iter(field_writers) + + for pos, write_field in enumerate(struct.fields): + if write_field.field_id in provided_struct_positions: + results.append((provided_struct_positions[write_field.field_id], field_writers[pos])) + else: + # There is a default value + if isinstance(write_field, NestedField) and write_field.write_default is not None: + # The field is not in the record, but there is a write default value + default_writer = DefaultWriter( + writer=visit(write_field.field_type, CONSTRUCT_WRITER_VISITOR), value=write_field.write_default + ) + results.append((None, default_writer)) + elif write_field.required: + raise ValueError(f"Field is required, and there is no write default: {write_field}") + else: + results.append((pos, NoneWriter())) + + return StructWriter(field_writers=tuple(results)) + + def field(self, field: NestedField, expected_field: Optional[IcebergType], field_writer: Writer) -> Writer: + return field_writer if field.required else OptionWriter(field_writer) + + def list(self, list_type: ListType, expected_list: Optional[IcebergType], element_reader: Writer) -> Writer: + if expected_list and not isinstance(expected_list, ListType): + raise ResolveError(f"File/read schema are not aligned for list, got {expected_list}") + + return ListWriter(element_reader if list_type.element_required else OptionWriter(element_reader)) + + def map(self, map_type: MapType, expected_map: Optional[IcebergType], key_reader: Writer, value_reader: Writer) -> Writer: + if expected_map and not isinstance(expected_map, MapType): + raise ResolveError(f"File/read schema are not aligned for map, got {expected_map}") + + return MapWriter(key_reader, value_reader if map_type.value_required else OptionWriter(value_reader)) + + def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[IcebergType]) -> Writer: + if expected_primitive is not None: + if not isinstance(expected_primitive, PrimitiveType): + raise ResolveError(f"File/read schema are not aligned for {primitive}, got {expected_primitive}") + + # ensure that the type can be projected to the expected + if primitive != expected_primitive: + promote(primitive, expected_primitive) + + return super().primitive(primitive, expected_primitive) + + def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: + return BooleanWriter() + + def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer: + return IntegerWriter() + + def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer: + return IntegerWriter() + + def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer: + return FloatWriter() + + def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer: + return DoubleWriter() + + def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer: + return DecimalWriter(decimal_type.precision, decimal_type.scale) + + def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer: + return DateWriter() + + def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer: + return TimeWriter() + + def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer: + return TimestampWriter() + + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer: + return TimestamptzWriter() + + def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer: + return StringWriter() + + def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer: + return UUIDWriter() + + def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer: + return FixedWriter(len(fixed_type)) + + def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer: + return BinaryWriter() + + +class ReadSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): __slots__ = ("read_types", "read_enums", "context") read_types: Dict[int, Callable[..., StructProtocol]] read_enums: Dict[int, Callable[..., Enum]] @@ -279,7 +401,7 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie for field, result_reader in zip(struct.fields, field_readers) ] - file_fields = {field.field_id: field for field in struct.fields} + file_fields = {field.field_id for field in struct.fields} for pos, read_field in enumerate(expected_struct.fields): if read_field.field_id not in file_fields: if isinstance(read_field, NestedField) and read_field.initial_default is not None: diff --git a/pyiceberg/avro/writer.py b/pyiceberg/avro/writer.py index ad6a755614..fbb3de62be 100644 --- a/pyiceberg/avro/writer.py +++ b/pyiceberg/avro/writer.py @@ -29,6 +29,7 @@ Any, Dict, List, + Optional, Tuple, ) from uuid import UUID @@ -39,6 +40,7 @@ from pyiceberg.utils.singleton import Singleton +@dataclass(frozen=True) class Writer(Singleton): @abstractmethod def write(self, encoder: BinaryEncoder, val: Any) -> Any: @@ -49,16 +51,19 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" +@dataclass(frozen=True) class NoneWriter(Writer): - def write(self, _: BinaryEncoder, __: Any) -> None: - pass + def write(self, encoder: BinaryEncoder, __: Any) -> None: + encoder.write_int(0) +@dataclass(frozen=True) class BooleanWriter(Writer): def write(self, encoder: BinaryEncoder, val: bool) -> None: encoder.write_boolean(val) +@dataclass(frozen=True) class IntegerWriter(Writer): """Longs and ints are encoded the same way, and there is no long in Python.""" @@ -66,41 +71,49 @@ def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class FloatWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_float(val) +@dataclass(frozen=True) class DoubleWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_double(val) +@dataclass(frozen=True) class DateWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimeWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimestampWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimestamptzWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class StringWriter(Writer): def write(self, encoder: BinaryEncoder, val: Any) -> None: encoder.write_utf8(val) +@dataclass(frozen=True) class UUIDWriter(Writer): def write(self, encoder: BinaryEncoder, val: UUID) -> None: encoder.write(val.bytes) @@ -124,6 +137,7 @@ def __repr__(self) -> str: return f"FixedWriter({self._len})" +@dataclass(frozen=True) class BinaryWriter(Writer): """Variable byte length writer.""" @@ -158,11 +172,12 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None: @dataclass(frozen=True) class StructWriter(Writer): - field_writers: Tuple[Writer, ...] = dataclassfield() + field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield() def write(self, encoder: BinaryEncoder, val: Record) -> None: - for writer, value in zip(self.field_writers, val.record_fields()): - writer.write(encoder, value) + for pos, writer in self.field_writers: + # When pos is None, then it is a default value + writer.write(encoder, val[pos] if pos is not None else None) def __eq__(self, other: Any) -> bool: """Implement the equality operator for this object.""" @@ -170,7 +185,7 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """Return string representation of this object.""" - return f"StructWriter({','.join(repr(field) for field in self.field_writers)})" + return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))" def __hash__(self) -> int: """Return the hash of the writer as hash of this object.""" @@ -201,3 +216,12 @@ def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None: self.value_writer.write(encoder, v) if len(val) > 0: encoder.write_int(0) + + +@dataclass(frozen=True) +class DefaultWriter(Writer): + writer: Writer + value: Any + + def write(self, encoder: BinaryEncoder, _: Any) -> None: + self.writer.write(encoder, self.value) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 8bdbfd3524..efeceac7e7 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -58,6 +58,7 @@ UNASSIGNED_SEQ = -1 DEFAULT_BLOCK_SIZE = 67108864 # 64 * 1024 * 1024 +DEFAULT_READ_VERSION: Literal[2] = 2 class DataFileContent(int, Enum): @@ -99,101 +100,185 @@ def __repr__(self) -> str: return f"FileFormat.{self.name}" -DATA_FILE_TYPE_V1 = StructType( - NestedField( - field_id=134, - name="content", - field_type=IntegerType(), - required=False, - doc="Contents of the file: 0=data, 1=position deletes, 2=equality deletes", - initial_default=DataFileContent.DATA, +DATA_FILE_TYPE: Dict[int, StructType] = { + 1: StructType( + NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), + NestedField( + field_id=101, + name="file_format", + field_type=StringType(), + required=True, + doc="File format name: avro, orc, or parquet", + ), + NestedField( + field_id=102, + name="partition", + field_type=StructType(), + required=True, + doc="Partition data tuple, schema based on the partition spec", + ), + NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), + NestedField( + field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes" + ), + NestedField( + field_id=105, + name="block_size_in_bytes", + field_type=LongType(), + required=True, + doc="Deprecated. Always write a default in v1. Do not write in v2.", + write_default=DEFAULT_BLOCK_SIZE, + ), + NestedField( + field_id=108, + name="column_sizes", + field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), + required=False, + doc="Map of column id to total size on disk", + ), + NestedField( + field_id=109, + name="value_counts", + field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), + required=False, + doc="Map of column id to total count, including null and NaN", + ), + NestedField( + field_id=110, + name="null_value_counts", + field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), + required=False, + doc="Map of column id to null value count", + ), + NestedField( + field_id=137, + name="nan_value_counts", + field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), + required=False, + doc="Map of column id to number of NaN values in the column", + ), + NestedField( + field_id=125, + name="lower_bounds", + field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), + required=False, + doc="Map of column id to lower bound", + ), + NestedField( + field_id=128, + name="upper_bounds", + field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), + required=False, + doc="Map of column id to upper bound", + ), + NestedField( + field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob" + ), + NestedField( + field_id=132, + name="split_offsets", + field_type=ListType(element_id=133, element_type=LongType(), element_required=True), + required=False, + doc="Splittable offsets", + ), + NestedField(field_id=140, name="sort_order_id", field_type=IntegerType(), required=False, doc="Sort order ID"), ), - NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), - NestedField( - field_id=101, - name="file_format", - field_type=StringType(), - required=True, - doc="File format name: avro, orc, or parquet", + 2: StructType( + NestedField( + field_id=134, + name="content", + field_type=IntegerType(), + required=True, + doc="File format name: avro, orc, or parquet", + initial_default=DataFileContent.DATA, + ), + NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), + NestedField( + field_id=101, + name="file_format", + field_type=StringType(), + required=True, + doc="File format name: avro, orc, or parquet", + ), + NestedField( + field_id=102, + name="partition", + field_type=StructType(), + required=True, + doc="Partition data tuple, schema based on the partition spec", + ), + NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), + NestedField( + field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes" + ), + NestedField( + field_id=108, + name="column_sizes", + field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), + required=False, + doc="Map of column id to total size on disk", + ), + NestedField( + field_id=109, + name="value_counts", + field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), + required=False, + doc="Map of column id to total count, including null and NaN", + ), + NestedField( + field_id=110, + name="null_value_counts", + field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), + required=False, + doc="Map of column id to null value count", + ), + NestedField( + field_id=137, + name="nan_value_counts", + field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), + required=False, + doc="Map of column id to number of NaN values in the column", + ), + NestedField( + field_id=125, + name="lower_bounds", + field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), + required=False, + doc="Map of column id to lower bound", + ), + NestedField( + field_id=128, + name="upper_bounds", + field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), + required=False, + doc="Map of column id to upper bound", + ), + NestedField( + field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob" + ), + NestedField( + field_id=132, + name="split_offsets", + field_type=ListType(element_id=133, element_type=LongType(), element_required=True), + required=False, + doc="Splittable offsets", + ), + NestedField( + field_id=135, + name="equality_ids", + field_type=ListType(element_id=136, element_type=LongType(), element_required=True), + required=False, + doc="Field ids used to determine row equality in equality delete files.", + ), + NestedField( + field_id=140, + name="sort_order_id", + field_type=IntegerType(), + required=False, + doc=" ID representing sort order for this file", + ), ), - NestedField( - field_id=102, - name="partition", - field_type=StructType(), - required=True, - doc="Partition data tuple, schema based on the partition spec", - ), - NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), - NestedField(field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes"), - NestedField( - field_id=105, - name="block_size_in_bytes", - field_type=LongType(), - required=False, - doc="Deprecated. Always write a default in v1. Do not write in v2.", - ), - NestedField( - field_id=108, - name="column_sizes", - field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), - required=False, - doc="Map of column id to total size on disk", - ), - NestedField( - field_id=109, - name="value_counts", - field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), - required=False, - doc="Map of column id to total count, including null and NaN", - ), - NestedField( - field_id=110, - name="null_value_counts", - field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), - required=False, - doc="Map of column id to null value count", - ), - NestedField( - field_id=137, - name="nan_value_counts", - field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), - required=False, - doc="Map of column id to number of NaN values in the column", - ), - NestedField( - field_id=125, - name="lower_bounds", - field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), - required=False, - doc="Map of column id to lower bound", - ), - NestedField( - field_id=128, - name="upper_bounds", - field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), - required=False, - doc="Map of column id to upper bound", - ), - NestedField(field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob"), - NestedField( - field_id=132, - name="split_offsets", - field_type=ListType(element_id=133, element_type=LongType(), element_required=True), - required=False, - doc="Splittable offsets", - ), - NestedField( - field_id=135, - name="equality_ids", - field_type=ListType(element_id=136, element_type=LongType(), element_required=True), - required=False, - doc="Equality comparison field IDs", - ), - NestedField(field_id=140, name="sort_order_id", field_type=IntegerType(), required=False, doc="Sort order ID"), - NestedField(field_id=141, name="spec_id", field_type=IntegerType(), required=False, doc="Partition spec ID"), -) - -DATA_FILE_TYPE_V2 = StructType(*[field for field in DATA_FILE_TYPE_V1.fields if field.field_id != 105]) +} @singledispatch @@ -238,7 +323,7 @@ def data_file_with_partition(partition_type: StructType, format_version: Literal ) if field.field_id == 102 else field - for field in (DATA_FILE_TYPE_V1.fields if format_version == 1 else DATA_FILE_TYPE_V2.fields) + for field in DATA_FILE_TYPE[format_version].fields ] ) @@ -251,7 +336,6 @@ class DataFile(Record): "partition", "record_count", "file_size_in_bytes", - "block_size_in_bytes", "column_sizes", "value_counts", "null_value_counts", @@ -262,7 +346,6 @@ class DataFile(Record): "split_offsets", "equality_ids", "sort_order_id", - "spec_id", ) content: DataFileContent file_path: str @@ -270,7 +353,6 @@ class DataFile(Record): partition: Record record_count: int file_size_in_bytes: int - block_size_in_bytes: Optional[int] column_sizes: Dict[int, int] value_counts: Dict[int, int] null_value_counts: Dict[int, int] @@ -281,7 +363,6 @@ class DataFile(Record): split_offsets: Optional[List[int]] equality_ids: Optional[List[int]] sort_order_id: Optional[int] - spec_id: Optional[int] def __setattr__(self, name: str, value: Any) -> None: """Assign a key/value to a DataFile.""" @@ -290,10 +371,10 @@ def __setattr__(self, name: str, value: Any) -> None: value = FileFormat[value] super().__setattr__(name, value) - def __init__(self, format_version: Literal[1, 2] = 1, *data: Any, **named_data: Any) -> None: + def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None: super().__init__( *data, - **{"struct": DATA_FILE_TYPE_V1 if format_version == 1 else DATA_FILE_TYPE_V2, **named_data}, + **{"struct": DATA_FILE_TYPE[format_version], **named_data}, ) def __hash__(self) -> int: @@ -308,22 +389,29 @@ def __eq__(self, other: Any) -> bool: return self.file_path == other.file_path if isinstance(other, DataFile) else False -MANIFEST_ENTRY_SCHEMA = Schema( - NestedField(0, "status", IntegerType(), required=True), - NestedField(1, "snapshot_id", LongType(), required=False), - NestedField(3, "data_sequence_number", LongType(), required=False), - NestedField(4, "file_sequence_number", LongType(), required=False), - NestedField(2, "data_file", DATA_FILE_TYPE_V1, required=True), -) +MANIFEST_ENTRY_SCHEMAS = { + 1: Schema( + NestedField(0, "status", IntegerType(), required=True), + NestedField(1, "snapshot_id", LongType(), required=True), + NestedField(2, "data_file", DATA_FILE_TYPE[1], required=True), + ), + 2: Schema( + NestedField(0, "status", IntegerType(), required=True), + NestedField(1, "snapshot_id", LongType(), required=False), + NestedField(3, "data_sequence_number", LongType(), required=False), + NestedField(4, "file_sequence_number", LongType(), required=False), + NestedField(2, "data_file", DATA_FILE_TYPE[2], required=True), + ), +} -MANIFEST_ENTRY_SCHEMA_STRUCT = MANIFEST_ENTRY_SCHEMA.as_struct() +MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()} -def manifest_entry_schema_with_data_file(data_file: StructType) -> Schema: +def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema: return Schema( *[ NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field - for field in MANIFEST_ENTRY_SCHEMA.fields + for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields ] ) @@ -337,7 +425,7 @@ class ManifestEntry(Record): data_file: DataFile def __init__(self, *data: Any, **named_data: Any) -> None: - super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMA_STRUCT, **named_data}) + super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMAS_STRUCT[DEFAULT_READ_VERSION], **named_data}) PARTITION_FIELD_SUMMARY_TYPE = StructType( @@ -489,7 +577,7 @@ def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> List input_file = io.new_input(self.manifest_path) with AvroFile[ManifestEntry]( input_file, - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], read_types={-1: ManifestEntry, 2: DataFile}, read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent}, ) as reader: @@ -603,10 +691,26 @@ def __exit__( def content(self) -> ManifestContent: ... + @property @abstractmethod - def new_writer(self) -> AvroOutputFile[ManifestEntry]: + def version(self) -> Literal[1, 2]: ... + def _with_partition(self, format_version: Literal[1, 2]) -> Schema: + data_file_type = data_file_with_partition( + format_version=format_version, partition_type=self._spec.partition_type(self._schema) + ) + return manifest_entry_schema_with_data_file(format_version=format_version, data_file=data_file_type) + + def new_writer(self) -> AvroOutputFile[ManifestEntry]: + return AvroOutputFile[ManifestEntry]( + output_file=self._output_file, + file_schema=self._with_partition(self.version), + schema=self._with_partition(DEFAULT_READ_VERSION) if self.version != DEFAULT_READ_VERSION else None, + schema_name="manifest_entry", + metadata=self._meta, + ) + @abstractmethod def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: ... @@ -678,15 +782,12 @@ def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, def content(self) -> ManifestContent: return ManifestContent.DATA - def new_writer(self) -> AvroOutputFile[ManifestEntry]: - v1_data_file_type = data_file_with_partition(self._spec.partition_type(self._schema), format_version=1) - v1_manifest_entry_schema = manifest_entry_schema_with_data_file(v1_data_file_type) - return AvroOutputFile[ManifestEntry](self._output_file, v1_manifest_entry_schema, "manifest_entry", self._meta) + @property + def version(self) -> Literal[1, 2]: + return 1 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: - wrapped_entry = ManifestEntry(*entry.record_fields()) - wrapped_entry.data_file.block_size_in_bytes = DEFAULT_BLOCK_SIZE - return wrapped_entry + return entry class ManifestWriterV2(ManifestWriter): @@ -708,10 +809,9 @@ def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, def content(self) -> ManifestContent: return ManifestContent.DATA - def new_writer(self) -> AvroOutputFile[ManifestEntry]: - v2_data_file_type = data_file_with_partition(self._spec.partition_type(self._schema), format_version=2) - v2_manifest_entry_schema = manifest_entry_schema_with_data_file(v2_data_file_type) - return AvroOutputFile[ManifestEntry](self._output_file, v2_manifest_entry_schema, "manifest_entry", self._meta) + @property + def version(self) -> Literal[1, 2]: + return 2 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: if entry.data_sequence_number is None: @@ -719,35 +819,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: raise ValueError(f"Found unassigned sequence number for an entry from snapshot: {entry.snapshot_id}") if entry.status != ManifestEntryStatus.ADDED: raise ValueError("Only entries with status ADDED can have null sequence number") - # In v2, we should not write block_size_in_bytes field - wrapped_data_file_v2_debug = DataFile( - format_version=2, - content=entry.data_file.content, - file_path=entry.data_file.file_path, - file_format=entry.data_file.file_format, - partition=entry.data_file.partition, - record_count=entry.data_file.record_count, - file_size_in_bytes=entry.data_file.file_size_in_bytes, - column_sizes=entry.data_file.column_sizes, - value_counts=entry.data_file.value_counts, - null_value_counts=entry.data_file.null_value_counts, - nan_value_counts=entry.data_file.nan_value_counts, - lower_bounds=entry.data_file.lower_bounds, - upper_bounds=entry.data_file.upper_bounds, - key_metadata=entry.data_file.key_metadata, - split_offsets=entry.data_file.split_offsets, - equality_ids=entry.data_file.equality_ids, - sort_order_id=entry.data_file.sort_order_id, - spec_id=entry.data_file.spec_id, - ) - wrapped_entry = ManifestEntry( - status=entry.status, - snapshot_id=entry.snapshot_id, - data_sequence_number=entry.data_sequence_number, - file_sequence_number=entry.file_sequence_number, - data_file=wrapped_data_file_v2_debug, - ) - return wrapped_entry + return entry def write_manifest( @@ -775,7 +847,9 @@ def __init__(self, output_file: OutputFile, meta: Dict[str, str]): def __enter__(self) -> ManifestListWriter: """Open the writer for writing.""" - self._writer = AvroOutputFile[ManifestFile](self._output_file, MANIFEST_FILE_SCHEMA, "manifest_file", self._meta) + self._writer = AvroOutputFile[ManifestFile]( + output_file=self._output_file, file_schema=MANIFEST_FILE_SCHEMA, schema_name="manifest_file", metadata=self._meta + ) self._writer.__enter__() return self diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 28101809c7..dd795a3076 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1143,7 +1143,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]): """A schema visitor for generating a field ID to accessor index. Example: - >>> from pyiceberg.schema import Schema + >>> from pyiceberg.file_schema import Schema >>> from pyiceberg.types import * >>> schema = Schema( ... NestedField(field_id=2, name="id", field_type=IntegerType(), required=False), diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 12ea831f08..715959602a 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -282,6 +282,7 @@ class NestedField(IcebergType): required: bool = Field(default=True) doc: Optional[str] = Field(default=None, repr=False) initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) + write_default: Optional[Any] = Field(alias="write-default", default=None, repr=False) def __init__( self, @@ -291,6 +292,7 @@ def __init__( required: bool = True, doc: Optional[str] = None, initial_default: Optional[Any] = None, + write_default: Optional[Any] = None, **data: Any, ): # We need an init when we want to use positional arguments, but @@ -301,6 +303,7 @@ def __init__( data["required"] = required data["doc"] = doc data["initial-default"] = initial_default + data["write-default"] = write_default super().__init__(**data) def __str__(self) -> str: diff --git a/tests/avro/test_decoder.py b/tests/avro/test_decoder.py index fd660247cd..bbcc7394f4 100644 --- a/tests/avro/test_decoder.py +++ b/tests/avro/test_decoder.py @@ -27,7 +27,7 @@ from pyiceberg.avro.decoder import BinaryDecoder, StreamingBinaryDecoder, new_decoder from pyiceberg.avro.decoder_fast import CythonBinaryDecoder -from pyiceberg.avro.resolver import resolve +from pyiceberg.avro.resolver import resolve_reader from pyiceberg.io import InputStream from pyiceberg.types import DoubleType, FloatType @@ -194,7 +194,7 @@ def test_skip_utf8(decoder_class: Callable[[bytes], BinaryDecoder]) -> None: @pytest.mark.parametrize("decoder_class", AVAILABLE_DECODERS) def test_read_int_as_float(decoder_class: Callable[[bytes], BinaryDecoder]) -> None: decoder = decoder_class(b"\x00\x00\x9A\x41") - reader = resolve(FloatType(), DoubleType()) + reader = resolve_reader(FloatType(), DoubleType()) assert reader.read(decoder) == 19.25 diff --git a/tests/avro/test_file.py b/tests/avro/test_file.py index e9dcc7eca1..15de4f0704 100644 --- a/tests/avro/test_file.py +++ b/tests/avro/test_file.py @@ -30,7 +30,8 @@ from pyiceberg.avro.file import META_SCHEMA, AvroFileHeader from pyiceberg.io.pyarrow import PyArrowFileIO from pyiceberg.manifest import ( - MANIFEST_ENTRY_SCHEMA, + DEFAULT_BLOCK_SIZE, + MANIFEST_ENTRY_SCHEMAS, DataFile, DataFileContent, FileFormat, @@ -116,7 +117,7 @@ def todict(obj: Any) -> Any: return obj -def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: +def test_write_manifest_entry_with_iceberg_read_with_fastavro_v1() -> None: data_file = DataFile( content=DataFileContent.DATA, file_path="s3://some-path/some-file.parquet", @@ -124,7 +125,6 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: partition=Record(), record_count=131327, file_size_in_bytes=220669226, - block_size_in_bytes=67108864, column_sizes={1: 220661854}, value_counts={1: 131327}, null_value_counts={1: 0}, @@ -135,7 +135,6 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: split_offsets=[4, 133697593], equality_ids=[], sort_order_id=4, - spec_id=3, ) entry = ManifestEntry( status=ManifestEntryStatus.ADDED, @@ -151,7 +150,76 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: tmp_avro_file = tmpdir + "/manifest_entry.avro" with avro.AvroOutputFile[ManifestEntry]( - PyArrowFileIO().new_output(tmp_avro_file), MANIFEST_ENTRY_SCHEMA, "manifest_entry", additional_metadata + output_file=PyArrowFileIO().new_output(tmp_avro_file), + file_schema=MANIFEST_ENTRY_SCHEMAS[1], + schema_name="manifest_entry", + schema=MANIFEST_ENTRY_SCHEMAS[2], + metadata=additional_metadata, + ) as out: + out.write_block([entry]) + + with open(tmp_avro_file, "rb") as fo: + r = reader(fo=fo) + + for k, v in additional_metadata.items(): + assert k in r.metadata + assert v == r.metadata[k] + + it = iter(r) + + fa_entry = next(it) + + v2_entry = todict(entry) + + # These are not written in V1 + del v2_entry['data_sequence_number'] + del v2_entry['file_sequence_number'] + del v2_entry['data_file']['content'] + del v2_entry['data_file']['equality_ids'] + + # Required in V1 + v2_entry['data_file']['block_size_in_bytes'] = DEFAULT_BLOCK_SIZE + + assert v2_entry == fa_entry + + +def test_write_manifest_entry_with_iceberg_read_with_fastavro_v2() -> None: + data_file = DataFile( + content=DataFileContent.DATA, + file_path="s3://some-path/some-file.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=131327, + file_size_in_bytes=220669226, + column_sizes={1: 220661854}, + value_counts={1: 131327}, + null_value_counts={1: 0}, + nan_value_counts={}, + lower_bounds={1: b"aaaaaaaaaaaaaaaa"}, + upper_bounds={1: b"zzzzzzzzzzzzzzzz"}, + key_metadata=b"\xde\xad\xbe\xef", + split_offsets=[4, 133697593], + equality_ids=[], + sort_order_id=4, + ) + entry = ManifestEntry( + status=ManifestEntryStatus.ADDED, + snapshot_id=8638475580105682862, + data_sequence_number=0, + file_sequence_number=0, + data_file=data_file, + ) + + additional_metadata = {"foo": "bar"} + + with TemporaryDirectory() as tmpdir: + tmp_avro_file = tmpdir + "/manifest_entry.avro" + + with avro.AvroOutputFile[ManifestEntry]( + output_file=PyArrowFileIO().new_output(tmp_avro_file), + file_schema=MANIFEST_ENTRY_SCHEMAS[2], + schema_name="manifest_entry", + metadata=additional_metadata, ) as out: out.write_block([entry]) @@ -169,7 +237,8 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: assert todict(entry) == fa_entry -def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: +@pytest.mark.parametrize("format_version", [1, 2]) +def test_write_manifest_entry_with_fastavro_read_with_iceberg(format_version: int) -> None: data_file = DataFile( content=DataFileContent.DATA, file_path="s3://some-path/some-file.parquet", @@ -187,8 +256,10 @@ def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: split_offsets=[4, 133697593], equality_ids=[], sort_order_id=4, - spec_id=3, ) + if format_version == 1: + data_file.block_size_in_bytes = DEFAULT_BLOCK_SIZE + entry = ManifestEntry( status=ManifestEntryStatus.ADDED, snapshot_id=8638475580105682862, @@ -200,14 +271,14 @@ def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: with TemporaryDirectory() as tmpdir: tmp_avro_file = tmpdir + "/manifest_entry.avro" - schema = AvroSchemaConversion().iceberg_to_avro(MANIFEST_ENTRY_SCHEMA, schema_name="manifest_entry") + schema = AvroSchemaConversion().iceberg_to_avro(MANIFEST_ENTRY_SCHEMAS[format_version], schema_name="manifest_entry") with open(tmp_avro_file, "wb") as out: writer(out, schema, [todict(entry)]) with avro.AvroFile[ManifestEntry]( PyArrowFileIO().new_input(tmp_avro_file), - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[format_version], {-1: ManifestEntry, 2: DataFile}, ) as avro_reader: it = iter(avro_reader) diff --git a/tests/avro/test_reader.py b/tests/avro/test_reader.py index a3a502bcff..48ee8911da 100644 --- a/tests/avro/test_reader.py +++ b/tests/avro/test_reader.py @@ -41,7 +41,7 @@ ) from pyiceberg.avro.resolver import construct_reader from pyiceberg.io.pyarrow import PyArrowFileIO -from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMA, DataFile, ManifestEntry +from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMAS, DataFile, ManifestEntry from pyiceberg.schema import Schema from pyiceberg.typedef import Record from pyiceberg.types import ( @@ -70,7 +70,7 @@ def test_read_header(generated_manifest_entry_file: str, iceberg_manifest_entry_schema: Schema) -> None: with AvroFile[ManifestEntry]( PyArrowFileIO().new_input(generated_manifest_entry_file), - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[2], {-1: ManifestEntry, 2: DataFile}, ) as reader: header = reader.header diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index a302294755..d170b9ae67 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -32,8 +32,19 @@ StringReader, StructReader, ) -from pyiceberg.avro.resolver import ResolveError, resolve +from pyiceberg.avro.resolver import ResolveError, resolve_reader, resolve_writer +from pyiceberg.avro.writer import ( + BinaryWriter, + DefaultWriter, + IntegerWriter, + ListWriter, + MapWriter, + OptionWriter, + StringWriter, + StructWriter, +) from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMAS from pyiceberg.schema import Schema from pyiceberg.typedef import Record from pyiceberg.types import ( @@ -81,7 +92,7 @@ def test_resolver() -> None: NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())), schema_id=1, ) - read_tree = resolve(write_schema, read_schema) + read_tree = resolve_reader(write_schema, read_schema) assert read_tree == StructReader( ( @@ -117,7 +128,7 @@ def test_resolver_new_required_field() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "2: data: required string is non-optional, and not part of the file schema" in str(exc_info.value) @@ -133,7 +144,7 @@ def test_resolver_invalid_evolution() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "Cannot promote long to double" in str(exc_info.value) @@ -147,7 +158,7 @@ def test_resolver_promotion_string_to_binary() -> None: NestedField(1, "id", BinaryType()), schema_id=1, ) - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) def test_resolver_promotion_binary_to_string() -> None: @@ -159,7 +170,7 @@ def test_resolver_promotion_binary_to_string() -> None: NestedField(1, "id", StringType()), schema_id=1, ) - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) def test_resolver_change_type() -> None: @@ -173,69 +184,69 @@ def test_resolver_change_type() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "File/read schema are not aligned for list, got map" in str(exc_info.value) def test_resolve_int_to_long() -> None: - assert resolve(IntegerType(), LongType()) == IntegerReader() + assert resolve_reader(IntegerType(), LongType()) == IntegerReader() def test_resolve_float_to_double() -> None: # We should still read floats, because it is encoded in 4 bytes - assert resolve(FloatType(), DoubleType()) == FloatReader() + assert resolve_reader(FloatType(), DoubleType()) == FloatReader() def test_resolve_decimal_to_decimal() -> None: # DecimalType(P, S) to DecimalType(P2, S) where P2 > P - assert resolve(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(19, 25) + assert resolve_reader(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(19, 25) def test_struct_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(StructType(), StringType()) + assert resolve_reader(StructType(), StringType()) def test_map_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(MapType(1, StringType(), 2, IntegerType()), StringType()) + assert resolve_reader(MapType(1, StringType(), 2, IntegerType()), StringType()) def test_primitive_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(IntegerType(), MapType(1, StringType(), 2, IntegerType())) + assert resolve_reader(IntegerType(), MapType(1, StringType(), 2, IntegerType())) def test_integer_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(IntegerType(), StringType()) + assert resolve_reader(IntegerType(), StringType()) def test_float_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(FloatType(), StringType()) + assert resolve_reader(FloatType(), StringType()) def test_string_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(StringType(), FloatType()) + assert resolve_reader(StringType(), FloatType()) def test_binary_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(BinaryType(), FloatType()) + assert resolve_reader(BinaryType(), FloatType()) def test_decimal_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(DecimalType(22, 19), StringType()) + assert resolve_reader(DecimalType(22, 19), StringType()) def test_resolve_decimal_to_decimal_reduce_precision() -> None: # DecimalType(P, S) to DecimalType(P2, S) where P2 > P with pytest.raises(ResolveError) as exc_info: - _ = resolve(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25) + _ = resolve_reader(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25) assert "Cannot reduce precision from decimal(19, 25) to decimal(10, 25)" in str(exc_info.value) @@ -293,7 +304,7 @@ def test_resolver_initial_value() -> None: schema_id=2, ) - assert resolve(write_schema, read_schema) == StructReader( + assert resolve_reader(write_schema, read_schema) == StructReader( ( (None, StringReader()), # The one we skip (0, DefaultReader("vo")), @@ -301,3 +312,37 @@ def test_resolver_initial_value() -> None: Record, read_schema.as_struct(), ) + + +def test_resolve_writer() -> None: + actual = resolve_writer(MANIFEST_ENTRY_SCHEMAS[1], MANIFEST_ENTRY_SCHEMAS[2]) + expected = StructWriter( + ( + (0, IntegerWriter()), + (1, IntegerWriter()), + ( + 4, + StructWriter( + ( + (1, StringWriter()), + (2, StringWriter()), + (3, StructWriter(())), + (4, IntegerWriter()), + (5, IntegerWriter()), + (None, DefaultWriter(writer=IntegerWriter(), value=67108864)), + (6, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (7, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (8, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (9, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (10, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (11, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (12, OptionWriter(option=BinaryWriter())), + (13, OptionWriter(option=ListWriter(element_writer=IntegerWriter()))), + (15, OptionWriter(option=IntegerWriter())), + ) + ), + ), + ) + ) + + assert actual == expected diff --git a/tests/test_integration_manifest.py b/tests/test_integration_manifest.py index 34b20f271d..475e0d40a6 100644 --- a/tests/test_integration_manifest.py +++ b/tests/test_integration_manifest.py @@ -101,7 +101,6 @@ def test_write_sample_manifest(table_test_all_types: Table) -> None: split_offsets=entry.data_file.split_offsets, equality_ids=entry.data_file.equality_ids, sort_order_id=entry.data_file.sort_order_id, - spec_id=entry.data_file.spec_id, ) wrapped_entry_v2 = ManifestEntry(*entry.record_fields()) wrapped_entry_v2.data_file = wrapped_data_file_v2_debug diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 41af844bba..4d55610c74 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme from tempfile import TemporaryDirectory -from typing import Dict +from typing import Dict, Literal import fastavro import pytest @@ -303,7 +303,9 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None: @pytest.mark.parametrize("format_version", [1, 2]) -def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: int) -> None: +def test_write_manifest( + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] +) -> None: io = load_file_io() snapshot = Snapshot( snapshot_id=25, @@ -327,7 +329,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest tmp_avro_file = tmpdir + "/test_write_manifest.avro" output = io.new_output(tmp_avro_file) with write_manifest( - format_version=format_version, # type: ignore + format_version=format_version, spec=test_spec, schema=test_schema, output_file=output, @@ -337,6 +339,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest writer.add_entry(entry) new_manifest = writer.to_manifest_file() with pytest.raises(RuntimeError): + # It is already closed writer.add_entry(manifest_entries[0]) expected_metadata = { @@ -345,8 +348,6 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest "partition-spec-id": str(test_spec.spec_id), "format-version": str(format_version), } - if format_version == 2: - expected_metadata["content"] = "data" _verify_metadata_with_fastavro( tmp_avro_file, expected_metadata, @@ -357,7 +358,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest assert manifest_entry.status == ManifestEntryStatus.ADDED assert manifest_entry.snapshot_id == 8744736658442914487 - assert manifest_entry.data_sequence_number == 0 if format_version == 1 else 3 + assert manifest_entry.data_sequence_number == -1 if format_version == 1 else 3 assert isinstance(manifest_entry.data_file, DataFile) data_file = manifest_entry.data_file @@ -371,10 +372,6 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest assert data_file.partition == Record(VendorID=1, tpep_pickup_datetime=1925) assert data_file.record_count == 19513 assert data_file.file_size_in_bytes == 388872 - if format_version == 1: - assert data_file.block_size_in_bytes == 67108864 - else: - assert data_file.block_size_in_bytes is None assert data_file.column_sizes == { 1: 53, 2: 98153, @@ -477,7 +474,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest @pytest.mark.parametrize("format_version", [1, 2]) def test_write_manifest_list( - generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: int + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] ) -> None: io = load_file_io() @@ -495,7 +492,7 @@ def test_write_manifest_list( path = tmp_dir + "/manifest-list.avro" output = io.new_output(path) with write_manifest_list( - format_version=format_version, output_file=output, snapshot_id=25, parent_snapshot_id=19, sequence_number=0 # type: ignore + format_version=format_version, output_file=output, snapshot_id=25, parent_snapshot_id=19, sequence_number=0 ) as writer: writer.add_manifests(demo_manifest_list) new_manifest_list = list(read_manifest_list(io.new_input(path))) From 0fce5b5b9c05d0a4fc9a853db37800eeb6655dc6 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 5 Oct 2023 15:40:35 +0200 Subject: [PATCH 2/8] Remove unrelated change --- pyiceberg/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index dd795a3076..28101809c7 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1143,7 +1143,7 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]): """A schema visitor for generating a field ID to accessor index. Example: - >>> from pyiceberg.file_schema import Schema + >>> from pyiceberg.schema import Schema >>> from pyiceberg.types import * >>> schema = Schema( ... NestedField(field_id=2, name="id", field_type=IntegerType(), required=False), From af82b85785f3cdfda6ca830c37b3517a9bacb2a5 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Sun, 8 Oct 2023 23:19:25 +0200 Subject: [PATCH 3/8] First part of the comments --- pyiceberg/avro/file.py | 11 +++--- pyiceberg/avro/resolver.py | 69 +++++++++++++++---------------------- pyiceberg/manifest.py | 4 +-- tests/avro/test_file.py | 2 +- tests/avro/test_resolver.py | 49 +++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 50 deletions(-) diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index 995b211927..a96fb3347d 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -233,17 +233,18 @@ def __init__( output_file: OutputFile, file_schema: Schema, schema_name: str, - schema: Optional[Schema] = None, + data_schema: Optional[Schema] = None, metadata: Dict[str, str] = EMPTY_DICT, ) -> None: self.output_file = output_file self.file_schema = file_schema self.schema_name = schema_name self.sync_bytes = os.urandom(SYNC_SIZE) - if schema is None: - self.writer = construct_writer(self.file_schema) - else: - self.writer = resolve_writer(self.file_schema, schema) + self.writer = ( + construct_writer(file_schema=self.file_schema) + if data_schema is None + else resolve_writer(data_schema=data_schema, write_schema=self.file_schema) + ) self.metadata = metadata def __enter__(self) -> AvroOutputFile[D]: diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 905704b18f..96cbd3bc8f 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -60,7 +60,6 @@ IntegerWriter, ListWriter, MapWriter, - NoneWriter, OptionWriter, StringWriter, StructWriter, @@ -199,19 +198,21 @@ def visit_binary(self, binary_type: BinaryType) -> Writer: def resolve_writer( - struct_schema: Union[Schema, IcebergType], + data_schema: Union[Schema, IcebergType], write_schema: Union[Schema, IcebergType], ) -> Writer: """Resolve the file and read schema to produce a reader. Args: - struct_schema (Schema | IcebergType): The schema of the Avro file. + data_schema (Schema | IcebergType): The schema of the Avro file. write_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - return visit_with_partner(struct_schema, write_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore + if write_schema == data_schema: + return construct_writer(write_schema) + return visit_with_partner(write_schema, data_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore def resolve_reader( @@ -256,61 +257,47 @@ def skip(self, decoder: BinaryDecoder) -> None: class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): - def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Writer) -> Writer: + def schema(self, write_schema: Schema, data_schema: Optional[IcebergType], result: Writer) -> Writer: return result - def struct(self, struct: StructType, provided_struct: Optional[IcebergType], field_writers: List[Writer]) -> Writer: - if not isinstance(provided_struct, StructType): - raise ResolveError(f"File/write schema are not aligned for struct, got {provided_struct}") - - provided_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(provided_struct.fields)} + def struct(self, write_schema: StructType, data_struct: Optional[IcebergType], field_writers: List[Writer]) -> Writer: + if not isinstance(data_struct, StructType): + raise ResolveError(f"File/write schema are not aligned for struct, got {data_struct}") + data_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(data_struct.fields)} results: List[Tuple[Optional[int], Writer]] = [] - iter(field_writers) - for pos, write_field in enumerate(struct.fields): - if write_field.field_id in provided_struct_positions: - results.append((provided_struct_positions[write_field.field_id], field_writers[pos])) + for writer, write_field in zip(field_writers, write_schema.fields): + if write_field.field_id in data_positions: + results.append((data_positions[write_field.field_id], writer)) else: # There is a default value - if isinstance(write_field, NestedField) and write_field.write_default is not None: + if write_field.write_default is not None: # The field is not in the record, but there is a write default value - default_writer = DefaultWriter( - writer=visit(write_field.field_type, CONSTRUCT_WRITER_VISITOR), value=write_field.write_default - ) - results.append((None, default_writer)) + results.append((None, DefaultWriter(writer=writer, value=write_field.write_default))) elif write_field.required: raise ValueError(f"Field is required, and there is no write default: {write_field}") - else: - results.append((pos, NoneWriter())) return StructWriter(field_writers=tuple(results)) - def field(self, field: NestedField, expected_field: Optional[IcebergType], field_writer: Writer) -> Writer: - return field_writer if field.required else OptionWriter(field_writer) - - def list(self, list_type: ListType, expected_list: Optional[IcebergType], element_reader: Writer) -> Writer: - if expected_list and not isinstance(expected_list, ListType): - raise ResolveError(f"File/read schema are not aligned for list, got {expected_list}") - - return ListWriter(element_reader if list_type.element_required else OptionWriter(element_reader)) + def field(self, write_field: NestedField, data_type: Optional[IcebergType], field_writer: Writer) -> Writer: + return field_writer if write_field.required else OptionWriter(field_writer) - def map(self, map_type: MapType, expected_map: Optional[IcebergType], key_reader: Writer, value_reader: Writer) -> Writer: - if expected_map and not isinstance(expected_map, MapType): - raise ResolveError(f"File/read schema are not aligned for map, got {expected_map}") + def list(self, write_list_type: ListType, write_list: Optional[IcebergType], element_reader: Writer) -> Writer: + return ListWriter(element_reader if write_list_type.element_required else OptionWriter(element_reader)) - return MapWriter(key_reader, value_reader if map_type.value_required else OptionWriter(value_reader)) - - def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[IcebergType]) -> Writer: - if expected_primitive is not None: - if not isinstance(expected_primitive, PrimitiveType): - raise ResolveError(f"File/read schema are not aligned for {primitive}, got {expected_primitive}") + def map( + self, write_map_type: MapType, write_primitive: Optional[IcebergType], key_reader: Writer, value_reader: Writer + ) -> Writer: + return MapWriter(key_reader, value_reader if write_map_type.value_required else OptionWriter(value_reader)) + def primitive(self, write_primitive: PrimitiveType, data_primitive: Optional[IcebergType]) -> Writer: + if data_primitive is not None: # ensure that the type can be projected to the expected - if primitive != expected_primitive: - promote(primitive, expected_primitive) + if write_primitive != data_primitive: + promote(data_primitive, write_primitive) - return super().primitive(primitive, expected_primitive) + return super().primitive(write_primitive, write_primitive) def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: return BooleanWriter() diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index efeceac7e7..15b2b6d24f 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -275,7 +275,7 @@ def __repr__(self) -> str: name="sort_order_id", field_type=IntegerType(), required=False, - doc=" ID representing sort order for this file", + doc="ID representing sort order for this file", ), ), } @@ -706,7 +706,7 @@ def new_writer(self) -> AvroOutputFile[ManifestEntry]: return AvroOutputFile[ManifestEntry]( output_file=self._output_file, file_schema=self._with_partition(self.version), - schema=self._with_partition(DEFAULT_READ_VERSION) if self.version != DEFAULT_READ_VERSION else None, + data_schema=self._with_partition(DEFAULT_READ_VERSION), schema_name="manifest_entry", metadata=self._meta, ) diff --git a/tests/avro/test_file.py b/tests/avro/test_file.py index 15de4f0704..9c0b754098 100644 --- a/tests/avro/test_file.py +++ b/tests/avro/test_file.py @@ -153,7 +153,7 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro_v1() -> None: output_file=PyArrowFileIO().new_output(tmp_avro_file), file_schema=MANIFEST_ENTRY_SCHEMAS[1], schema_name="manifest_entry", - schema=MANIFEST_ENTRY_SCHEMAS[2], + data_schema=MANIFEST_ENTRY_SCHEMAS[2], metadata=additional_metadata, ) as out: out.write_block([entry]) diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index d170b9ae67..5bdf76e0e2 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -36,6 +36,7 @@ from pyiceberg.avro.writer import ( BinaryWriter, DefaultWriter, + DoubleWriter, IntegerWriter, ListWriter, MapWriter, @@ -49,6 +50,7 @@ from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, + BooleanType, DecimalType, DoubleType, FloatType, @@ -315,7 +317,7 @@ def test_resolver_initial_value() -> None: def test_resolve_writer() -> None: - actual = resolve_writer(MANIFEST_ENTRY_SCHEMAS[1], MANIFEST_ENTRY_SCHEMAS[2]) + actual = resolve_writer(data_schema=MANIFEST_ENTRY_SCHEMAS[2], write_schema=MANIFEST_ENTRY_SCHEMAS[1]) expected = StructWriter( ( (0, IntegerWriter()), @@ -346,3 +348,48 @@ def test_resolve_writer() -> None: ) assert actual == expected + + +def test_resolve_writer_promotion() -> None: + with pytest.raises(ResolveError) as exc_info: + _ = resolve_writer( + data_schema=Schema(NestedField(field_id=1, name="floating", type=DoubleType(), required=True)), + write_schema=Schema(NestedField(field_id=1, name="floating", type=FloatType(), required=True)), + ) + + assert "Cannot promote double to float" in str(exc_info.value) + + +def test_writer_ordering() -> None: + actual = resolve_writer( + data_schema=Schema( + NestedField(field_id=1, name="str", type=StringType(), required=True), + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + ), + write_schema=Schema( + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + ), + ) + + expected = StructWriter(((1, DoubleWriter()), (0, StringWriter()))) + + assert actual == expected + + +def test_writer_one_more_field() -> None: + actual = resolve_writer( + data_schema=Schema( + NestedField(field_id=3, name="bool", type=BooleanType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + ), + write_schema=Schema( + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + ), + ) + + expected = StructWriter(((2, DoubleWriter()), (1, StringWriter()))) + + assert actual == expected From 28d1fb05a2de16d6874d9c90cbff07b839f57495 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Sun, 8 Oct 2023 23:39:48 +0200 Subject: [PATCH 4/8] Less is more --- pyiceberg/avro/writer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pyiceberg/avro/writer.py b/pyiceberg/avro/writer.py index fbb3de62be..4e3d3d476a 100644 --- a/pyiceberg/avro/writer.py +++ b/pyiceberg/avro/writer.py @@ -51,12 +51,6 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" -@dataclass(frozen=True) -class NoneWriter(Writer): - def write(self, encoder: BinaryEncoder, __: Any) -> None: - encoder.write_int(0) - - @dataclass(frozen=True) class BooleanWriter(Writer): def write(self, encoder: BinaryEncoder, val: bool) -> None: From cae519a9abc6734ba5e123551f81debbdec72c89 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 9 Oct 2023 22:46:25 +0200 Subject: [PATCH 5/8] Fix writer-default --- pyiceberg/avro/resolver.py | 2 +- pyiceberg/types.py | 6 +++--- tests/test_types.py | 6 ++++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 96cbd3bc8f..b5d58eb42e 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -274,7 +274,7 @@ def struct(self, write_schema: StructType, data_struct: Optional[IcebergType], f # There is a default value if write_field.write_default is not None: # The field is not in the record, but there is a write default value - results.append((None, DefaultWriter(writer=writer, value=write_field.write_default))) + results.append((None, DefaultWriter(writer=writer, value=write_field.write_default))) # type: ignore elif write_field.required: raise ValueError(f"Field is required, and there is no write default: {write_field}") diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 715959602a..b8fdedea51 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -51,7 +51,7 @@ from pydantic_core.core_schema import ValidatorFunctionWrapHandler from pyiceberg.exceptions import ValidationError -from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel +from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L from pyiceberg.utils.parsing import ParseNumberFromBrackets from pyiceberg.utils.singleton import Singleton @@ -282,7 +282,7 @@ class NestedField(IcebergType): required: bool = Field(default=True) doc: Optional[str] = Field(default=None, repr=False) initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) - write_default: Optional[Any] = Field(alias="write-default", default=None, repr=False) + write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore def __init__( self, @@ -292,7 +292,7 @@ def __init__( required: bool = True, doc: Optional[str] = None, initial_default: Optional[Any] = None, - write_default: Optional[Any] = None, + write_default: Optional[L] = None, **data: Any, ): # We need an init when we want to use positional arguments, but diff --git a/tests/test_types.py b/tests/test_types.py index 249ee98a6f..6aed56c58f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -18,6 +18,7 @@ import pickle from typing import Type +import pydantic_core import pytest from pyiceberg.exceptions import ValidationError @@ -208,6 +209,11 @@ def test_nested_field() -> None: assert str(field_var) == str(eval(repr(field_var))) assert field_var == pickle.loads(pickle.dumps(field_var)) + with pytest.raises(pydantic_core.ValidationError) as exc_info: + _ = (NestedField(1, "field", StringType(), required=True, write_default=(1, "a", True)),) # type: ignore + + assert "validation errors for NestedField" in str(exc_info.value) + @pytest.mark.parametrize("input_index,input_type", non_parameterized_types) @pytest.mark.parametrize("check_index,check_type", non_parameterized_types) From ee58a1621154f09d47996f42c1478490ef69dac1 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 10 Oct 2023 14:03:55 +0200 Subject: [PATCH 6/8] Comments, thanks Ryan! --- pyiceberg/avro/file.py | 6 ++-- pyiceberg/avro/resolver.py | 60 ++++++++++++++++++------------------- pyiceberg/manifest.py | 2 +- tests/avro/test_file.py | 2 +- tests/avro/test_resolver.py | 14 ++++----- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index a96fb3347d..4985c6fb60 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -233,7 +233,7 @@ def __init__( output_file: OutputFile, file_schema: Schema, schema_name: str, - data_schema: Optional[Schema] = None, + record_schema: Optional[Schema] = None, metadata: Dict[str, str] = EMPTY_DICT, ) -> None: self.output_file = output_file @@ -242,8 +242,8 @@ def __init__( self.sync_bytes = os.urandom(SYNC_SIZE) self.writer = ( construct_writer(file_schema=self.file_schema) - if data_schema is None - else resolve_writer(data_schema=data_schema, write_schema=self.file_schema) + if record_schema is None + else resolve_writer(record_schema=record_schema, file_schema=self.file_schema) ) self.metadata = metadata diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index b5d58eb42e..4f486d7336 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -198,21 +198,21 @@ def visit_binary(self, binary_type: BinaryType) -> Writer: def resolve_writer( - data_schema: Union[Schema, IcebergType], - write_schema: Union[Schema, IcebergType], + record_schema: Union[Schema, IcebergType], + file_schema: Union[Schema, IcebergType], ) -> Writer: """Resolve the file and read schema to produce a reader. Args: - data_schema (Schema | IcebergType): The schema of the Avro file. - write_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. + record_schema (Schema | IcebergType): The schema of the record in memory. + file_schema (Schema | IcebergType): The schema of the file that will be written Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - if write_schema == data_schema: - return construct_writer(write_schema) - return visit_with_partner(write_schema, data_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore + if record_schema == file_schema: + return construct_writer(file_schema) + return visit_with_partner(file_schema, record_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore def resolve_reader( @@ -257,47 +257,47 @@ def skip(self, decoder: BinaryDecoder) -> None: class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): - def schema(self, write_schema: Schema, data_schema: Optional[IcebergType], result: Writer) -> Writer: + def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer: return result - def struct(self, write_schema: StructType, data_struct: Optional[IcebergType], field_writers: List[Writer]) -> Writer: - if not isinstance(data_struct, StructType): - raise ResolveError(f"File/write schema are not aligned for struct, got {data_struct}") + def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer: + if not isinstance(record_struct, StructType): + raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}") - data_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(data_struct.fields)} + record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)} results: List[Tuple[Optional[int], Writer]] = [] - for writer, write_field in zip(field_writers, write_schema.fields): - if write_field.field_id in data_positions: - results.append((data_positions[write_field.field_id], writer)) + for writer, file_field in zip(file_writers, file_schema.fields): + if file_field.field_id in record_struct_positions: + results.append((record_struct_positions[file_field.field_id], writer)) else: # There is a default value - if write_field.write_default is not None: + if file_field.write_default is not None: # The field is not in the record, but there is a write default value - results.append((None, DefaultWriter(writer=writer, value=write_field.write_default))) # type: ignore - elif write_field.required: - raise ValueError(f"Field is required, and there is no write default: {write_field}") + results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore + elif file_field.required: + raise ValueError(f"Field is required, and there is no write default: {file_field}") return StructWriter(field_writers=tuple(results)) - def field(self, write_field: NestedField, data_type: Optional[IcebergType], field_writer: Writer) -> Writer: - return field_writer if write_field.required else OptionWriter(field_writer) + def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer: + return field_writer if file_field.required else OptionWriter(field_writer) - def list(self, write_list_type: ListType, write_list: Optional[IcebergType], element_reader: Writer) -> Writer: - return ListWriter(element_reader if write_list_type.element_required else OptionWriter(element_reader)) + def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer: + return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer)) def map( - self, write_map_type: MapType, write_primitive: Optional[IcebergType], key_reader: Writer, value_reader: Writer + self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer ) -> Writer: - return MapWriter(key_reader, value_reader if write_map_type.value_required else OptionWriter(value_reader)) + return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer)) - def primitive(self, write_primitive: PrimitiveType, data_primitive: Optional[IcebergType]) -> Writer: - if data_primitive is not None: + def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer: + if record_primitive is not None: # ensure that the type can be projected to the expected - if write_primitive != data_primitive: - promote(data_primitive, write_primitive) + if file_primitive != record_primitive: + promote(record_primitive, file_primitive) - return super().primitive(write_primitive, write_primitive) + return super().primitive(file_primitive, file_primitive) def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: return BooleanWriter() diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 15b2b6d24f..92ca300f6d 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -706,7 +706,7 @@ def new_writer(self) -> AvroOutputFile[ManifestEntry]: return AvroOutputFile[ManifestEntry]( output_file=self._output_file, file_schema=self._with_partition(self.version), - data_schema=self._with_partition(DEFAULT_READ_VERSION), + record_schema=self._with_partition(DEFAULT_READ_VERSION), schema_name="manifest_entry", metadata=self._meta, ) diff --git a/tests/avro/test_file.py b/tests/avro/test_file.py index 9c0b754098..518026cc4f 100644 --- a/tests/avro/test_file.py +++ b/tests/avro/test_file.py @@ -153,7 +153,7 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro_v1() -> None: output_file=PyArrowFileIO().new_output(tmp_avro_file), file_schema=MANIFEST_ENTRY_SCHEMAS[1], schema_name="manifest_entry", - data_schema=MANIFEST_ENTRY_SCHEMAS[2], + record_schema=MANIFEST_ENTRY_SCHEMAS[2], metadata=additional_metadata, ) as out: out.write_block([entry]) diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index 5bdf76e0e2..0aababfcda 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -317,7 +317,7 @@ def test_resolver_initial_value() -> None: def test_resolve_writer() -> None: - actual = resolve_writer(data_schema=MANIFEST_ENTRY_SCHEMAS[2], write_schema=MANIFEST_ENTRY_SCHEMAS[1]) + actual = resolve_writer(record_schema=MANIFEST_ENTRY_SCHEMAS[2], file_schema=MANIFEST_ENTRY_SCHEMAS[1]) expected = StructWriter( ( (0, IntegerWriter()), @@ -353,8 +353,8 @@ def test_resolve_writer() -> None: def test_resolve_writer_promotion() -> None: with pytest.raises(ResolveError) as exc_info: _ = resolve_writer( - data_schema=Schema(NestedField(field_id=1, name="floating", type=DoubleType(), required=True)), - write_schema=Schema(NestedField(field_id=1, name="floating", type=FloatType(), required=True)), + record_schema=Schema(NestedField(field_id=1, name="floating", type=DoubleType(), required=True)), + file_schema=Schema(NestedField(field_id=1, name="floating", type=FloatType(), required=True)), ) assert "Cannot promote double to float" in str(exc_info.value) @@ -362,11 +362,11 @@ def test_resolve_writer_promotion() -> None: def test_writer_ordering() -> None: actual = resolve_writer( - data_schema=Schema( + record_schema=Schema( NestedField(field_id=1, name="str", type=StringType(), required=True), NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), ), - write_schema=Schema( + file_schema=Schema( NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), NestedField(field_id=1, name="str", type=StringType(), required=True), ), @@ -379,12 +379,12 @@ def test_writer_ordering() -> None: def test_writer_one_more_field() -> None: actual = resolve_writer( - data_schema=Schema( + record_schema=Schema( NestedField(field_id=3, name="bool", type=BooleanType(), required=True), NestedField(field_id=1, name="str", type=StringType(), required=True), NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), ), - write_schema=Schema( + file_schema=Schema( NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), NestedField(field_id=1, name="str", type=StringType(), required=True), ), From 31ffbf30bd995a729ab6eec0d088b6c8b6aae95e Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Tue, 10 Oct 2023 18:09:15 +0200 Subject: [PATCH 7/8] Commit the changes --- pyiceberg/avro/resolver.py | 4 +++- tests/avro/test_resolver.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 4f486d7336..586144eab7 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -270,13 +270,15 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], for writer, file_field in zip(file_writers, file_schema.fields): if file_field.field_id in record_struct_positions: results.append((record_struct_positions[file_field.field_id], writer)) - else: + elif file_field.required: # There is a default value if file_field.write_default is not None: # The field is not in the record, but there is a write default value results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore elif file_field.required: raise ValueError(f"Field is required, and there is no write default: {file_field}") + else: + results.append((None, OptionWriter(option=writer))) return StructWriter(field_writers=tuple(results)) diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index 0aababfcda..5734bbc2f8 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -393,3 +393,16 @@ def test_writer_one_more_field() -> None: expected = StructWriter(((2, DoubleWriter()), (1, StringWriter()))) assert actual == expected + + +def test_writer_missing_optional_in_read_schema() -> None: + actual = resolve_writer( + record_schema=Schema(), + file_schema=Schema( + NestedField(field_id=1, name="str", type=StringType(), required=False), + ), + ) + + expected = StructWriter(field_writers=((None, OptionWriter(option=OptionWriter(option=StringWriter()))),)) + + assert actual == expected From 1faa7702b9614d01a04fd723bf6f3f40b0934e56 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Wed, 11 Oct 2023 10:10:20 +0200 Subject: [PATCH 8/8] Fix option writer --- pyiceberg/avro/resolver.py | 2 +- tests/avro/test_resolver.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 586144eab7..faf4dd0501 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -278,7 +278,7 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], elif file_field.required: raise ValueError(f"Field is required, and there is no write default: {file_field}") else: - results.append((None, OptionWriter(option=writer))) + results.append((None, writer)) return StructWriter(field_writers=tuple(results)) diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index 5734bbc2f8..51d2a7d8fc 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -403,6 +403,6 @@ def test_writer_missing_optional_in_read_schema() -> None: ), ) - expected = StructWriter(field_writers=((None, OptionWriter(option=OptionWriter(option=StringWriter()))),)) + expected = StructWriter(field_writers=((None, OptionWriter(option=StringWriter())),)) assert actual == expected