diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0b3f6306..e2e1789e 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -847,7 +847,7 @@ dependencies = [ [[package]] name = "sedpack_rs" -version = "0.1.3" +version = "0.1.4" dependencies = [ "criterion", "flatbuffers", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index e4be84c2..b0f718d8 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sedpack_rs" -version = "0.1.3" +version = "0.1.4" edition = "2021" description = "Rust bindings for sedpack a general ML dataset package" authors = [ diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index 8de9c7a7..97c1e699 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -26,7 +26,7 @@ from sedpack.io.compress import CompressedFile from sedpack.io.metadata import Attribute -from sedpack.io.types import ExampleT +from sedpack.io.types import AttributeValueT, ExampleT from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.utils import func_or_identity @@ -82,18 +82,16 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]: attribute=attribute, ) - # Copy otherwise the arrays are immutable and keep the whole - # file content from being garbage collected. - np_array = np.copy(np_array) - example_dictionary[attribute.name] = np_array yield example_dictionary @staticmethod - def decode_array(np_bytes: npt.NDArray[np.uint8], - attribute: Attribute, - batch_size: int = 0) -> npt.NDArray[np.generic]: + def decode_array( + np_bytes: npt.NDArray[np.uint8], + attribute: Attribute, + batch_size: int = 0, + ) -> AttributeValueT: """Decode an array. See `sedpack.io.shard.shard_writer_flatbuffer .ShardWriterFlatBuffer.save_numpy_vector_as_bytearray` for format description. The code tries to avoid unnecessary copies. @@ -112,6 +110,22 @@ def decode_array(np_bytes: npt.NDArray[np.uint8], Returns: the parsed np.ndarray of the correct dtype and shape. """ + match attribute.dtype: + case "str": + return np_bytes.tobytes().decode("utf-8") + case "bytes": + return np_bytes.tobytes() + case "int": + array = np.frombuffer( + buffer=np_bytes, + dtype=np.dtype("int64").newbyteorder("<"), + ) + assert array.shape == (1,), f"{array.shape = }" + return int(array[0]) + case _: + # The rest is interpreted as NumPy array. + pass + dt = np.dtype(attribute.dtype) # FlatBuffers are little-endian. There is no byteswap by # `np.frombuffer` but the array will be interpreted correctly. diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index c8ce4200..2de5fd05 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import aiofiles import numpy as np +from sedpack.io.metadata import Attribute from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.types import AttributeValueT, ExampleT @@ -32,20 +33,129 @@ class IterateShardNP(IterateShardBase[T]): """Iterate a shard saved in the npz format. """ + @staticmethod + def decode_attribute( + attribute: Attribute, + example_index: int, + prefixed_name: str, + shard_content: dict[str, list[AttributeValueT]], + ) -> AttributeValueT: + """Choose the correct way to decode the given attribute. + + Args: + + attribute (Attribute): Information about the attribute being decoded. + + example_index (int): Which example from this shard is being decoded. + + prefixed_name (str): For the case of `bytes` attributes we need to + store them in continuous array otherwise variable length would require + allowing pickling and result in a potential arbitrary code execution. + This name is the prefix-sum encoded lengths of the attribute values. + + shard_content (dict[str, list[AttributeValueT]]): The shard values. + """ + if attribute.dtype == "bytes": + return IterateShardNP.decode_bytes_attribute( + value=shard_content[attribute.name][0], + indexes=shard_content[prefixed_name], + attribute=attribute, + index=example_index, + ) + + return IterateShardNP.decode_non_bytes_attribute( + np_value=shard_content[attribute.name][example_index], + attribute=attribute, + ) + + @staticmethod + def decode_non_bytes_attribute( + np_value: AttributeValueT, + attribute: Attribute, + ) -> AttributeValueT: + match attribute.dtype: + case "str": + return str(np_value) + case "bytes": + raise ValueError("One needs to use decode_bytes_attribute") + case "int": + return int(np_value) + case _: + return np_value + + @staticmethod + def decode_bytes_attribute( + value: AttributeValueT, + indexes: list[AttributeValueT], + attribute: Attribute, + index: int, + ) -> AttributeValueT: + """Decode a bytes attribute. We are saving the byte attributes as a + continuous array across multiple examples and on the side we also save + the indexes into this array. + + Args: + + value (AttributeValueT): The NumPy array of np.uint8 containing + concatenated bytes values. + + indexes (list[AttributeValueT]): Indexes into this array. + + attribute (Attribute): The attribute description. + + index (int): Which example out of this shard to return. + """ + if attribute.dtype != "bytes": + raise ValueError("One needs to use decode_attribute") + # Help with type-checking: + my_value = np.array(value, np.uint8) + my_indexes = np.array(indexes, np.int64) + del value + del indexes + + begin: int = my_indexes[index] + end: int = my_indexes[index + 1] + return bytes(my_value[begin:end]) + def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """Iterate a shard saved in the NumPy format npz. """ - shard_content: dict[str, list[AttributeValueT]] = np.load(file_path) + # A prefix such that prepended it creates a new name without collision + # with any attribute name. + counting_prefix: str = "len" + "_" * max( + len(attribute.name) + for attribute in self.dataset_structure.saved_data_description) + self._prefixed_names: dict[str, str] = { + attribute.name: counting_prefix + attribute.name + for attribute in self.dataset_structure.saved_data_description + } + + shard_content: dict[str, list[AttributeValueT]] = np.load( + file_path, + allow_pickle=False, + ) # A given shard contains the same number of elements for each # attribute. - elements: int = 0 - for values in shard_content.values(): - elements = len(values) - break + elements: int + first_attribute = self.dataset_structure.saved_data_description[0] + if self._prefixed_names[first_attribute.name] in shard_content: + elements = len( + shard_content[self._prefixed_names[first_attribute.name]]) - 1 + else: + elements = len(shard_content[first_attribute.name]) - for i in range(elements): - yield {name: value[i] for name, value in shard_content.items()} + for example_index in range(elements): + yield { + attribute.name: + IterateShardNP.decode_attribute( + attribute=attribute, + example_index=example_index, + prefixed_name=self._prefixed_names[attribute.name], + shard_content=shard_content, + ) + for attribute in self.dataset_structure.saved_data_description + } # TODO(issue #85) fix and test async iterator typing async def iterate_shard_async( # pylint: disable=invalid-overridden-method @@ -58,7 +168,10 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method content_bytes: bytes = await f.read() content_io = io.BytesIO(content_bytes) - shard_content: dict[str, list[AttributeValueT]] = np.load(content_io) + shard_content: dict[str, list[AttributeValueT]] = np.load( + content_io, + allow_pickle=False, + ) # A given shard contains the same number of elements for each # attribute. @@ -67,8 +180,11 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method elements = len(values) break - for i in range(elements): - yield {name: value[i] for name, value in shard_content.items()} + for example_index in range(elements): + yield { + name: value[example_index] + for name, value in shard_content.items() + } def process_and_list(self, shard_file: Path) -> list[T]: process_record = func_or_identity(self.process_record) diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index ed47cb5b..0a8ad7c1 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -101,45 +101,28 @@ def _write(self, values: ExampleT) -> None: self._examples.append(fbapi_Example.ExampleEnd(self._builder)) @staticmethod - def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] - builder: Builder, attribute: Attribute, - value: AttributeValueT) -> int: - """Save a given array into a FlatBuffer as bytes. This is to ensure - compatibility with types which are not supported by FlatBuffers (e.g., - np.float16). The FlatBuffers schema must mark this vector as type - bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a - distinction of how the length is being saved. The inverse of this - function is - `sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`. - - If we have an array of np.int32 of 10 elements the FlatBuffers library - would save the length as 10. Which is then impossible to read in Rust - since the length and itemsize (sizeof of the type) are private. Thus we - could not get the full array back. Thus we are saving the array as 40 - bytes. This function does not modify the `value`, and saves a flattened - version of it. This function also saves the exact dtype as given by - `attribute`. Bytes are being saved in little endian ("<") and - c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to - `dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`. + def _np_to_bytes( # type: ignore[no-any-unimported] + builder: Builder, + attribute: Attribute, + value: AttributeValueT, + ) -> bytes: + """Turn a single AttributeValueT value into a sequence of bytes to be + saved in a FlatBuffer shard. Args: - builder (flatbuffers.Builder): The byte buffer being constructed. - Must be initialized. + builder (Builder): The FlatBuffer builder. - attribute (Attribute): Description of this attribute (shape and - dtype). + attribute (Attribute): Description of the attribute defining dtype and + shape. - value (AttributeValueT): The array to be saved. The shape should be - as defined in `attribute` (will be flattened). + value (AttributeValueT): The actual value which is being represented. - Returns: The offset returned by `flatbuffers.Builder.EndVector`. + Raises: ValueError if the value cannot be cast to the correct dtype with + the correct byteorder. """ - # Not sure about flatbuffers.Builder __bool__ semantics. assert builder is not None - # See `flatbuffers.builder.Builder.CreateNumpyVector`. - # Copy the value in order not to modify the original and flatten for # better saving. value_flattened = np.copy(value).flatten() @@ -178,7 +161,75 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] f" {attribute.name}") # This is going to be saved, ensure c_contiguous ordering. - byte_representation = value_np.tobytes(order="C") + return value_np.tobytes(order="C") + + @staticmethod + def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] + builder: Builder, + attribute: Attribute, + value: AttributeValueT, + ) -> int: + """Save a given array into a FlatBuffer as bytes. This is to ensure + compatibility with types which are not supported by FlatBuffers (e.g., + np.float16). The FlatBuffers schema must mark this vector as type + bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a + distinction of how the length is being saved. The inverse of this + function is + `sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`. + + If we have an array of np.int32 of 10 elements the FlatBuffers library + would save the length as 10. Which is then impossible to read in Rust + since the length and itemsize (sizeof of the type) are private. Thus we + could not get the full array back. Thus we are saving the array as 40 + bytes. This function does not modify the `value`, and saves a flattened + version of it. This function also saves the exact dtype as given by + `attribute`. Bytes are being saved in little endian ("<") and + c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to + `dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`. + + Args: + + builder (flatbuffers.Builder): The byte buffer being constructed. + Must be initialized. + + attribute (Attribute): Description of this attribute (shape and + dtype). + + value (AttributeValueT): The array to be saved. The shape should be + as defined in `attribute` (will be flattened). + + Returns: The offset returned by `flatbuffers.Builder.EndVector`. + + Raises: ValueError if the dtype is unknown to NumPy. + """ + # Not sure about flatbuffers.Builder __bool__ semantics. + assert builder is not None + + # See `flatbuffers.builder.Builder.CreateNumpyVector`. + + byte_representation: bytes + alignment: int + match attribute.dtype: + case "str": + byte_representation = str(value).encode("utf-8") + alignment = 16 + case "bytes": + byte_representation = np.array(value).tobytes() + alignment = 16 + case _: + try: + _ = np.dtype(attribute.dtype) + except Exception as exc: + raise ValueError( + f"Unsupported dtype {attribute.dtype}") from exc + + byte_representation = ShardWriterFlatBuffer._np_to_bytes( + builder=builder, + attribute=attribute, + value=value, + ) + value_np = np.array(value, dtype=attribute.dtype) + alignment = value_np.dtype.itemsize # Total length of the array (in bytes). length: int = len(byte_representation) @@ -187,8 +238,7 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] builder.StartVector( elemSize=1, # Storing bytes. numElems=length, - alignment=value_np.dtype. - itemsize, # Cautious alignment of the array. + alignment=alignment, # Cautious alignment of the array. ) builder.head = int(builder.Head() - length) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 636a0cd8..7c8fcdf7 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -20,8 +20,9 @@ from pathlib import Path import numpy as np +from numpy import typing as npt -from sedpack.io.metadata import DatasetStructure +from sedpack.io.metadata import Attribute, DatasetStructure from sedpack.io.types import AttributeValueT, CompressionT, ExampleT from sedpack.io.shard.shard_writer_base import ShardWriterBase @@ -50,6 +51,26 @@ def __init__(self, dataset_structure: DatasetStructure, self._buffer: dict[str, list[AttributeValueT]] = {} + # A prefix such that prepended it creates a new name without collision + # with any attribute name. + self._counting_prefix: str = "len" + "_" * max( + len(attribute.name) + for attribute in dataset_structure.saved_data_description) + + def _value_to_np( + self, + attribute: Attribute, + value: AttributeValueT, + ) -> npt.NDArray[np.generic] | str: + match attribute.dtype: + case "bytes": + raise ValueError("Attributes bytes are saved extra") + case "str": + assert isinstance(value, str) + return value + case _: + return np.copy(value) + def _write(self, values: ExampleT) -> None: """Write an example on disk. Writing may be buffered. @@ -59,12 +80,42 @@ def _write(self, values: ExampleT) -> None: """ # Just buffer all values. if not self._buffer: - self._buffer = { - name: [np.copy(value)] for name, value in values.items() - } - else: - for name, value in values.items(): - self._buffer[name].append(np.copy(value)) + self._buffer = {} + + for attribute in self.dataset_structure.saved_data_description: + name = attribute.name + value = values[name] + + if attribute.dtype != "bytes": + current_values = self._buffer.get(name, []) + current_values.append( + self._value_to_np( + attribute=attribute, + value=value, + )) + self._buffer[name] = current_values + else: + # Extend and remember the length. Attributes with dtype "bytes" + # may have variable length. Handle this case. We need to avoid + # two things: + # - Having wrong length of the bytes array and ideally also + # avoid padding. + # - Using allow_pickle when saving since that could lead to code + # execution when loading malicious dataset. + # We prefix the attribute name by `len_?` such that the new name + # is unique and tells us the lengths of the byte arrays. + counts = self._buffer.get(self._counting_prefix + name, [0]) + counts.append(counts[-1] + + len(value) # type: ignore[arg-type,operator] + ) + self._buffer[self._counting_prefix + name] = counts + + byte_list: list[list[int]] + byte_list = self._buffer.get( # type: ignore[assignment] + name, [[]]) + byte_list[0].extend(list(value) # type: ignore[arg-type] + ) + self._buffer[name] = byte_list # type: ignore[assignment] def close(self) -> None: """Close the shard file(-s). @@ -73,14 +124,32 @@ def close(self) -> None: assert not self._shard_file.is_file() return - # Write the buffer into a file. + # Deal properly with "bytes" attributes. + for attribute in self.dataset_structure.saved_data_description: + if attribute.dtype != "bytes": + continue + self._buffer[attribute.name] = [ + np.array( + self._buffer[attribute.name][0], + dtype=np.uint8, + ) + ] + + # Write the buffer into a file. We should not need to allow_pickle while + # saving (the default value is True). But on GitHub actions macos-13 + # runner the tests were failing while reading. The security concern + # (code execution) should be more on the side of loading pickled data. match self.dataset_structure.compression: case "ZIP": - np.savez_compressed(str(self._shard_file), - **self._buffer) # type: ignore[arg-type] + np.savez_compressed( + str(self._shard_file), + **self._buffer, # type: ignore[arg-type] + ) case "": - np.savez(str(self._shard_file), - **self._buffer) # type: ignore[arg-type] + np.savez( + str(self._shard_file), + **self._buffer, # type: ignore[arg-type] + ) case _: # Default should never happen since ShardWriterBase checks that # the requested compression type is supported. diff --git a/src/sedpack/io/tfrec/tfdata.py b/src/sedpack/io/tfrec/tfdata.py index 2f6b6e0b..a07d6e54 100644 --- a/src/sedpack/io/tfrec/tfdata.py +++ b/src/sedpack/io/tfrec/tfdata.py @@ -70,6 +70,7 @@ def get_from_tfrecord( for attribute in saved_data_description: dtype = { "str": tf.string, + "int": tf.int64, "bytes": tf.string, "uint8": tf.int64, "int8": tf.int64, @@ -95,9 +96,13 @@ def from_tfrecord(tf_record: Any) -> Any: for attribute in saved_data_description: if attribute.dtype == "float16": rec[attribute.name] = tf.io.parse_tensor( - rec[attribute.name], tf.float16) - rec[attribute.name] = tf.ensure_shape(rec[attribute.name], - shape=attribute.shape) + rec[attribute.name], + tf.float16, + ) + rec[attribute.name] = tf.ensure_shape( + rec[attribute.name], + shape=attribute.shape, + ) return rec return from_tfrecord @@ -128,6 +133,7 @@ def to_tfrecord(saved_data_description: list[Attribute], if len(attribute_names) != len(values): raise ValueError(f"There are missing attributes. Got: {values} " f"expected: {attribute_names}") + del attribute_names # Create dictionary of features feature = {} @@ -135,16 +141,17 @@ def to_tfrecord(saved_data_description: list[Attribute], for attribute in saved_data_description: value = values[attribute.name] - # Convert the value into a NumPy type. - value = np.array(value) - # Check shape - if attribute.dtype != "bytes" and value.shape != attribute.shape: + if attribute.dtype not in [ + "bytes", + "str", + "int", + ] and value.shape != attribute.shape: raise ValueError(f"Wrong shape of {attribute.name}, expected: " f"{attribute.shape}, got: {value.shape}.") # Set feature value - if attribute.dtype in ["int8", "uint8", "int32", "int64"]: + if attribute.dtype in ["int", "int8", "uint8", "int32", "int64"]: feature[attribute.name] = int64_feature(values[attribute.name]) elif attribute.dtype == "float16": value = value.astype(dtype=np.float16) diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py new file mode 100644 index 00000000..51749716 --- /dev/null +++ b/tests/io/test_end2end_dtypes.py @@ -0,0 +1,354 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any +import random + +import numpy as np +import numpy.typing as npt +import pytest + +import sedpack +from sedpack.io import Dataset +from sedpack.io.shard_info_iterator import ShardInfoIterator +from sedpack.io import Metadata +from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT + + +def dataset_and_values_dynamic_shape( + tmpdir: str | Path, + shard_file_type: str, + compression: str, + dtypes: list[str], + items: int, +) -> (Dataset, dict[str, list[Any]]): + values: dict[str, list[Any]] = {} + ds_path = Path(tmpdir) / f"e2e_{shard_file_type}_{'_'.join(dtypes)}" + dataset_metadata = Metadata(description="Test of the lib") + + # The order should not play a role. + random.shuffle(dtypes) + + example_attributes = [ + sedpack.io.metadata.Attribute( + name=f"attribute_{dtype}", + dtype=dtype, + shape=(), + ) for dtype in dtypes + ] + + for dtype in dtypes: + values[f"attribute_{dtype}"] = [] + + match dtype: + case "int": + for _ in range(items): + # TODO larger range than just int64 + values[f"attribute_{dtype}"].append( + random.randint(-2**60, 2**60)) + case "str": + long_string = "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ" \ + "https://arxiv.org/abs/2306.07249 ḹẩḇőꝛế" \ + "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ" \ + ":(){ :|:& };: éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ." + for _ in range(items): + begin: int = random.randint(0, len(long_string) // 2) + end: int = random.randint(begin + 1, len(long_string)) + values[f"attribute_{dtype}"].append(long_string[begin:end]) + case "bytes": + for _ in range(items): + values[f"attribute_{dtype}"].append( + np.random.randint( + 0, + 256, + size=random.randint(5, 20), + dtype=np.uint8, + ).tobytes()) + + dataset_structure = sedpack.io.metadata.DatasetStructure( + saved_data_description=example_attributes, + compression=compression, + examples_per_shard=3, + shard_file_type=shard_file_type, + ) + + # Test attribute_by_name + for attribute in example_attributes: + assert dataset_structure.attribute_by_name( + attribute_name=attribute.name) == attribute + + dataset = Dataset.create( + path=ds_path, + metadata=dataset_metadata, + dataset_structure=dataset_structure, + ) + + # Fill data in the dataset + + with dataset.filler() as filler: + for i in range(items): + filler.write_example( + values={ + name: value[i] for name, value in values.items() + }, + split=TRAIN_SPLIT, + ) + + # Check the data is correct + # Reopen the dataset + dataset = Dataset(ds_path) + dataset.check() + + return (values, dataset) + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "GZIP", + }, + { + "dtypes": ["bytes"], + "compression": "GZIP", + }, + { + "dtypes": ["int"], + "compression": "GZIP", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "GZIP", + }, + ], +) +def values_and_dataset_tfrec(request, tmpdir_factory) -> None: + shard_file_type: str = "tfrec" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "ZIP", + }, + { + "dtypes": ["bytes"], + "compression": "ZIP", + }, + { + "dtypes": ["int"], + "compression": "ZIP", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "ZIP", + }, + ], +) +def values_and_dataset_npz(request, tmpdir_factory) -> None: + shard_file_type: str = "npz" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "LZ4", + }, + { + "dtypes": ["bytes"], + "compression": "LZ4", + }, + { + "dtypes": ["int"], + "compression": "LZ4", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "LZ4", + }, + ], +) +def values_and_dataset_fb(request, tmpdir_factory) -> None: + shard_file_type: str = "fb" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +def check_iteration_of_values( + method: str, + dataset: Dataset, + values: dict[str, list[Any]], +) -> None: + match method: + case "as_tfdataset": + for i, example in enumerate( + dataset.as_tfdataset( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + batch_size=1, + )): + assert len(example) == len(values) + + # No idea how to have an actual string or bytes in TensorFlow. + # Maybe it is best to leave it as a tensor anyway since that is + # the "native" type. + + for name, returned_batch in example.items(): + assert returned_batch == values[name][i:i + 1] + case "as_numpy_iterator": + for i, example in enumerate( + dataset.as_numpy_iterator( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert len(example) == len(values) + for name, returned_value in example.items(): + if dataset.dataset_structure.shard_file_type != "tfrec": + assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) + else: + if "attribute_str" == name: + assert returned_value == values[name][i].encode( + "utf-8") + else: + assert returned_value == values[name][i] + case "as_numpy_iterator_concurrent": + for i, example in enumerate( + dataset.as_numpy_iterator_concurrent( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert len(example) == len(values) + for name, returned_value in example.items(): + if dataset.dataset_structure.shard_file_type != "tfrec": + assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) + else: + if "attribute_str" == name: + assert returned_value == values[name][i].encode( + "utf-8") + else: + assert returned_value == values[name][i] + case "as_numpy_iterator_rust": + for i, example in enumerate( + dataset.as_numpy_iterator_concurrent( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert len(example) == len(values) + for name, returned_value in example.items(): + assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) + + # We tested everything + if i + 1 != len(next(iter(values.values()))): + raise AssertionError("Not all examples have been iterated") + + # Number of shards matches + full_iterator = ShardInfoIterator( + dataset_path=dataset.path, + dataset_info=dataset.dataset_info, + split=None, + ) + number_of_all_shards: int = full_iterator.number_of_shards() + assert number_of_all_shards == len(full_iterator) + assert number_of_all_shards == len(list(full_iterator)) + assert number_of_all_shards == sum( + ShardInfoIterator( + dataset_path=dataset.path, + dataset_info=dataset.dataset_info, + split=split, + ).number_of_shards() for split in ["train", "test", "holdout"]) + + +@pytest.mark.parametrize("method", [ + "as_tfdataset", + "as_numpy_iterator", + "as_numpy_iterator_concurrent", +]) +def test_end2end_dtypes_str_tfrec( + method: str, + values_and_dataset_tfrec, +) -> None: + values, dataset = values_and_dataset_tfrec + check_iteration_of_values( + method=method, + dataset=dataset, + values=values, + ) + + +@pytest.mark.parametrize("method", [ + "as_numpy_iterator", + "as_numpy_iterator_concurrent", +]) +def test_end2end_dtypes_str_npz( + method: str, + values_and_dataset_npz, +) -> None: + values, dataset = values_and_dataset_npz + check_iteration_of_values( + method=method, + dataset=dataset, + values=values, + ) + + +@pytest.mark.parametrize("method", [ + "as_numpy_iterator", + "as_numpy_iterator_concurrent", + "as_numpy_iterator_rust", +]) +def test_end2end_dtypes_str_fb( + method: str, + values_and_dataset_fb, +) -> None: + values, dataset = values_and_dataset_fb + check_iteration_of_values( + method=method, + dataset=dataset, + values=values, + )