diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index dc843f6dc0..4985c6fb60 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -38,7 +38,7 @@ from pyiceberg.avro.decoder import BinaryDecoder, new_decoder from pyiceberg.avro.encoder import BinaryEncoder from pyiceberg.avro.reader import Reader -from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve +from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve_reader, resolve_writer from pyiceberg.avro.writer import Writer from pyiceberg.io import InputFile, OutputFile, OutputStream from pyiceberg.schema import Schema @@ -172,7 +172,7 @@ def __enter__(self) -> AvroFile[D]: if not self.read_schema: self.read_schema = self.schema - self.reader = resolve(self.schema, self.read_schema, self.read_types, self.read_enums) + self.reader = resolve_reader(self.schema, self.read_schema, self.read_types, self.read_enums) return self @@ -222,18 +222,29 @@ def _read_header(self) -> AvroFileHeader: class AvroOutputFile(Generic[D]): output_file: OutputFile output_stream: OutputStream - schema: Schema + file_schema: Schema schema_name: str encoder: BinaryEncoder sync_bytes: bytes writer: Writer - def __init__(self, output_file: OutputFile, schema: Schema, schema_name: str, metadata: Dict[str, str] = EMPTY_DICT) -> None: + def __init__( + self, + output_file: OutputFile, + file_schema: Schema, + schema_name: str, + record_schema: Optional[Schema] = None, + metadata: Dict[str, str] = EMPTY_DICT, + ) -> None: self.output_file = output_file - self.schema = schema + self.file_schema = file_schema self.schema_name = schema_name self.sync_bytes = os.urandom(SYNC_SIZE) - self.writer = construct_writer(self.schema) + self.writer = ( + construct_writer(file_schema=self.file_schema) + if record_schema is None + else resolve_writer(record_schema=record_schema, file_schema=self.file_schema) + ) self.metadata = metadata def __enter__(self) -> AvroOutputFile[D]: @@ -247,7 +258,6 @@ def __enter__(self) -> AvroOutputFile[D]: self.encoder = BinaryEncoder(self.output_stream) self._write_header() - self.writer = construct_writer(self.schema) return self @@ -258,7 +268,7 @@ def __exit__( self.output_stream.close() def _write_header(self) -> None: - json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.schema, schema_name=self.schema_name)) + json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name)) meta = {**self.metadata, _SCHEMA_KEY: json_schema, _CODEC_KEY: "null"} header = AvroFileHeader(magic=MAGIC, meta=meta, sync=self.sync_bytes) construct_writer(META_SCHEMA).write(self.encoder, header) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 8b2daeb7c7..faf4dd0501 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -53,6 +53,7 @@ BooleanWriter, DateWriter, DecimalWriter, + DefaultWriter, DoubleWriter, FixedWriter, FloatWriter, @@ -112,11 +113,12 @@ def construct_reader( Args: file_schema (Schema | IcebergType): The schema of the Avro file. + read_types (Dict[int, Callable[..., StructProtocol]]): Constructors for structs for certain field-ids Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - return resolve(file_schema, file_schema, read_types) + return resolve_reader(file_schema, file_schema, read_types) def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: @@ -128,7 +130,7 @@ def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ - return visit(file_schema, ConstructWriter()) + return visit(file_schema, CONSTRUCT_WRITER_VISITOR) class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]): @@ -138,7 +140,7 @@ def schema(self, schema: Schema, struct_result: Writer) -> Writer: return struct_result def struct(self, struct: StructType, field_results: List[Writer]) -> Writer: - return StructWriter(tuple(field_results)) + return StructWriter(tuple((pos, result) for pos, result in enumerate(field_results))) def field(self, field: NestedField, field_result: Writer) -> Writer: return field_result if field.required else OptionWriter(field_result) @@ -192,7 +194,28 @@ def visit_binary(self, binary_type: BinaryType) -> Writer: return BinaryWriter() -def resolve( +CONSTRUCT_WRITER_VISITOR = ConstructWriter() + + +def resolve_writer( + record_schema: Union[Schema, IcebergType], + file_schema: Union[Schema, IcebergType], +) -> Writer: + """Resolve the file and read schema to produce a reader. + + Args: + record_schema (Schema | IcebergType): The schema of the record in memory. + file_schema (Schema | IcebergType): The schema of the file that will be written + + Raises: + NotImplementedError: If attempting to resolve an unrecognized object type. + """ + if record_schema == file_schema: + return construct_writer(file_schema) + return visit_with_partner(file_schema, record_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore + + +def resolve_reader( file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, @@ -210,7 +233,7 @@ def resolve( NotImplementedError: If attempting to resolve an unrecognized object type. """ return visit_with_partner( - file_schema, read_schema, SchemaResolver(read_types, read_enums), SchemaPartnerAccessor() + file_schema, read_schema, ReadSchemaResolver(read_types, read_enums), SchemaPartnerAccessor() ) # type: ignore @@ -233,7 +256,95 @@ def skip(self, decoder: BinaryDecoder) -> None: pass -class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): +class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): + def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer: + return result + + def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer: + if not isinstance(record_struct, StructType): + raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}") + + record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)} + results: List[Tuple[Optional[int], Writer]] = [] + + for writer, file_field in zip(file_writers, file_schema.fields): + if file_field.field_id in record_struct_positions: + results.append((record_struct_positions[file_field.field_id], writer)) + elif file_field.required: + # There is a default value + if file_field.write_default is not None: + # The field is not in the record, but there is a write default value + results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore + elif file_field.required: + raise ValueError(f"Field is required, and there is no write default: {file_field}") + else: + results.append((None, writer)) + + return StructWriter(field_writers=tuple(results)) + + def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer: + return field_writer if file_field.required else OptionWriter(field_writer) + + def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer: + return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer)) + + def map( + self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer + ) -> Writer: + return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer)) + + def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer: + if record_primitive is not None: + # ensure that the type can be projected to the expected + if file_primitive != record_primitive: + promote(record_primitive, file_primitive) + + return super().primitive(file_primitive, file_primitive) + + def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: + return BooleanWriter() + + def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer: + return IntegerWriter() + + def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer: + return IntegerWriter() + + def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer: + return FloatWriter() + + def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer: + return DoubleWriter() + + def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer: + return DecimalWriter(decimal_type.precision, decimal_type.scale) + + def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer: + return DateWriter() + + def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer: + return TimeWriter() + + def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer: + return TimestampWriter() + + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer: + return TimestamptzWriter() + + def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer: + return StringWriter() + + def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer: + return UUIDWriter() + + def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer: + return FixedWriter(len(fixed_type)) + + def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer: + return BinaryWriter() + + +class ReadSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): __slots__ = ("read_types", "read_enums", "context") read_types: Dict[int, Callable[..., StructProtocol]] read_enums: Dict[int, Callable[..., Enum]] @@ -279,7 +390,7 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie for field, result_reader in zip(struct.fields, field_readers) ] - file_fields = {field.field_id: field for field in struct.fields} + file_fields = {field.field_id for field in struct.fields} for pos, read_field in enumerate(expected_struct.fields): if read_field.field_id not in file_fields: if isinstance(read_field, NestedField) and read_field.initial_default is not None: diff --git a/pyiceberg/avro/writer.py b/pyiceberg/avro/writer.py index ad6a755614..4e3d3d476a 100644 --- a/pyiceberg/avro/writer.py +++ b/pyiceberg/avro/writer.py @@ -29,6 +29,7 @@ Any, Dict, List, + Optional, Tuple, ) from uuid import UUID @@ -39,6 +40,7 @@ from pyiceberg.utils.singleton import Singleton +@dataclass(frozen=True) class Writer(Singleton): @abstractmethod def write(self, encoder: BinaryEncoder, val: Any) -> Any: @@ -49,16 +51,13 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" -class NoneWriter(Writer): - def write(self, _: BinaryEncoder, __: Any) -> None: - pass - - +@dataclass(frozen=True) class BooleanWriter(Writer): def write(self, encoder: BinaryEncoder, val: bool) -> None: encoder.write_boolean(val) +@dataclass(frozen=True) class IntegerWriter(Writer): """Longs and ints are encoded the same way, and there is no long in Python.""" @@ -66,41 +65,49 @@ def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class FloatWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_float(val) +@dataclass(frozen=True) class DoubleWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_double(val) +@dataclass(frozen=True) class DateWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimeWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimestampWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class TimestamptzWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) +@dataclass(frozen=True) class StringWriter(Writer): def write(self, encoder: BinaryEncoder, val: Any) -> None: encoder.write_utf8(val) +@dataclass(frozen=True) class UUIDWriter(Writer): def write(self, encoder: BinaryEncoder, val: UUID) -> None: encoder.write(val.bytes) @@ -124,6 +131,7 @@ def __repr__(self) -> str: return f"FixedWriter({self._len})" +@dataclass(frozen=True) class BinaryWriter(Writer): """Variable byte length writer.""" @@ -158,11 +166,12 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None: @dataclass(frozen=True) class StructWriter(Writer): - field_writers: Tuple[Writer, ...] = dataclassfield() + field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield() def write(self, encoder: BinaryEncoder, val: Record) -> None: - for writer, value in zip(self.field_writers, val.record_fields()): - writer.write(encoder, value) + for pos, writer in self.field_writers: + # When pos is None, then it is a default value + writer.write(encoder, val[pos] if pos is not None else None) def __eq__(self, other: Any) -> bool: """Implement the equality operator for this object.""" @@ -170,7 +179,7 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """Return string representation of this object.""" - return f"StructWriter({','.join(repr(field) for field in self.field_writers)})" + return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))" def __hash__(self) -> int: """Return the hash of the writer as hash of this object.""" @@ -201,3 +210,12 @@ def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None: self.value_writer.write(encoder, v) if len(val) > 0: encoder.write_int(0) + + +@dataclass(frozen=True) +class DefaultWriter(Writer): + writer: Writer + value: Any + + def write(self, encoder: BinaryEncoder, _: Any) -> None: + self.writer.write(encoder, self.value) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 8bdbfd3524..92ca300f6d 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -58,6 +58,7 @@ UNASSIGNED_SEQ = -1 DEFAULT_BLOCK_SIZE = 67108864 # 64 * 1024 * 1024 +DEFAULT_READ_VERSION: Literal[2] = 2 class DataFileContent(int, Enum): @@ -99,101 +100,185 @@ def __repr__(self) -> str: return f"FileFormat.{self.name}" -DATA_FILE_TYPE_V1 = StructType( - NestedField( - field_id=134, - name="content", - field_type=IntegerType(), - required=False, - doc="Contents of the file: 0=data, 1=position deletes, 2=equality deletes", - initial_default=DataFileContent.DATA, +DATA_FILE_TYPE: Dict[int, StructType] = { + 1: StructType( + NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), + NestedField( + field_id=101, + name="file_format", + field_type=StringType(), + required=True, + doc="File format name: avro, orc, or parquet", + ), + NestedField( + field_id=102, + name="partition", + field_type=StructType(), + required=True, + doc="Partition data tuple, schema based on the partition spec", + ), + NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), + NestedField( + field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes" + ), + NestedField( + field_id=105, + name="block_size_in_bytes", + field_type=LongType(), + required=True, + doc="Deprecated. Always write a default in v1. Do not write in v2.", + write_default=DEFAULT_BLOCK_SIZE, + ), + NestedField( + field_id=108, + name="column_sizes", + field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), + required=False, + doc="Map of column id to total size on disk", + ), + NestedField( + field_id=109, + name="value_counts", + field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), + required=False, + doc="Map of column id to total count, including null and NaN", + ), + NestedField( + field_id=110, + name="null_value_counts", + field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), + required=False, + doc="Map of column id to null value count", + ), + NestedField( + field_id=137, + name="nan_value_counts", + field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), + required=False, + doc="Map of column id to number of NaN values in the column", + ), + NestedField( + field_id=125, + name="lower_bounds", + field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), + required=False, + doc="Map of column id to lower bound", + ), + NestedField( + field_id=128, + name="upper_bounds", + field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), + required=False, + doc="Map of column id to upper bound", + ), + NestedField( + field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob" + ), + NestedField( + field_id=132, + name="split_offsets", + field_type=ListType(element_id=133, element_type=LongType(), element_required=True), + required=False, + doc="Splittable offsets", + ), + NestedField(field_id=140, name="sort_order_id", field_type=IntegerType(), required=False, doc="Sort order ID"), ), - NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), - NestedField( - field_id=101, - name="file_format", - field_type=StringType(), - required=True, - doc="File format name: avro, orc, or parquet", + 2: StructType( + NestedField( + field_id=134, + name="content", + field_type=IntegerType(), + required=True, + doc="File format name: avro, orc, or parquet", + initial_default=DataFileContent.DATA, + ), + NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"), + NestedField( + field_id=101, + name="file_format", + field_type=StringType(), + required=True, + doc="File format name: avro, orc, or parquet", + ), + NestedField( + field_id=102, + name="partition", + field_type=StructType(), + required=True, + doc="Partition data tuple, schema based on the partition spec", + ), + NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), + NestedField( + field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes" + ), + NestedField( + field_id=108, + name="column_sizes", + field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), + required=False, + doc="Map of column id to total size on disk", + ), + NestedField( + field_id=109, + name="value_counts", + field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), + required=False, + doc="Map of column id to total count, including null and NaN", + ), + NestedField( + field_id=110, + name="null_value_counts", + field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), + required=False, + doc="Map of column id to null value count", + ), + NestedField( + field_id=137, + name="nan_value_counts", + field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), + required=False, + doc="Map of column id to number of NaN values in the column", + ), + NestedField( + field_id=125, + name="lower_bounds", + field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), + required=False, + doc="Map of column id to lower bound", + ), + NestedField( + field_id=128, + name="upper_bounds", + field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), + required=False, + doc="Map of column id to upper bound", + ), + NestedField( + field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob" + ), + NestedField( + field_id=132, + name="split_offsets", + field_type=ListType(element_id=133, element_type=LongType(), element_required=True), + required=False, + doc="Splittable offsets", + ), + NestedField( + field_id=135, + name="equality_ids", + field_type=ListType(element_id=136, element_type=LongType(), element_required=True), + required=False, + doc="Field ids used to determine row equality in equality delete files.", + ), + NestedField( + field_id=140, + name="sort_order_id", + field_type=IntegerType(), + required=False, + doc="ID representing sort order for this file", + ), ), - NestedField( - field_id=102, - name="partition", - field_type=StructType(), - required=True, - doc="Partition data tuple, schema based on the partition spec", - ), - NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"), - NestedField(field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes"), - NestedField( - field_id=105, - name="block_size_in_bytes", - field_type=LongType(), - required=False, - doc="Deprecated. Always write a default in v1. Do not write in v2.", - ), - NestedField( - field_id=108, - name="column_sizes", - field_type=MapType(key_id=117, key_type=IntegerType(), value_id=118, value_type=LongType()), - required=False, - doc="Map of column id to total size on disk", - ), - NestedField( - field_id=109, - name="value_counts", - field_type=MapType(key_id=119, key_type=IntegerType(), value_id=120, value_type=LongType()), - required=False, - doc="Map of column id to total count, including null and NaN", - ), - NestedField( - field_id=110, - name="null_value_counts", - field_type=MapType(key_id=121, key_type=IntegerType(), value_id=122, value_type=LongType()), - required=False, - doc="Map of column id to null value count", - ), - NestedField( - field_id=137, - name="nan_value_counts", - field_type=MapType(key_id=138, key_type=IntegerType(), value_id=139, value_type=LongType()), - required=False, - doc="Map of column id to number of NaN values in the column", - ), - NestedField( - field_id=125, - name="lower_bounds", - field_type=MapType(key_id=126, key_type=IntegerType(), value_id=127, value_type=BinaryType()), - required=False, - doc="Map of column id to lower bound", - ), - NestedField( - field_id=128, - name="upper_bounds", - field_type=MapType(key_id=129, key_type=IntegerType(), value_id=130, value_type=BinaryType()), - required=False, - doc="Map of column id to upper bound", - ), - NestedField(field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob"), - NestedField( - field_id=132, - name="split_offsets", - field_type=ListType(element_id=133, element_type=LongType(), element_required=True), - required=False, - doc="Splittable offsets", - ), - NestedField( - field_id=135, - name="equality_ids", - field_type=ListType(element_id=136, element_type=LongType(), element_required=True), - required=False, - doc="Equality comparison field IDs", - ), - NestedField(field_id=140, name="sort_order_id", field_type=IntegerType(), required=False, doc="Sort order ID"), - NestedField(field_id=141, name="spec_id", field_type=IntegerType(), required=False, doc="Partition spec ID"), -) - -DATA_FILE_TYPE_V2 = StructType(*[field for field in DATA_FILE_TYPE_V1.fields if field.field_id != 105]) +} @singledispatch @@ -238,7 +323,7 @@ def data_file_with_partition(partition_type: StructType, format_version: Literal ) if field.field_id == 102 else field - for field in (DATA_FILE_TYPE_V1.fields if format_version == 1 else DATA_FILE_TYPE_V2.fields) + for field in DATA_FILE_TYPE[format_version].fields ] ) @@ -251,7 +336,6 @@ class DataFile(Record): "partition", "record_count", "file_size_in_bytes", - "block_size_in_bytes", "column_sizes", "value_counts", "null_value_counts", @@ -262,7 +346,6 @@ class DataFile(Record): "split_offsets", "equality_ids", "sort_order_id", - "spec_id", ) content: DataFileContent file_path: str @@ -270,7 +353,6 @@ class DataFile(Record): partition: Record record_count: int file_size_in_bytes: int - block_size_in_bytes: Optional[int] column_sizes: Dict[int, int] value_counts: Dict[int, int] null_value_counts: Dict[int, int] @@ -281,7 +363,6 @@ class DataFile(Record): split_offsets: Optional[List[int]] equality_ids: Optional[List[int]] sort_order_id: Optional[int] - spec_id: Optional[int] def __setattr__(self, name: str, value: Any) -> None: """Assign a key/value to a DataFile.""" @@ -290,10 +371,10 @@ def __setattr__(self, name: str, value: Any) -> None: value = FileFormat[value] super().__setattr__(name, value) - def __init__(self, format_version: Literal[1, 2] = 1, *data: Any, **named_data: Any) -> None: + def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None: super().__init__( *data, - **{"struct": DATA_FILE_TYPE_V1 if format_version == 1 else DATA_FILE_TYPE_V2, **named_data}, + **{"struct": DATA_FILE_TYPE[format_version], **named_data}, ) def __hash__(self) -> int: @@ -308,22 +389,29 @@ def __eq__(self, other: Any) -> bool: return self.file_path == other.file_path if isinstance(other, DataFile) else False -MANIFEST_ENTRY_SCHEMA = Schema( - NestedField(0, "status", IntegerType(), required=True), - NestedField(1, "snapshot_id", LongType(), required=False), - NestedField(3, "data_sequence_number", LongType(), required=False), - NestedField(4, "file_sequence_number", LongType(), required=False), - NestedField(2, "data_file", DATA_FILE_TYPE_V1, required=True), -) +MANIFEST_ENTRY_SCHEMAS = { + 1: Schema( + NestedField(0, "status", IntegerType(), required=True), + NestedField(1, "snapshot_id", LongType(), required=True), + NestedField(2, "data_file", DATA_FILE_TYPE[1], required=True), + ), + 2: Schema( + NestedField(0, "status", IntegerType(), required=True), + NestedField(1, "snapshot_id", LongType(), required=False), + NestedField(3, "data_sequence_number", LongType(), required=False), + NestedField(4, "file_sequence_number", LongType(), required=False), + NestedField(2, "data_file", DATA_FILE_TYPE[2], required=True), + ), +} -MANIFEST_ENTRY_SCHEMA_STRUCT = MANIFEST_ENTRY_SCHEMA.as_struct() +MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()} -def manifest_entry_schema_with_data_file(data_file: StructType) -> Schema: +def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema: return Schema( *[ NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field - for field in MANIFEST_ENTRY_SCHEMA.fields + for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields ] ) @@ -337,7 +425,7 @@ class ManifestEntry(Record): data_file: DataFile def __init__(self, *data: Any, **named_data: Any) -> None: - super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMA_STRUCT, **named_data}) + super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMAS_STRUCT[DEFAULT_READ_VERSION], **named_data}) PARTITION_FIELD_SUMMARY_TYPE = StructType( @@ -489,7 +577,7 @@ def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> List input_file = io.new_input(self.manifest_path) with AvroFile[ManifestEntry]( input_file, - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], read_types={-1: ManifestEntry, 2: DataFile}, read_enums={0: ManifestEntryStatus, 101: FileFormat, 134: DataFileContent}, ) as reader: @@ -603,10 +691,26 @@ def __exit__( def content(self) -> ManifestContent: ... + @property @abstractmethod - def new_writer(self) -> AvroOutputFile[ManifestEntry]: + def version(self) -> Literal[1, 2]: ... + def _with_partition(self, format_version: Literal[1, 2]) -> Schema: + data_file_type = data_file_with_partition( + format_version=format_version, partition_type=self._spec.partition_type(self._schema) + ) + return manifest_entry_schema_with_data_file(format_version=format_version, data_file=data_file_type) + + def new_writer(self) -> AvroOutputFile[ManifestEntry]: + return AvroOutputFile[ManifestEntry]( + output_file=self._output_file, + file_schema=self._with_partition(self.version), + record_schema=self._with_partition(DEFAULT_READ_VERSION), + schema_name="manifest_entry", + metadata=self._meta, + ) + @abstractmethod def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: ... @@ -678,15 +782,12 @@ def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, def content(self) -> ManifestContent: return ManifestContent.DATA - def new_writer(self) -> AvroOutputFile[ManifestEntry]: - v1_data_file_type = data_file_with_partition(self._spec.partition_type(self._schema), format_version=1) - v1_manifest_entry_schema = manifest_entry_schema_with_data_file(v1_data_file_type) - return AvroOutputFile[ManifestEntry](self._output_file, v1_manifest_entry_schema, "manifest_entry", self._meta) + @property + def version(self) -> Literal[1, 2]: + return 1 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: - wrapped_entry = ManifestEntry(*entry.record_fields()) - wrapped_entry.data_file.block_size_in_bytes = DEFAULT_BLOCK_SIZE - return wrapped_entry + return entry class ManifestWriterV2(ManifestWriter): @@ -708,10 +809,9 @@ def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, def content(self) -> ManifestContent: return ManifestContent.DATA - def new_writer(self) -> AvroOutputFile[ManifestEntry]: - v2_data_file_type = data_file_with_partition(self._spec.partition_type(self._schema), format_version=2) - v2_manifest_entry_schema = manifest_entry_schema_with_data_file(v2_data_file_type) - return AvroOutputFile[ManifestEntry](self._output_file, v2_manifest_entry_schema, "manifest_entry", self._meta) + @property + def version(self) -> Literal[1, 2]: + return 2 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: if entry.data_sequence_number is None: @@ -719,35 +819,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: raise ValueError(f"Found unassigned sequence number for an entry from snapshot: {entry.snapshot_id}") if entry.status != ManifestEntryStatus.ADDED: raise ValueError("Only entries with status ADDED can have null sequence number") - # In v2, we should not write block_size_in_bytes field - wrapped_data_file_v2_debug = DataFile( - format_version=2, - content=entry.data_file.content, - file_path=entry.data_file.file_path, - file_format=entry.data_file.file_format, - partition=entry.data_file.partition, - record_count=entry.data_file.record_count, - file_size_in_bytes=entry.data_file.file_size_in_bytes, - column_sizes=entry.data_file.column_sizes, - value_counts=entry.data_file.value_counts, - null_value_counts=entry.data_file.null_value_counts, - nan_value_counts=entry.data_file.nan_value_counts, - lower_bounds=entry.data_file.lower_bounds, - upper_bounds=entry.data_file.upper_bounds, - key_metadata=entry.data_file.key_metadata, - split_offsets=entry.data_file.split_offsets, - equality_ids=entry.data_file.equality_ids, - sort_order_id=entry.data_file.sort_order_id, - spec_id=entry.data_file.spec_id, - ) - wrapped_entry = ManifestEntry( - status=entry.status, - snapshot_id=entry.snapshot_id, - data_sequence_number=entry.data_sequence_number, - file_sequence_number=entry.file_sequence_number, - data_file=wrapped_data_file_v2_debug, - ) - return wrapped_entry + return entry def write_manifest( @@ -775,7 +847,9 @@ def __init__(self, output_file: OutputFile, meta: Dict[str, str]): def __enter__(self) -> ManifestListWriter: """Open the writer for writing.""" - self._writer = AvroOutputFile[ManifestFile](self._output_file, MANIFEST_FILE_SCHEMA, "manifest_file", self._meta) + self._writer = AvroOutputFile[ManifestFile]( + output_file=self._output_file, file_schema=MANIFEST_FILE_SCHEMA, schema_name="manifest_file", metadata=self._meta + ) self._writer.__enter__() return self diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 12ea831f08..b8fdedea51 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -51,7 +51,7 @@ from pydantic_core.core_schema import ValidatorFunctionWrapHandler from pyiceberg.exceptions import ValidationError -from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel +from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L from pyiceberg.utils.parsing import ParseNumberFromBrackets from pyiceberg.utils.singleton import Singleton @@ -282,6 +282,7 @@ class NestedField(IcebergType): required: bool = Field(default=True) doc: Optional[str] = Field(default=None, repr=False) initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) + write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore def __init__( self, @@ -291,6 +292,7 @@ def __init__( required: bool = True, doc: Optional[str] = None, initial_default: Optional[Any] = None, + write_default: Optional[L] = None, **data: Any, ): # We need an init when we want to use positional arguments, but @@ -301,6 +303,7 @@ def __init__( data["required"] = required data["doc"] = doc data["initial-default"] = initial_default + data["write-default"] = write_default super().__init__(**data) def __str__(self) -> str: diff --git a/tests/avro/test_decoder.py b/tests/avro/test_decoder.py index fd660247cd..bbcc7394f4 100644 --- a/tests/avro/test_decoder.py +++ b/tests/avro/test_decoder.py @@ -27,7 +27,7 @@ from pyiceberg.avro.decoder import BinaryDecoder, StreamingBinaryDecoder, new_decoder from pyiceberg.avro.decoder_fast import CythonBinaryDecoder -from pyiceberg.avro.resolver import resolve +from pyiceberg.avro.resolver import resolve_reader from pyiceberg.io import InputStream from pyiceberg.types import DoubleType, FloatType @@ -194,7 +194,7 @@ def test_skip_utf8(decoder_class: Callable[[bytes], BinaryDecoder]) -> None: @pytest.mark.parametrize("decoder_class", AVAILABLE_DECODERS) def test_read_int_as_float(decoder_class: Callable[[bytes], BinaryDecoder]) -> None: decoder = decoder_class(b"\x00\x00\x9A\x41") - reader = resolve(FloatType(), DoubleType()) + reader = resolve_reader(FloatType(), DoubleType()) assert reader.read(decoder) == 19.25 diff --git a/tests/avro/test_file.py b/tests/avro/test_file.py index e9dcc7eca1..518026cc4f 100644 --- a/tests/avro/test_file.py +++ b/tests/avro/test_file.py @@ -30,7 +30,8 @@ from pyiceberg.avro.file import META_SCHEMA, AvroFileHeader from pyiceberg.io.pyarrow import PyArrowFileIO from pyiceberg.manifest import ( - MANIFEST_ENTRY_SCHEMA, + DEFAULT_BLOCK_SIZE, + MANIFEST_ENTRY_SCHEMAS, DataFile, DataFileContent, FileFormat, @@ -116,7 +117,7 @@ def todict(obj: Any) -> Any: return obj -def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: +def test_write_manifest_entry_with_iceberg_read_with_fastavro_v1() -> None: data_file = DataFile( content=DataFileContent.DATA, file_path="s3://some-path/some-file.parquet", @@ -124,7 +125,6 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: partition=Record(), record_count=131327, file_size_in_bytes=220669226, - block_size_in_bytes=67108864, column_sizes={1: 220661854}, value_counts={1: 131327}, null_value_counts={1: 0}, @@ -135,7 +135,6 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: split_offsets=[4, 133697593], equality_ids=[], sort_order_id=4, - spec_id=3, ) entry = ManifestEntry( status=ManifestEntryStatus.ADDED, @@ -151,7 +150,76 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: tmp_avro_file = tmpdir + "/manifest_entry.avro" with avro.AvroOutputFile[ManifestEntry]( - PyArrowFileIO().new_output(tmp_avro_file), MANIFEST_ENTRY_SCHEMA, "manifest_entry", additional_metadata + output_file=PyArrowFileIO().new_output(tmp_avro_file), + file_schema=MANIFEST_ENTRY_SCHEMAS[1], + schema_name="manifest_entry", + record_schema=MANIFEST_ENTRY_SCHEMAS[2], + metadata=additional_metadata, + ) as out: + out.write_block([entry]) + + with open(tmp_avro_file, "rb") as fo: + r = reader(fo=fo) + + for k, v in additional_metadata.items(): + assert k in r.metadata + assert v == r.metadata[k] + + it = iter(r) + + fa_entry = next(it) + + v2_entry = todict(entry) + + # These are not written in V1 + del v2_entry['data_sequence_number'] + del v2_entry['file_sequence_number'] + del v2_entry['data_file']['content'] + del v2_entry['data_file']['equality_ids'] + + # Required in V1 + v2_entry['data_file']['block_size_in_bytes'] = DEFAULT_BLOCK_SIZE + + assert v2_entry == fa_entry + + +def test_write_manifest_entry_with_iceberg_read_with_fastavro_v2() -> None: + data_file = DataFile( + content=DataFileContent.DATA, + file_path="s3://some-path/some-file.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=131327, + file_size_in_bytes=220669226, + column_sizes={1: 220661854}, + value_counts={1: 131327}, + null_value_counts={1: 0}, + nan_value_counts={}, + lower_bounds={1: b"aaaaaaaaaaaaaaaa"}, + upper_bounds={1: b"zzzzzzzzzzzzzzzz"}, + key_metadata=b"\xde\xad\xbe\xef", + split_offsets=[4, 133697593], + equality_ids=[], + sort_order_id=4, + ) + entry = ManifestEntry( + status=ManifestEntryStatus.ADDED, + snapshot_id=8638475580105682862, + data_sequence_number=0, + file_sequence_number=0, + data_file=data_file, + ) + + additional_metadata = {"foo": "bar"} + + with TemporaryDirectory() as tmpdir: + tmp_avro_file = tmpdir + "/manifest_entry.avro" + + with avro.AvroOutputFile[ManifestEntry]( + output_file=PyArrowFileIO().new_output(tmp_avro_file), + file_schema=MANIFEST_ENTRY_SCHEMAS[2], + schema_name="manifest_entry", + metadata=additional_metadata, ) as out: out.write_block([entry]) @@ -169,7 +237,8 @@ def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: assert todict(entry) == fa_entry -def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: +@pytest.mark.parametrize("format_version", [1, 2]) +def test_write_manifest_entry_with_fastavro_read_with_iceberg(format_version: int) -> None: data_file = DataFile( content=DataFileContent.DATA, file_path="s3://some-path/some-file.parquet", @@ -187,8 +256,10 @@ def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: split_offsets=[4, 133697593], equality_ids=[], sort_order_id=4, - spec_id=3, ) + if format_version == 1: + data_file.block_size_in_bytes = DEFAULT_BLOCK_SIZE + entry = ManifestEntry( status=ManifestEntryStatus.ADDED, snapshot_id=8638475580105682862, @@ -200,14 +271,14 @@ def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: with TemporaryDirectory() as tmpdir: tmp_avro_file = tmpdir + "/manifest_entry.avro" - schema = AvroSchemaConversion().iceberg_to_avro(MANIFEST_ENTRY_SCHEMA, schema_name="manifest_entry") + schema = AvroSchemaConversion().iceberg_to_avro(MANIFEST_ENTRY_SCHEMAS[format_version], schema_name="manifest_entry") with open(tmp_avro_file, "wb") as out: writer(out, schema, [todict(entry)]) with avro.AvroFile[ManifestEntry]( PyArrowFileIO().new_input(tmp_avro_file), - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[format_version], {-1: ManifestEntry, 2: DataFile}, ) as avro_reader: it = iter(avro_reader) diff --git a/tests/avro/test_reader.py b/tests/avro/test_reader.py index a3a502bcff..48ee8911da 100644 --- a/tests/avro/test_reader.py +++ b/tests/avro/test_reader.py @@ -41,7 +41,7 @@ ) from pyiceberg.avro.resolver import construct_reader from pyiceberg.io.pyarrow import PyArrowFileIO -from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMA, DataFile, ManifestEntry +from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMAS, DataFile, ManifestEntry from pyiceberg.schema import Schema from pyiceberg.typedef import Record from pyiceberg.types import ( @@ -70,7 +70,7 @@ def test_read_header(generated_manifest_entry_file: str, iceberg_manifest_entry_schema: Schema) -> None: with AvroFile[ManifestEntry]( PyArrowFileIO().new_input(generated_manifest_entry_file), - MANIFEST_ENTRY_SCHEMA, + MANIFEST_ENTRY_SCHEMAS[2], {-1: ManifestEntry, 2: DataFile}, ) as reader: header = reader.header diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index a302294755..51d2a7d8fc 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -32,12 +32,25 @@ StringReader, StructReader, ) -from pyiceberg.avro.resolver import ResolveError, resolve +from pyiceberg.avro.resolver import ResolveError, resolve_reader, resolve_writer +from pyiceberg.avro.writer import ( + BinaryWriter, + DefaultWriter, + DoubleWriter, + IntegerWriter, + ListWriter, + MapWriter, + OptionWriter, + StringWriter, + StructWriter, +) from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.manifest import MANIFEST_ENTRY_SCHEMAS from pyiceberg.schema import Schema from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, + BooleanType, DecimalType, DoubleType, FloatType, @@ -81,7 +94,7 @@ def test_resolver() -> None: NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())), schema_id=1, ) - read_tree = resolve(write_schema, read_schema) + read_tree = resolve_reader(write_schema, read_schema) assert read_tree == StructReader( ( @@ -117,7 +130,7 @@ def test_resolver_new_required_field() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "2: data: required string is non-optional, and not part of the file schema" in str(exc_info.value) @@ -133,7 +146,7 @@ def test_resolver_invalid_evolution() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "Cannot promote long to double" in str(exc_info.value) @@ -147,7 +160,7 @@ def test_resolver_promotion_string_to_binary() -> None: NestedField(1, "id", BinaryType()), schema_id=1, ) - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) def test_resolver_promotion_binary_to_string() -> None: @@ -159,7 +172,7 @@ def test_resolver_promotion_binary_to_string() -> None: NestedField(1, "id", StringType()), schema_id=1, ) - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) def test_resolver_change_type() -> None: @@ -173,69 +186,69 @@ def test_resolver_change_type() -> None: ) with pytest.raises(ResolveError) as exc_info: - resolve(write_schema, read_schema) + resolve_reader(write_schema, read_schema) assert "File/read schema are not aligned for list, got map" in str(exc_info.value) def test_resolve_int_to_long() -> None: - assert resolve(IntegerType(), LongType()) == IntegerReader() + assert resolve_reader(IntegerType(), LongType()) == IntegerReader() def test_resolve_float_to_double() -> None: # We should still read floats, because it is encoded in 4 bytes - assert resolve(FloatType(), DoubleType()) == FloatReader() + assert resolve_reader(FloatType(), DoubleType()) == FloatReader() def test_resolve_decimal_to_decimal() -> None: # DecimalType(P, S) to DecimalType(P2, S) where P2 > P - assert resolve(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(19, 25) + assert resolve_reader(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(19, 25) def test_struct_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(StructType(), StringType()) + assert resolve_reader(StructType(), StringType()) def test_map_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(MapType(1, StringType(), 2, IntegerType()), StringType()) + assert resolve_reader(MapType(1, StringType(), 2, IntegerType()), StringType()) def test_primitive_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(IntegerType(), MapType(1, StringType(), 2, IntegerType())) + assert resolve_reader(IntegerType(), MapType(1, StringType(), 2, IntegerType())) def test_integer_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(IntegerType(), StringType()) + assert resolve_reader(IntegerType(), StringType()) def test_float_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(FloatType(), StringType()) + assert resolve_reader(FloatType(), StringType()) def test_string_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(StringType(), FloatType()) + assert resolve_reader(StringType(), FloatType()) def test_binary_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(BinaryType(), FloatType()) + assert resolve_reader(BinaryType(), FloatType()) def test_decimal_not_aligned() -> None: with pytest.raises(ResolveError): - assert resolve(DecimalType(22, 19), StringType()) + assert resolve_reader(DecimalType(22, 19), StringType()) def test_resolve_decimal_to_decimal_reduce_precision() -> None: # DecimalType(P, S) to DecimalType(P2, S) where P2 > P with pytest.raises(ResolveError) as exc_info: - _ = resolve(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25) + _ = resolve_reader(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25) assert "Cannot reduce precision from decimal(19, 25) to decimal(10, 25)" in str(exc_info.value) @@ -293,7 +306,7 @@ def test_resolver_initial_value() -> None: schema_id=2, ) - assert resolve(write_schema, read_schema) == StructReader( + assert resolve_reader(write_schema, read_schema) == StructReader( ( (None, StringReader()), # The one we skip (0, DefaultReader("vo")), @@ -301,3 +314,95 @@ def test_resolver_initial_value() -> None: Record, read_schema.as_struct(), ) + + +def test_resolve_writer() -> None: + actual = resolve_writer(record_schema=MANIFEST_ENTRY_SCHEMAS[2], file_schema=MANIFEST_ENTRY_SCHEMAS[1]) + expected = StructWriter( + ( + (0, IntegerWriter()), + (1, IntegerWriter()), + ( + 4, + StructWriter( + ( + (1, StringWriter()), + (2, StringWriter()), + (3, StructWriter(())), + (4, IntegerWriter()), + (5, IntegerWriter()), + (None, DefaultWriter(writer=IntegerWriter(), value=67108864)), + (6, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (7, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (8, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (9, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (10, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (11, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (12, OptionWriter(option=BinaryWriter())), + (13, OptionWriter(option=ListWriter(element_writer=IntegerWriter()))), + (15, OptionWriter(option=IntegerWriter())), + ) + ), + ), + ) + ) + + assert actual == expected + + +def test_resolve_writer_promotion() -> None: + with pytest.raises(ResolveError) as exc_info: + _ = resolve_writer( + record_schema=Schema(NestedField(field_id=1, name="floating", type=DoubleType(), required=True)), + file_schema=Schema(NestedField(field_id=1, name="floating", type=FloatType(), required=True)), + ) + + assert "Cannot promote double to float" in str(exc_info.value) + + +def test_writer_ordering() -> None: + actual = resolve_writer( + record_schema=Schema( + NestedField(field_id=1, name="str", type=StringType(), required=True), + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + ), + file_schema=Schema( + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + ), + ) + + expected = StructWriter(((1, DoubleWriter()), (0, StringWriter()))) + + assert actual == expected + + +def test_writer_one_more_field() -> None: + actual = resolve_writer( + record_schema=Schema( + NestedField(field_id=3, name="bool", type=BooleanType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + ), + file_schema=Schema( + NestedField(field_id=2, name="dbl", type=DoubleType(), required=True), + NestedField(field_id=1, name="str", type=StringType(), required=True), + ), + ) + + expected = StructWriter(((2, DoubleWriter()), (1, StringWriter()))) + + assert actual == expected + + +def test_writer_missing_optional_in_read_schema() -> None: + actual = resolve_writer( + record_schema=Schema(), + file_schema=Schema( + NestedField(field_id=1, name="str", type=StringType(), required=False), + ), + ) + + expected = StructWriter(field_writers=((None, OptionWriter(option=StringWriter())),)) + + assert actual == expected diff --git a/tests/test_integration_manifest.py b/tests/test_integration_manifest.py index 34b20f271d..475e0d40a6 100644 --- a/tests/test_integration_manifest.py +++ b/tests/test_integration_manifest.py @@ -101,7 +101,6 @@ def test_write_sample_manifest(table_test_all_types: Table) -> None: split_offsets=entry.data_file.split_offsets, equality_ids=entry.data_file.equality_ids, sort_order_id=entry.data_file.sort_order_id, - spec_id=entry.data_file.spec_id, ) wrapped_entry_v2 = ManifestEntry(*entry.record_fields()) wrapped_entry_v2.data_file = wrapped_data_file_v2_debug diff --git a/tests/test_types.py b/tests/test_types.py index 249ee98a6f..6aed56c58f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -18,6 +18,7 @@ import pickle from typing import Type +import pydantic_core import pytest from pyiceberg.exceptions import ValidationError @@ -208,6 +209,11 @@ def test_nested_field() -> None: assert str(field_var) == str(eval(repr(field_var))) assert field_var == pickle.loads(pickle.dumps(field_var)) + with pytest.raises(pydantic_core.ValidationError) as exc_info: + _ = (NestedField(1, "field", StringType(), required=True, write_default=(1, "a", True)),) # type: ignore + + assert "validation errors for NestedField" in str(exc_info.value) + @pytest.mark.parametrize("input_index,input_type", non_parameterized_types) @pytest.mark.parametrize("check_index,check_type", non_parameterized_types) diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 41af844bba..4d55610c74 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme from tempfile import TemporaryDirectory -from typing import Dict +from typing import Dict, Literal import fastavro import pytest @@ -303,7 +303,9 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None: @pytest.mark.parametrize("format_version", [1, 2]) -def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: int) -> None: +def test_write_manifest( + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] +) -> None: io = load_file_io() snapshot = Snapshot( snapshot_id=25, @@ -327,7 +329,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest tmp_avro_file = tmpdir + "/test_write_manifest.avro" output = io.new_output(tmp_avro_file) with write_manifest( - format_version=format_version, # type: ignore + format_version=format_version, spec=test_spec, schema=test_schema, output_file=output, @@ -337,6 +339,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest writer.add_entry(entry) new_manifest = writer.to_manifest_file() with pytest.raises(RuntimeError): + # It is already closed writer.add_entry(manifest_entries[0]) expected_metadata = { @@ -345,8 +348,6 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest "partition-spec-id": str(test_spec.spec_id), "format-version": str(format_version), } - if format_version == 2: - expected_metadata["content"] = "data" _verify_metadata_with_fastavro( tmp_avro_file, expected_metadata, @@ -357,7 +358,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest assert manifest_entry.status == ManifestEntryStatus.ADDED assert manifest_entry.snapshot_id == 8744736658442914487 - assert manifest_entry.data_sequence_number == 0 if format_version == 1 else 3 + assert manifest_entry.data_sequence_number == -1 if format_version == 1 else 3 assert isinstance(manifest_entry.data_file, DataFile) data_file = manifest_entry.data_file @@ -371,10 +372,6 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest assert data_file.partition == Record(VendorID=1, tpep_pickup_datetime=1925) assert data_file.record_count == 19513 assert data_file.file_size_in_bytes == 388872 - if format_version == 1: - assert data_file.block_size_in_bytes == 67108864 - else: - assert data_file.block_size_in_bytes is None assert data_file.column_sizes == { 1: 53, 2: 98153, @@ -477,7 +474,7 @@ def test_write_manifest(generated_manifest_file_file_v1: str, generated_manifest @pytest.mark.parametrize("format_version", [1, 2]) def test_write_manifest_list( - generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: int + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] ) -> None: io = load_file_io() @@ -495,7 +492,7 @@ def test_write_manifest_list( path = tmp_dir + "/manifest-list.avro" output = io.new_output(path) with write_manifest_list( - format_version=format_version, output_file=output, snapshot_id=25, parent_snapshot_id=19, sequence_number=0 # type: ignore + format_version=format_version, output_file=output, snapshot_id=25, parent_snapshot_id=19, sequence_number=0 ) as writer: writer.add_manifests(demo_manifest_list) new_manifest_list = list(read_manifest_list(io.new_input(path)))