From efed8fd474a5e0cc04ce898a76746bc5ea170781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 27 Aug 2025 19:26:26 +0200 Subject: [PATCH 01/17] Improve saving and loading of str attributes --- src/sedpack/io/flatbuffer/iterate.py | 5 + .../io/shard/shard_writer_flatbuffer.py | 113 +++++++---- src/sedpack/io/tfrec/tfdata.py | 21 +- tests/io/test_end2end_dtypes.py | 190 ++++++++++++++++++ 4 files changed, 285 insertions(+), 44 deletions(-) create mode 100644 tests/io/test_end2end_dtypes.py diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index 8de9c7a7..e2635df4 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -112,6 +112,11 @@ def decode_array(np_bytes: npt.NDArray[np.uint8], Returns: the parsed np.ndarray of the correct dtype and shape. """ + if attribute.dtype == "str": + return np_bytes.tobytes().decode("utf-8") + if attribute.dtype == "bytes": + return np_bytes.tobytes() + dt = np.dtype(attribute.dtype) # FlatBuffers are little-endian. There is no byteswap by # `np.frombuffer` but the array will be interpreted correctly. diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index ed47cb5b..091708c9 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -101,45 +101,15 @@ def _write(self, values: ExampleT) -> None: self._examples.append(fbapi_Example.ExampleEnd(self._builder)) @staticmethod - def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] - builder: Builder, attribute: Attribute, - value: AttributeValueT) -> int: - """Save a given array into a FlatBuffer as bytes. This is to ensure - compatibility with types which are not supported by FlatBuffers (e.g., - np.float16). The FlatBuffers schema must mark this vector as type - bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a - distinction of how the length is being saved. The inverse of this - function is - `sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`. - - If we have an array of np.int32 of 10 elements the FlatBuffers library - would save the length as 10. Which is then impossible to read in Rust - since the length and itemsize (sizeof of the type) are private. Thus we - could not get the full array back. Thus we are saving the array as 40 - bytes. This function does not modify the `value`, and saves a flattened - version of it. This function also saves the exact dtype as given by - `attribute`. Bytes are being saved in little endian ("<") and - c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to - `dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`. - - Args: - - builder (flatbuffers.Builder): The byte buffer being constructed. - Must be initialized. - - attribute (Attribute): Description of this attribute (shape and - dtype). - - value (AttributeValueT): The array to be saved. The shape should be - as defined in `attribute` (will be flattened). - - Returns: The offset returned by `flatbuffers.Builder.EndVector`. + def _np_to_bytes( + builder: Builder, + attribute: Attribute, + value: AttributeValueT, + ) -> bytes: + """ """ - # Not sure about flatbuffers.Builder __bool__ semantics. assert builder is not None - # See `flatbuffers.builder.Builder.CreateNumpyVector`. - # Copy the value in order not to modify the original and flatten for # better saving. value_flattened = np.copy(value).flatten() @@ -178,7 +148,73 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] f" {attribute.name}") # This is going to be saved, ensure c_contiguous ordering. - byte_representation = value_np.tobytes(order="C") + return value_np.tobytes(order="C") + + @staticmethod + def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] + builder: Builder, + attribute: Attribute, + value: AttributeValueT, + ) -> int: + """Save a given array into a FlatBuffer as bytes. This is to ensure + compatibility with types which are not supported by FlatBuffers (e.g., + np.float16). The FlatBuffers schema must mark this vector as type + bytes [byte] (see src/sedpack/io/flatbuffer/shard.fbs) since there is a + distinction of how the length is being saved. The inverse of this + function is + `sedpack.io.flatbuffer.iterate.IterateShardFlatBuffer.decode_array`. + + If we have an array of np.int32 of 10 elements the FlatBuffers library + would save the length as 10. Which is then impossible to read in Rust + since the length and itemsize (sizeof of the type) are private. Thus we + could not get the full array back. Thus we are saving the array as 40 + bytes. This function does not modify the `value`, and saves a flattened + version of it. This function also saves the exact dtype as given by + `attribute`. Bytes are being saved in little endian ("<") and + c_contiguous ("C") order, same as with FlatBuffers. Alignment is set to + `dtype.itemsize` as opposed to FlatBuffers choice of `dtype.alignment`. + + Args: + + builder (flatbuffers.Builder): The byte buffer being constructed. + Must be initialized. + + attribute (Attribute): Description of this attribute (shape and + dtype). + + value (AttributeValueT): The array to be saved. The shape should be + as defined in `attribute` (will be flattened). + + Returns: The offset returned by `flatbuffers.Builder.EndVector`. + """ + # Not sure about flatbuffers.Builder __bool__ semantics. + assert builder is not None + + # See `flatbuffers.builder.Builder.CreateNumpyVector`. + + byte_representation: bytess + alignment: int + match attribute.dtype: + case "str": + byte_representation = value.encode("utf-8") + alignment = 16 + case "bytes": + byte_representation = bytes(value) + alignment = 16 + case _: + try: + dt = np.dtype(attribute.dtype) + except: + raise ValueError(f"Unsupported dtype {attribute.dtype}") + + byte_representation = ShardWriterFlatBuffer._np_to_bytes( + builder=builder, + attribute=attribute, + value=value, + ) + value_flattened = np.copy(value).flatten() + value_np = np.array(value_flattened, dtype=attribute.dtype) + alignment = value_np.dtype.itemsize # Total length of the array (in bytes). length: int = len(byte_representation) @@ -187,8 +223,7 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] builder.StartVector( elemSize=1, # Storing bytes. numElems=length, - alignment=value_np.dtype. - itemsize, # Cautious alignment of the array. + alignment=alignment, # Cautious alignment of the array. ) builder.head = int(builder.Head() - length) diff --git a/src/sedpack/io/tfrec/tfdata.py b/src/sedpack/io/tfrec/tfdata.py index 2f6b6e0b..f64baf1e 100644 --- a/src/sedpack/io/tfrec/tfdata.py +++ b/src/sedpack/io/tfrec/tfdata.py @@ -93,11 +93,22 @@ def get_from_tfrecord( def from_tfrecord(tf_record: Any) -> Any: rec = tf.io.parse_single_example(tf_record, tf_features) for attribute in saved_data_description: - if attribute.dtype == "float16": - rec[attribute.name] = tf.io.parse_tensor( - rec[attribute.name], tf.float16) - rec[attribute.name] = tf.ensure_shape(rec[attribute.name], - shape=attribute.shape) + match attribute.dtype: + case "str": + pass + #rec[attribute.name] = rec[attribute.name].decode("utf-8") + case "float16": + rec[attribute.name] = tf.io.parse_tensor( + rec[attribute.name], + tf.float16, + ) + rec[attribute.name] = tf.ensure_shape( + rec[attribute.name], + shape=attribute.shape, + ) + case _: + # Nothing extra needs to be done. + pass return rec return from_tfrecord diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py new file mode 100644 index 00000000..7ba1e34b --- /dev/null +++ b/tests/io/test_end2end_dtypes.py @@ -0,0 +1,190 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Union + +import numpy as np +import numpy.typing as npt +import pytest + +import sedpack +from sedpack.io import Dataset +from sedpack.io.shard_info_iterator import ShardInfoIterator +from sedpack.io import Metadata +from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT + + +def end2end_str( + tmpdir: Union[str, Path], + method: str, + shard_file_type: ShardFileTypeT, + compression: CompressionT, +) -> None: + array_of_values = [ + "https://arxiv.org/abs/2306.07249", + "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ", + "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ ḹẩḇőꝛế", + "éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ.", + ] + + tiny_experiment_path: Path = Path(tmpdir) / "e2e_str_experiment" + + # Create a dataset + + dataset_metadata = Metadata(description="Test of the lib") + + example_attributes = [ + sedpack.io.metadata.Attribute( + name="strange_strings", + dtype="str", + shape=(), + ), + ] + + dataset_structure = sedpack.io.metadata.DatasetStructure( + saved_data_description=example_attributes, + compression=compression, + examples_per_shard=3, + shard_file_type=shard_file_type, + ) + + # Test attribute_by_name + for attribute in example_attributes: + assert dataset_structure.attribute_by_name( + attribute_name=attribute.name) == attribute + + dataset = Dataset.create( + path=tiny_experiment_path, + metadata=dataset_metadata, + dataset_structure=dataset_structure, + ) + + # Fill data in the dataset + + with dataset.filler() as filler: + for attribute_value in array_of_values: + filler.write_example( + values={"strange_strings": attribute_value}, + split=TRAIN_SPLIT, + ) + + # Check the data is correct + # Reopen the dataset + dataset = Dataset(tiny_experiment_path) + dataset.check() + + match method: + case "as_tfdataset": + for i, example in enumerate( + dataset.as_tfdataset( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + batch_size=1, + )): + assert type(example["strange_strings"]) == type( + array_of_values[i:i+1]) + assert example["strange_strings"] == array_of_values[i:i + 1] + case "as_numpy_iterator": + for i, example in enumerate( + dataset.as_numpy_iterator( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert type(example["strange_strings"]) == type( + array_of_values[i]) + assert example["strange_strings"] == array_of_values[i] + case "as_numpy_iterator_concurrent": + for i, example in enumerate( + dataset.as_numpy_iterator_concurrent( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert type(example["strange_strings"]) == type( + array_of_values[i]) + assert example["strange_strings"] == array_of_values[i] + + # We tested everything + assert i + 1 == len(array_of_values), "Not all examples have been iterated" + + # Number of shards matches + full_iterator = ShardInfoIterator( + dataset_path=dataset.path, + dataset_info=dataset.dataset_info, + split=None, + ) + number_of_all_shards: int = full_iterator.number_of_shards() + assert number_of_all_shards == len(full_iterator) + assert number_of_all_shards == len(list(full_iterator)) + assert number_of_all_shards == sum( + ShardInfoIterator( + dataset_path=dataset.path, + dataset_info=dataset.dataset_info, + split=split, + ).number_of_shards() for split in ["train", "test", "holdout"]) + + +# TODO common fixture tfrec_dataset +@pytest.mark.parametrize("method", [ + "as_tfdataset", + "as_numpy_iterator", + "as_numpy_iterator_concurrent", +]) +def test_end2end_dtypes_str_tfrec( + method: str, + tmpdir: Union[str, Path], +) -> None: + end2end_str( + tmpdir=tmpdir, + method=method, + shard_file_type="tfrec", + compression="GZIP", + ) + + +# TODO common fixture npz_dataset +@pytest.mark.parametrize("method", [ + "as_numpy_iterator", + "as_numpy_iterator_concurrent", +]) +def test_end2end_dtypes_str_npz( + method: str, + tmpdir: Union[str, Path], +) -> None: + end2end_str( + tmpdir=tmpdir, + method=method, + shard_file_type="npz", + compression="ZIP", + ) + + +# TODO common fixture fb_dataset +@pytest.mark.parametrize("method", [ + "as_numpy_iterator", + "as_numpy_iterator_concurrent", +]) +def test_end2end_dtypes_str_fb( + method: str, + tmpdir: Union[str, Path], +) -> None: + end2end_str( + tmpdir=tmpdir, + method=method, + shard_file_type="fb", + compression="LZ4", + ) From 8478126e6dbee45078709a8ffeb85cf03977cefc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 27 Aug 2025 22:07:59 +0200 Subject: [PATCH 02/17] [squash] fix nits --- src/sedpack/io/shard/shard_writer_flatbuffer.py | 2 +- tests/io/test_end2end_dtypes.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index 091708c9..101c1e85 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -192,7 +192,7 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] # See `flatbuffers.builder.Builder.CreateNumpyVector`. - byte_representation: bytess + byte_representation: bytes alignment: int match attribute.dtype: case "str": diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index 7ba1e34b..adc62f95 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -95,7 +95,7 @@ def end2end_str( batch_size=1, )): assert type(example["strange_strings"]) == type( - array_of_values[i:i+1]) + array_of_values[i:i + 1]) assert example["strange_strings"] == array_of_values[i:i + 1] case "as_numpy_iterator": for i, example in enumerate( From e6fbb9d19ca1eff854af965c0843bfa2d39ca58e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Thu, 28 Aug 2025 08:35:10 +0200 Subject: [PATCH 03/17] [squash] mostly ok as_tfdataset should also have strings? or is bytearray representation of string ok here? Copy issues should be fine since Rust is allocating the memory and passing ownership to Python. Only difference maybe str(value) vs value.decode("utf-8")? bytearray needs the same need a dataset fixture for efficiency --- src/sedpack/io/flatbuffer/iterate.py | 4 ---- src/sedpack/io/npz/iterate_npz.py | 24 +++++++++++++++++++++- src/sedpack/io/tfrec/read.py | 30 +++++++++++++++++++++++++--- tests/io/test_end2end_dtypes.py | 4 ++-- 4 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index e2635df4..972c20ce 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -82,10 +82,6 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]: attribute=attribute, ) - # Copy otherwise the arrays are immutable and keep the whole - # file content from being garbage collected. - np_array = np.copy(np_array) - example_dictionary[attribute.name] = np_array yield example_dictionary diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index c8ce4200..6357689c 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -21,7 +21,9 @@ import aiofiles import numpy as np +import numpy.typing as npt +from sedpack.io.metadata import Attribute from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.types import AttributeValueT, ExampleT @@ -32,6 +34,17 @@ class IterateShardNP(IterateShardBase[T]): """Iterate a shard saved in the npz format. """ + @staticmethod + def decode_attribute(np_value: npt.NDArray[np.generic], + attribute: Attribute) -> AttributeValueT: + match attribute.dtype: + case "str": + return str(np_value) + case "bytes": + return np_value.tobytes() + case _: + return np_value + def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """Iterate a shard saved in the NumPy format npz. """ @@ -45,7 +58,16 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: break for i in range(elements): - yield {name: value[i] for name, value in shard_content.items()} + yield { + name: + IterateShardNP.decode_attribute( + np_value=value[i], + attribute=attribute, + ) for (name, value), attribute in zip( + shard_content.items(), + self.dataset_structure.saved_data_description, + ) + } # TODO(issue #85) fix and test async iterator typing async def iterate_shard_async( # pylint: disable=invalid-overridden-method diff --git a/src/sedpack/io/tfrec/read.py b/src/sedpack/io/tfrec/read.py index 8450f963..e336a2a9 100644 --- a/src/sedpack/io/tfrec/read.py +++ b/src/sedpack/io/tfrec/read.py @@ -19,13 +19,15 @@ from pathlib import Path from typing import Any, AsyncIterator, Callable, Iterable +import numpy as np +import numpy.typing as npt import tensorflow as tf -from sedpack.io.metadata import DatasetStructure +from sedpack.io.metadata import Attribute, DatasetStructure from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.tfrec.tfdata import get_from_tfrecord -from sedpack.io.types import ExampleT +from sedpack.io.types import AttributeValueT, ExampleT from sedpack.io.utils import func_or_identity @@ -45,6 +47,17 @@ def __init__( self.from_tfrecord: Callable[[Any], Any] | None = None self.num_parallel_calls: int = num_parallel_calls + @staticmethod + def decode_attribute(value: npt.NDArray[np.generic], + attribute: Attribute) -> AttributeValueT: + match attribute.dtype: + case "str": + return value.decode("utf-8") + case "bytes": + return value.tobytes() + case _: + return value + def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """Iterate a shard saved in the TFRec format """ @@ -65,7 +78,18 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: num_parallel_calls=self.num_parallel_calls, ) - yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc] + #yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc] + for example in tf_dataset_examples.as_numpy_iterator(): + yield { + name: + IterateShardTFRec.decode_attribute( + value=value, + attribute=attribute, + ) for (name, value), attribute in zip( + example.items(), + self.dataset_structure.saved_data_description, + ) + } # TODO(issue #85) fix and test async iterator typing async def iterate_shard_async( # pylint: disable=invalid-overridden-method diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index adc62f95..25722f06 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -94,8 +94,8 @@ def end2end_str( repeat=False, batch_size=1, )): - assert type(example["strange_strings"]) == type( - array_of_values[i:i + 1]) + #assert type(example["strange_strings"]) == type( + # array_of_values[i:i+1]) assert example["strange_strings"] == array_of_values[i:i + 1] case "as_numpy_iterator": for i, example in enumerate( From 039639472dd26a90cf772da84bcd6e1e80ad61dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Thu, 28 Aug 2025 11:08:18 +0200 Subject: [PATCH 04/17] [squash] fix workflows --- src/sedpack/io/flatbuffer/iterate.py | 10 ++++--- src/sedpack/io/npz/iterate_npz.py | 5 ++-- .../io/shard/shard_writer_flatbuffer.py | 30 ++++++++++++++----- src/sedpack/io/tfrec/read.py | 6 ++-- tests/io/test_end2end_dtypes.py | 7 +++-- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index 972c20ce..8ca46f96 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -26,7 +26,7 @@ from sedpack.io.compress import CompressedFile from sedpack.io.metadata import Attribute -from sedpack.io.types import ExampleT +from sedpack.io.types import AttributeValueT, ExampleT from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.utils import func_or_identity @@ -87,9 +87,11 @@ def _iterate_content(self, content: bytes) -> Iterable[ExampleT]: yield example_dictionary @staticmethod - def decode_array(np_bytes: npt.NDArray[np.uint8], - attribute: Attribute, - batch_size: int = 0) -> npt.NDArray[np.generic]: + def decode_array( + np_bytes: npt.NDArray[np.uint8], + attribute: Attribute, + batch_size: int = 0, + ) -> AttributeValueT: """Decode an array. See `sedpack.io.shard.shard_writer_flatbuffer .ShardWriterFlatBuffer.save_numpy_vector_as_bytearray` for format description. The code tries to avoid unnecessary copies. diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index 6357689c..e88367f9 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -21,7 +21,6 @@ import aiofiles import numpy as np -import numpy.typing as npt from sedpack.io.metadata import Attribute from sedpack.io.shard import IterateShardBase @@ -35,13 +34,13 @@ class IterateShardNP(IterateShardBase[T]): """ @staticmethod - def decode_attribute(np_value: npt.NDArray[np.generic], + def decode_attribute(np_value: AttributeValueT, attribute: Attribute) -> AttributeValueT: match attribute.dtype: case "str": return str(np_value) case "bytes": - return np_value.tobytes() + return bytes(np.array(np_value)) case _: return np_value diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index 101c1e85..2bebb3b1 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -101,12 +101,25 @@ def _write(self, values: ExampleT) -> None: self._examples.append(fbapi_Example.ExampleEnd(self._builder)) @staticmethod - def _np_to_bytes( + def _np_to_bytes( # type: ignore[no-any-unimported] builder: Builder, attribute: Attribute, value: AttributeValueT, ) -> bytes: - """ + """Turn a single AttributeValueT value into a sequence of bytes to be + saved in a FlatBuffer shard. + + Args: + + builder (Builder): The FlatBuffer builder. + + attribute (Attribute): Description of the attribute defining dtype and + shape. + + value (AttributeValueT): The actual value which is being represented. + + Raises: ValueError if the value cannot be cast to the correct dtype with + the correct byteorder. """ assert builder is not None @@ -186,6 +199,8 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] as defined in `attribute` (will be flattened). Returns: The offset returned by `flatbuffers.Builder.EndVector`. + + Raises: ValueError if the dtype is unknown to NumPy. """ # Not sure about flatbuffers.Builder __bool__ semantics. assert builder is not None @@ -196,16 +211,17 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] alignment: int match attribute.dtype: case "str": - byte_representation = value.encode("utf-8") + byte_representation = str(value).encode("utf-8") alignment = 16 case "bytes": - byte_representation = bytes(value) + byte_representation = np.array(value).tobytes() alignment = 16 case _: try: - dt = np.dtype(attribute.dtype) - except: - raise ValueError(f"Unsupported dtype {attribute.dtype}") + _ = np.dtype(attribute.dtype) + except Exception as exc: + raise ValueError( + f"Unsupported dtype {attribute.dtype}") from exc byte_representation = ShardWriterFlatBuffer._np_to_bytes( builder=builder, diff --git a/src/sedpack/io/tfrec/read.py b/src/sedpack/io/tfrec/read.py index e336a2a9..25eb563e 100644 --- a/src/sedpack/io/tfrec/read.py +++ b/src/sedpack/io/tfrec/read.py @@ -52,7 +52,7 @@ def decode_attribute(value: npt.NDArray[np.generic], attribute: Attribute) -> AttributeValueT: match attribute.dtype: case "str": - return value.decode("utf-8") + return bytes(value).decode("utf-8") case "bytes": return value.tobytes() case _: @@ -78,7 +78,6 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: num_parallel_calls=self.num_parallel_calls, ) - #yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc] for example in tf_dataset_examples.as_numpy_iterator(): yield { name: @@ -86,7 +85,8 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: value=value, attribute=attribute, ) for (name, value), attribute in zip( - example.items(), + # `example` is a dictionary, mypy does not know that. + example.items(), # type: ignore[attr-defined] self.dataset_structure.saved_data_description, ) } diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index 25722f06..a2c6f2ca 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -94,8 +94,11 @@ def end2end_str( repeat=False, batch_size=1, )): - #assert type(example["strange_strings"]) == type( - # array_of_values[i:i+1]) + # No idea how to have an actual string in TensorFlow. Maybe it + # is best to leave it as a tensor anyway since that is the + # "native" type. + #assert type(example["strange_strings"][0]) == type( + # array_of_values[i]) assert example["strange_strings"] == array_of_values[i:i + 1] case "as_numpy_iterator": for i, example in enumerate( From 3b86d4dfcd2f5037ba2e4616cc7ff10f6d5a5ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Thu, 28 Aug 2025 14:47:32 +0200 Subject: [PATCH 05/17] [squash] parametrized tests - not all of these pass - large int is still turned into a NumPy value --- tests/io/test_end2end_dtypes.py | 258 +++++++++++++++++++++++++------- 1 file changed, 202 insertions(+), 56 deletions(-) diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index a2c6f2ca..19a1080c 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -13,7 +13,8 @@ # limitations under the License. from pathlib import Path -from typing import Union +from typing import Any +import random import numpy as np import numpy.typing as npt @@ -26,33 +27,55 @@ from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT -def end2end_str( - tmpdir: Union[str, Path], - method: str, - shard_file_type: ShardFileTypeT, - compression: CompressionT, -) -> None: - array_of_values = [ - "https://arxiv.org/abs/2306.07249", - "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ", - "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ ḹẩḇőꝛế", - "éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ.", - ] - - tiny_experiment_path: Path = Path(tmpdir) / "e2e_str_experiment" - - # Create a dataset - +def dataset_and_values_dynamic_shape( + tmpdir: str | Path, + shard_file_type: str, + compression: str, + dtypes: list[str], + items: int, +) -> (Dataset, dict[str, list[Any]]): + values: dict[str, list[Any]] = {} + ds_path = Path(tmpdir) / f"e2e_{shard_file_type}_{'_'.join(dtypes)}" dataset_metadata = Metadata(description="Test of the lib") + # The order should not play a role. + random.shuffle(dtypes) + example_attributes = [ sedpack.io.metadata.Attribute( - name="strange_strings", - dtype="str", + name=f"attribute_{dtype}", + dtype=dtype, shape=(), - ), + ) for dtype in dtypes ] + for dtype in dtypes: + values[f"attribute_{dtype}"] = [] + + match dtype: + case "int": + for _ in range(items): + # TODO larger range than just int64 + values[f"attribute_{dtype}"].append( + random.randint(-60**2, 60**2)) + case "str": + long_string = "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ" \ + "https://arxiv.org/abs/2306.07249 ḹẩḇőꝛế" \ + "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ" \ + ":(){ :|:& };: éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ." + for _ in range(items): + begin: int = random.randint(0, len(long_string) // 2) + end: int = random.randint(begin + 1, len(long_string)) + values[f"attribute_{dtype}"].append(long_string[begin:end]) + case "bytes": + for _ in range(items): + values[f"attribute_{dtype}"].append( + np.random.randint( + 0, + 256, + size=random.randint(5, 20), + ).tobytes()) + dataset_structure = sedpack.io.metadata.DatasetStructure( saved_data_description=example_attributes, compression=compression, @@ -66,7 +89,7 @@ def end2end_str( attribute_name=attribute.name) == attribute dataset = Dataset.create( - path=tiny_experiment_path, + path=ds_path, metadata=dataset_metadata, dataset_structure=dataset_structure, ) @@ -74,17 +97,126 @@ def end2end_str( # Fill data in the dataset with dataset.filler() as filler: - for attribute_value in array_of_values: + for i in range(items): filler.write_example( - values={"strange_strings": attribute_value}, + values={ + name: value[i] for name, value in values.items() + }, split=TRAIN_SPLIT, ) # Check the data is correct # Reopen the dataset - dataset = Dataset(tiny_experiment_path) + dataset = Dataset(ds_path) dataset.check() + return (values, dataset) + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "GZIP", + }, + #{ + # "dtypes": ["bytes"], + # "compression": "GZIP", + #}, + #{ + # "dtypes": ["int"], + # "compression": "GZIP", + #}, + #{ + # "dtypes": ["str", "bytes", "int"], + # "compression": "GZIP", + #}, + ], +) +def values_and_dataset_tfrec(request, tmpdir_factory) -> None: + shard_file_type: str = "tfrec" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "ZIP", + }, + { + "dtypes": ["bytes"], + "compression": "ZIP", + }, + #{ + # "dtypes": ["int"], + # "compression": "ZIP", + #}, + #{ + # "dtypes": ["str", "bytes", "int"], + # "compression": "ZIP", + #}, + ], +) +def values_and_dataset_npz(request, tmpdir_factory) -> None: + shard_file_type: str = "npz" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +@pytest.fixture( + scope="module", + params=[ + { + "dtypes": ["str"], + "compression": "LZ4", + }, + #{ + # "dtypes": ["bytes"], + # "compression": "LZ4", + #}, + #{ + # "dtypes": ["int"], + # "compression": "LZ4", + #}, + #{ + # "dtypes": ["str", "bytes", "int"], + # "compression": "LZ4", + #}, + ], +) +def values_and_dataset_fb(request, tmpdir_factory) -> None: + shard_file_type: str = "fb" + yield dataset_and_values_dynamic_shape( + tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"), + shard_file_type=shard_file_type, + compression=request.param["compression"], + dtypes=request.param["dtypes"], + items=137, + ) + # Teardown. + + +def check_iteration_of_values( + method: str, + dataset: Dataset, + values: dict[str, list[Any]], +) -> None: match method: case "as_tfdataset": for i, example in enumerate( @@ -94,12 +226,14 @@ def end2end_str( repeat=False, batch_size=1, )): - # No idea how to have an actual string in TensorFlow. Maybe it - # is best to leave it as a tensor anyway since that is the - # "native" type. - #assert type(example["strange_strings"][0]) == type( - # array_of_values[i]) - assert example["strange_strings"] == array_of_values[i:i + 1] + assert len(example) == len(values) + + # No idea how to have an actual string or bytes in TensorFlow. + # Maybe it is best to leave it as a tensor anyway since that is + # the "native" type. + + for name, returned_batch in example.items(): + assert returned_batch == values[name][i:i + 1] case "as_numpy_iterator": for i, example in enumerate( dataset.as_numpy_iterator( @@ -107,9 +241,10 @@ def end2end_str( shuffle=0, repeat=False, )): - assert type(example["strange_strings"]) == type( - array_of_values[i]) - assert example["strange_strings"] == array_of_values[i] + assert len(example) == len(values) + for name, returned_value in example.items(): + assert type(returned_value) == type(values[name][i]) + assert returned_value == values[name][i] case "as_numpy_iterator_concurrent": for i, example in enumerate( dataset.as_numpy_iterator_concurrent( @@ -117,12 +252,25 @@ def end2end_str( shuffle=0, repeat=False, )): - assert type(example["strange_strings"]) == type( - array_of_values[i]) - assert example["strange_strings"] == array_of_values[i] + assert len(example) == len(values) + for name, returned_value in example.items(): + assert type(returned_value) == type(values[name][i]) + assert returned_value == values[name][i] + case "as_numpy_iterator_rust": + for i, example in enumerate( + dataset.as_numpy_iterator_concurrent( + split=TRAIN_SPLIT, + shuffle=0, + repeat=False, + )): + assert len(example) == len(values) + for name, returned_value in example.items(): + assert type(returned_value) == type(values[name][i]) + assert returned_value == values[name][i] # We tested everything - assert i + 1 == len(array_of_values), "Not all examples have been iterated" + if i + 1 != len(next(iter(values.values()))): + raise AssertionError("Not all examples have been iterated") # Number of shards matches full_iterator = ShardInfoIterator( @@ -141,7 +289,6 @@ def end2end_str( ).number_of_shards() for split in ["train", "test", "holdout"]) -# TODO common fixture tfrec_dataset @pytest.mark.parametrize("method", [ "as_tfdataset", "as_numpy_iterator", @@ -149,45 +296,44 @@ def end2end_str( ]) def test_end2end_dtypes_str_tfrec( method: str, - tmpdir: Union[str, Path], + values_and_dataset_tfrec, ) -> None: - end2end_str( - tmpdir=tmpdir, + values, dataset = values_and_dataset_tfrec + check_iteration_of_values( method=method, - shard_file_type="tfrec", - compression="GZIP", + dataset=dataset, + values=values, ) -# TODO common fixture npz_dataset @pytest.mark.parametrize("method", [ "as_numpy_iterator", "as_numpy_iterator_concurrent", ]) def test_end2end_dtypes_str_npz( method: str, - tmpdir: Union[str, Path], + values_and_dataset_npz, ) -> None: - end2end_str( - tmpdir=tmpdir, + values, dataset = values_and_dataset_npz + check_iteration_of_values( method=method, - shard_file_type="npz", - compression="ZIP", + dataset=dataset, + values=values, ) -# TODO common fixture fb_dataset @pytest.mark.parametrize("method", [ "as_numpy_iterator", "as_numpy_iterator_concurrent", + "as_numpy_iterator_rust", ]) def test_end2end_dtypes_str_fb( method: str, - tmpdir: Union[str, Path], + values_and_dataset_fb, ) -> None: - end2end_str( - tmpdir=tmpdir, + values, dataset = values_and_dataset_fb + check_iteration_of_values( method=method, - shard_file_type="fb", - compression="LZ4", + dataset=dataset, + values=values, ) From be17d60fd562905f6bdfc9365ee4851be393303a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 11:35:29 +0200 Subject: [PATCH 06/17] [squash] --- src/sedpack/io/flatbuffer/iterate.py | 18 ++++- src/sedpack/io/npz/iterate_npz.py | 86 ++++++++++++++++++---- src/sedpack/io/shard/shard_writer_np.py | 90 ++++++++++++++++++++--- src/sedpack/io/tfrec/read.py | 30 +------- src/sedpack/io/tfrec/tfdata.py | 37 +++++----- tests/io/test_end2end_dtypes.py | 97 ++++++++++++++----------- 6 files changed, 242 insertions(+), 116 deletions(-) diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index 8ca46f96..d5d16954 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -110,10 +110,20 @@ def decode_array( Returns: the parsed np.ndarray of the correct dtype and shape. """ - if attribute.dtype == "str": - return np_bytes.tobytes().decode("utf-8") - if attribute.dtype == "bytes": - return np_bytes.tobytes() + match attribute.dtype: + case "str": + return np_bytes.tobytes().decode("utf-8") + case "bytes": + return np_bytes.tobytes() + case "int": + return int( + np.frombuffer( + buffer=np_bytes, + dtype=np.dtype("int64").newbyteorder("<"), + )) + case _: + # The rest is interpreted as NumPy array. + pass dt = np.dtype(attribute.dtype) # FlatBuffers are little-endian. There is no byteswap by diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index e88367f9..d3ccf772 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -34,38 +34,93 @@ class IterateShardNP(IterateShardBase[T]): """ @staticmethod - def decode_attribute(np_value: AttributeValueT, - attribute: Attribute) -> AttributeValueT: + def decode_attribute( + np_value: AttributeValueT, + attribute: Attribute, + ) -> AttributeValueT: match attribute.dtype: case "str": return str(np_value) case "bytes": - return bytes(np.array(np_value)) + raise ValueError("One needs to use decode_bytes_attribute") + case "int": + return int(np_value) case _: return np_value + @staticmethod + def decode_bytes_attribute( + value: AttributeValueT, + indexes: list[AttributeValueT], + attribute: Attribute, + index: int, + ) -> AttributeValueT: + """Decode a bytes attribute. We are saving the byte attributes as a + continuous array across multiple examples and on the side we also save + the indexes into this array. + + Args: + + value (AttributeValueT): The NumPy array of np.uint8 containing + concatenated bytes values. + + indexes (list[AttributeValueT]): Indexes into this array. + + attribute (Attribute): The attribute description. + + index (int): Which example out of this shard to return. + """ + if attribute.dtype != "bytes": + raise ValueError("One needs to use decode_attribute") + # Help with type-checking: + my_value = np.array(value, np.uint8) + my_indexes = np.array(indexes, np.int64) + del value + del indexes + + begin: int = my_indexes[index] + end: int = my_indexes[index + 1] + return bytes(my_value[begin:end]) + def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """Iterate a shard saved in the NumPy format npz. """ - shard_content: dict[str, list[AttributeValueT]] = np.load(file_path) + # A prefix such that prepended it creates a new name without collision + # with any attribute name. + self._counting_prefix: str = "len" + "_" * max( + len(attribute.name) + for attribute in self.dataset_structure.saved_data_description) + + shard_content: dict[str, list[AttributeValueT]] = np.load( + file_path, + allow_pickle=False, + ) # A given shard contains the same number of elements for each # attribute. - elements: int = 0 - for values in shard_content.values(): - elements = len(values) - break + elements: int + for attribute in self.dataset_structure.saved_data_description: + if self._counting_prefix + attribute.name in shard_content: + elements = len( + shard_content[self._counting_prefix + attribute.name]) - 1 + else: + elements = len(shard_content[attribute.name]) for i in range(elements): yield { - name: + attribute.name: IterateShardNP.decode_attribute( - np_value=value[i], + np_value=shard_content[attribute.name][i], + attribute=attribute, + ) if attribute.dtype != "bytes" else + IterateShardNP.decode_bytes_attribute( + value=shard_content[attribute.name][0], + indexes=shard_content[self._counting_prefix + + attribute.name], attribute=attribute, - ) for (name, value), attribute in zip( - shard_content.items(), - self.dataset_structure.saved_data_description, + index=i, ) + for attribute in self.dataset_structure.saved_data_description } # TODO(issue #85) fix and test async iterator typing @@ -79,7 +134,10 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method content_bytes: bytes = await f.read() content_io = io.BytesIO(content_bytes) - shard_content: dict[str, list[AttributeValueT]] = np.load(content_io) + shard_content: dict[str, list[AttributeValueT]] = np.load( + content_io, + allow_pickle=False, + ) # A given shard contains the same number of elements for each # attribute. diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 636a0cd8..a3759109 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -20,8 +20,9 @@ from pathlib import Path import numpy as np +from numpy import typing as npt -from sedpack.io.metadata import DatasetStructure +from sedpack.io.metadata import Attribute, DatasetStructure from sedpack.io.types import AttributeValueT, CompressionT, ExampleT from sedpack.io.shard.shard_writer_base import ShardWriterBase @@ -50,6 +51,26 @@ def __init__(self, dataset_structure: DatasetStructure, self._buffer: dict[str, list[AttributeValueT]] = {} + # A prefix such that prepended it creates a new name without collision + # with any attribute name. + self._counting_prefix: str = "len" + "_" * max( + len(attribute.name) + for attribute in dataset_structure.saved_data_description) + + def _value_to_np( + self, + attribute: Attribute, + value: AttributeValueT, + ) -> npt.NDArray[np.generic] | str: + match attribute.dtype: + case "bytes": + raise ValueError("Attributes bytes are saved extra") + case "str": + assert isinstance(value, str) + return value + case _: + return np.copy(value) + def _write(self, values: ExampleT) -> None: """Write an example on disk. Writing may be buffered. @@ -59,12 +80,42 @@ def _write(self, values: ExampleT) -> None: """ # Just buffer all values. if not self._buffer: - self._buffer = { - name: [np.copy(value)] for name, value in values.items() - } - else: - for name, value in values.items(): - self._buffer[name].append(np.copy(value)) + self._buffer = {} + + for (name, value), attribute in zip( + values.items(), + self.dataset_structure.saved_data_description, + ): + if attribute.dtype != "bytes": + current_values = self._buffer.get(name, []) + current_values.append( + self._value_to_np( + attribute=attribute, + value=value, + )) + self._buffer[name] = current_values + else: + # Extend and remember the length. Attributes with dtype "bytes" + # may have variable length. Handle this case. We need to avoid + # two thigs: + # - Having wrong length of the bytes array and ideally also + # avoid padding. + # - Using allow_pickle when saving since that could lead to code + # execution when loading malicious dataset. + # We prefix the attribute name by `len_?` such that the new name + # is unique and tells us the lengths of the byte arrays. + counts = self._buffer.get(self._counting_prefix + name, [0]) + counts.append(counts[-1] + + len(value) # type: ignore[arg-type,operator] + ) + self._buffer[self._counting_prefix + name] = counts + + byte_list: list[list[int]] + byte_list = self._buffer.get( # type: ignore[assignment] + name, [[]]) + byte_list[0].extend(list(value) # type: ignore[arg-type] + ) + self._buffer[name] = byte_list # type: ignore[assignment] def close(self) -> None: """Close the shard file(-s). @@ -73,14 +124,31 @@ def close(self) -> None: assert not self._shard_file.is_file() return + # Deal properly with "bytes" attributes. + for attribute in self.dataset_structure.saved_data_description: + if attribute.dtype != "bytes": + continue + self._buffer[attribute.name] = [ + np.array( + self._buffer[attribute.name][0], + dtype=np.uint8, + ) + ] + # Write the buffer into a file. match self.dataset_structure.compression: case "ZIP": - np.savez_compressed(str(self._shard_file), - **self._buffer) # type: ignore[arg-type] + np.savez_compressed( + str(self._shard_file), + allow_pickle=False, + **self._buffer, # type: ignore[arg-type] + ) case "": - np.savez(str(self._shard_file), - **self._buffer) # type: ignore[arg-type] + np.savez( + str(self._shard_file), + allow_pickle=False, + **self._buffer, # type: ignore[arg-type] + ) case _: # Default should never happen since ShardWriterBase checks that # the requested compression type is supported. diff --git a/src/sedpack/io/tfrec/read.py b/src/sedpack/io/tfrec/read.py index 25eb563e..8450f963 100644 --- a/src/sedpack/io/tfrec/read.py +++ b/src/sedpack/io/tfrec/read.py @@ -19,15 +19,13 @@ from pathlib import Path from typing import Any, AsyncIterator, Callable, Iterable -import numpy as np -import numpy.typing as npt import tensorflow as tf -from sedpack.io.metadata import Attribute, DatasetStructure +from sedpack.io.metadata import DatasetStructure from sedpack.io.shard import IterateShardBase from sedpack.io.shard.iterate_shard_base import T from sedpack.io.tfrec.tfdata import get_from_tfrecord -from sedpack.io.types import AttributeValueT, ExampleT +from sedpack.io.types import ExampleT from sedpack.io.utils import func_or_identity @@ -47,17 +45,6 @@ def __init__( self.from_tfrecord: Callable[[Any], Any] | None = None self.num_parallel_calls: int = num_parallel_calls - @staticmethod - def decode_attribute(value: npt.NDArray[np.generic], - attribute: Attribute) -> AttributeValueT: - match attribute.dtype: - case "str": - return bytes(value).decode("utf-8") - case "bytes": - return value.tobytes() - case _: - return value - def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """Iterate a shard saved in the TFRec format """ @@ -78,18 +65,7 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: num_parallel_calls=self.num_parallel_calls, ) - for example in tf_dataset_examples.as_numpy_iterator(): - yield { - name: - IterateShardTFRec.decode_attribute( - value=value, - attribute=attribute, - ) for (name, value), attribute in zip( - # `example` is a dictionary, mypy does not know that. - example.items(), # type: ignore[attr-defined] - self.dataset_structure.saved_data_description, - ) - } + yield from tf_dataset_examples.as_numpy_iterator() # type: ignore[misc] # TODO(issue #85) fix and test async iterator typing async def iterate_shard_async( # pylint: disable=invalid-overridden-method diff --git a/src/sedpack/io/tfrec/tfdata.py b/src/sedpack/io/tfrec/tfdata.py index f64baf1e..543a5950 100644 --- a/src/sedpack/io/tfrec/tfdata.py +++ b/src/sedpack/io/tfrec/tfdata.py @@ -70,6 +70,7 @@ def get_from_tfrecord( for attribute in saved_data_description: dtype = { "str": tf.string, + "int": tf.int64, "bytes": tf.string, "uint8": tf.int64, "int8": tf.int64, @@ -93,22 +94,15 @@ def get_from_tfrecord( def from_tfrecord(tf_record: Any) -> Any: rec = tf.io.parse_single_example(tf_record, tf_features) for attribute in saved_data_description: - match attribute.dtype: - case "str": - pass - #rec[attribute.name] = rec[attribute.name].decode("utf-8") - case "float16": - rec[attribute.name] = tf.io.parse_tensor( - rec[attribute.name], - tf.float16, - ) - rec[attribute.name] = tf.ensure_shape( - rec[attribute.name], - shape=attribute.shape, - ) - case _: - # Nothing extra needs to be done. - pass + if attribute.dtype == "float16": + rec[attribute.name] = tf.io.parse_tensor( + rec[attribute.name], + tf.float16, + ) + rec[attribute.name] = tf.ensure_shape( + rec[attribute.name], + shape=attribute.shape, + ) return rec return from_tfrecord @@ -139,6 +133,7 @@ def to_tfrecord(saved_data_description: list[Attribute], if len(attribute_names) != len(values): raise ValueError(f"There are missing attributes. Got: {values} " f"expected: {attribute_names}") + del attribute_names # Create dictionary of features feature = {} @@ -147,15 +142,19 @@ def to_tfrecord(saved_data_description: list[Attribute], value = values[attribute.name] # Convert the value into a NumPy type. - value = np.array(value) + #value = np.array(value) # Check shape - if attribute.dtype != "bytes" and value.shape != attribute.shape: + if attribute.dtype not in [ + "bytes", + "str", + "int", + ] and value.shape != attribute.shape: raise ValueError(f"Wrong shape of {attribute.name}, expected: " f"{attribute.shape}, got: {value.shape}.") # Set feature value - if attribute.dtype in ["int8", "uint8", "int32", "int64"]: + if attribute.dtype in ["int", "int8", "uint8", "int32", "int64"]: feature[attribute.name] = int64_feature(values[attribute.name]) elif attribute.dtype == "float16": value = value.astype(dtype=np.float16) diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index 19a1080c..51749716 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -57,12 +57,12 @@ def dataset_and_values_dynamic_shape( for _ in range(items): # TODO larger range than just int64 values[f"attribute_{dtype}"].append( - random.randint(-60**2, 60**2)) + random.randint(-2**60, 2**60)) case "str": long_string = "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ" \ - "https://arxiv.org/abs/2306.07249 ḹẩḇőꝛế" \ - "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ" \ - ":(){ :|:& };: éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ." + "https://arxiv.org/abs/2306.07249 ḹẩḇőꝛế" \ + "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ" \ + ":(){ :|:& };: éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ." for _ in range(items): begin: int = random.randint(0, len(long_string) // 2) end: int = random.randint(begin + 1, len(long_string)) @@ -74,6 +74,7 @@ def dataset_and_values_dynamic_shape( 0, 256, size=random.randint(5, 20), + dtype=np.uint8, ).tobytes()) dataset_structure = sedpack.io.metadata.DatasetStructure( @@ -120,18 +121,18 @@ def dataset_and_values_dynamic_shape( "dtypes": ["str"], "compression": "GZIP", }, - #{ - # "dtypes": ["bytes"], - # "compression": "GZIP", - #}, - #{ - # "dtypes": ["int"], - # "compression": "GZIP", - #}, - #{ - # "dtypes": ["str", "bytes", "int"], - # "compression": "GZIP", - #}, + { + "dtypes": ["bytes"], + "compression": "GZIP", + }, + { + "dtypes": ["int"], + "compression": "GZIP", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "GZIP", + }, ], ) def values_and_dataset_tfrec(request, tmpdir_factory) -> None: @@ -157,14 +158,14 @@ def values_and_dataset_tfrec(request, tmpdir_factory) -> None: "dtypes": ["bytes"], "compression": "ZIP", }, - #{ - # "dtypes": ["int"], - # "compression": "ZIP", - #}, - #{ - # "dtypes": ["str", "bytes", "int"], - # "compression": "ZIP", - #}, + { + "dtypes": ["int"], + "compression": "ZIP", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "ZIP", + }, ], ) def values_and_dataset_npz(request, tmpdir_factory) -> None: @@ -186,18 +187,18 @@ def values_and_dataset_npz(request, tmpdir_factory) -> None: "dtypes": ["str"], "compression": "LZ4", }, - #{ - # "dtypes": ["bytes"], - # "compression": "LZ4", - #}, - #{ - # "dtypes": ["int"], - # "compression": "LZ4", - #}, - #{ - # "dtypes": ["str", "bytes", "int"], - # "compression": "LZ4", - #}, + { + "dtypes": ["bytes"], + "compression": "LZ4", + }, + { + "dtypes": ["int"], + "compression": "LZ4", + }, + { + "dtypes": ["str", "bytes", "int"], + "compression": "LZ4", + }, ], ) def values_and_dataset_fb(request, tmpdir_factory) -> None: @@ -243,8 +244,15 @@ def check_iteration_of_values( )): assert len(example) == len(values) for name, returned_value in example.items(): - assert type(returned_value) == type(values[name][i]) - assert returned_value == values[name][i] + if dataset.dataset_structure.shard_file_type != "tfrec": + assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) + else: + if "attribute_str" == name: + assert returned_value == values[name][i].encode( + "utf-8") + else: + assert returned_value == values[name][i] case "as_numpy_iterator_concurrent": for i, example in enumerate( dataset.as_numpy_iterator_concurrent( @@ -254,8 +262,15 @@ def check_iteration_of_values( )): assert len(example) == len(values) for name, returned_value in example.items(): - assert type(returned_value) == type(values[name][i]) - assert returned_value == values[name][i] + if dataset.dataset_structure.shard_file_type != "tfrec": + assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) + else: + if "attribute_str" == name: + assert returned_value == values[name][i].encode( + "utf-8") + else: + assert returned_value == values[name][i] case "as_numpy_iterator_rust": for i, example in enumerate( dataset.as_numpy_iterator_concurrent( @@ -265,8 +280,8 @@ def check_iteration_of_values( )): assert len(example) == len(values) for name, returned_value in example.items(): - assert type(returned_value) == type(values[name][i]) assert returned_value == values[name][i] + assert type(returned_value) == type(values[name][i]) # We tested everything if i + 1 != len(next(iter(values.values()))): From ddec48e37201cd8705515720acfb9a42d455d260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 11:37:10 +0200 Subject: [PATCH 07/17] [squash] typo --- src/sedpack/io/shard/shard_writer_np.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index a3759109..1a959711 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -97,7 +97,7 @@ def _write(self, values: ExampleT) -> None: else: # Extend and remember the length. Attributes with dtype "bytes" # may have variable length. Handle this case. We need to avoid - # two thigs: + # two things: # - Having wrong length of the bytes array and ideally also # avoid padding. # - Using allow_pickle when saving since that could lead to code From 19b0d6f0c6883f2de83d569955372e75d653bab0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 11:58:56 +0200 Subject: [PATCH 08/17] [squash] fix deprecated warning --- src/sedpack/io/flatbuffer/iterate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/sedpack/io/flatbuffer/iterate.py b/src/sedpack/io/flatbuffer/iterate.py index d5d16954..97c1e699 100644 --- a/src/sedpack/io/flatbuffer/iterate.py +++ b/src/sedpack/io/flatbuffer/iterate.py @@ -116,11 +116,12 @@ def decode_array( case "bytes": return np_bytes.tobytes() case "int": - return int( - np.frombuffer( - buffer=np_bytes, - dtype=np.dtype("int64").newbyteorder("<"), - )) + array = np.frombuffer( + buffer=np_bytes, + dtype=np.dtype("int64").newbyteorder("<"), + ) + assert array.shape == (1,), f"{array.shape = }" + return int(array[0]) case _: # The rest is interpreted as NumPy array. pass From 552e01f6fdb3dcadc5c49595bc3bcdbb525ba10c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 12:09:21 +0200 Subject: [PATCH 09/17] [drop] Try allowing pickle when saving but not when loading --- src/sedpack/io/shard/shard_writer_np.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 1a959711..13a154c9 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -140,13 +140,13 @@ def close(self) -> None: case "ZIP": np.savez_compressed( str(self._shard_file), - allow_pickle=False, + #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) case "": np.savez( str(self._shard_file), - allow_pickle=False, + #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) case _: From 6ee51c1f6e14bf1484c32034bf88b86da86884e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 12:23:53 +0200 Subject: [PATCH 10/17] [squash] add pickle explanation --- src/sedpack/io/shard/shard_writer_np.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 13a154c9..b1e599b8 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -135,17 +135,22 @@ def close(self) -> None: ) ] - # Write the buffer into a file. + # Write the buffer into a file. We should not need to allow_pickle while + # saving (the default value is True). But on GitHub actions macos-13 + # runner the tests were failing while reading. The security concern + # (code execution) should be more on the side of loading pickled data. match self.dataset_structure.compression: case "ZIP": np.savez_compressed( str(self._shard_file), + # See comment above. #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) case "": np.savez( str(self._shard_file), + # See comment above. #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) From f8d2ff89878f03398258f98e4bd73be48c3c4c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 13:21:53 +0200 Subject: [PATCH 11/17] [squash] nit allow pickle --- src/sedpack/io/shard/shard_writer_np.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index b1e599b8..9551fd9a 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -143,15 +143,11 @@ def close(self) -> None: case "ZIP": np.savez_compressed( str(self._shard_file), - # See comment above. - #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) case "": np.savez( str(self._shard_file), - # See comment above. - #allow_pickle=False, **self._buffer, # type: ignore[arg-type] ) case _: From e834ea699bb29ae08f9c66eb03d82fc5e4afa2fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 13:58:24 +0200 Subject: [PATCH 12/17] [squash] NIT about zip (it is better this way, yes but not for the reasons mentioned) --- src/sedpack/io/shard/shard_writer_np.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 9551fd9a..7c8fcdf7 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -82,10 +82,10 @@ def _write(self, values: ExampleT) -> None: if not self._buffer: self._buffer = {} - for (name, value), attribute in zip( - values.items(), - self.dataset_structure.saved_data_description, - ): + for attribute in self.dataset_structure.saved_data_description: + name = attribute.name + value = values[name] + if attribute.dtype != "bytes": current_values = self._buffer.get(name, []) current_values.append( From 5956646ce944e0f7532afca667a09c71f0bbb622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 14:12:46 +0200 Subject: [PATCH 13/17] [squash] fix nits --- src/sedpack/io/npz/iterate_npz.py | 12 ++++++------ src/sedpack/io/shard/shard_writer_flatbuffer.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index d3ccf772..ead3bd47 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -99,12 +99,12 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: # A given shard contains the same number of elements for each # attribute. elements: int - for attribute in self.dataset_structure.saved_data_description: - if self._counting_prefix + attribute.name in shard_content: - elements = len( - shard_content[self._counting_prefix + attribute.name]) - 1 - else: - elements = len(shard_content[attribute.name]) + first_attribute = self.dataset_structure.saved_data_description[0] + if self._counting_prefix + first_attribute.name in shard_content: + elements = len( + shard_content[self._counting_prefix + first_attribute.name]) - 1 + else: + elements = len(shard_content[first_attribute.name]) for i in range(elements): yield { diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index 2bebb3b1..0a8ad7c1 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -228,8 +228,7 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] attribute=attribute, value=value, ) - value_flattened = np.copy(value).flatten() - value_np = np.array(value_flattened, dtype=attribute.dtype) + value_np = np.array(value, dtype=attribute.dtype) alignment = value_np.dtype.itemsize # Total length of the array (in bytes). From 8e9cee831d0ecb67bb6143741aba39df3eb2c813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 14:25:03 +0200 Subject: [PATCH 14/17] [squash] remove commented out code --- src/sedpack/io/tfrec/tfdata.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sedpack/io/tfrec/tfdata.py b/src/sedpack/io/tfrec/tfdata.py index 543a5950..a07d6e54 100644 --- a/src/sedpack/io/tfrec/tfdata.py +++ b/src/sedpack/io/tfrec/tfdata.py @@ -141,9 +141,6 @@ def to_tfrecord(saved_data_description: list[Attribute], for attribute in saved_data_description: value = values[attribute.name] - # Convert the value into a NumPy type. - #value = np.array(value) - # Check shape if attribute.dtype not in [ "bytes", From e8ba34cacd8918d2541d4504088e89aa1a53749e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Fri, 29 Aug 2025 14:31:08 +0200 Subject: [PATCH 15/17] [squash] also bump package version to fix possible incompatibilities --- rust/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 06836768..8cd75648 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sedpack_rs" -version = "0.1.3" +version = "0.1.4" edition = "2021" description = "Rust bindings for sedpack a general ML dataset package" authors = [ From 7e03cdc3c9fdd41ad8a85441fb0f7b3028d3409d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Mon, 8 Sep 2025 17:56:14 +0200 Subject: [PATCH 16/17] [squash] fix incomprehensible comprehension bump sedpack version (this PR fixes major problem) --- rust/Cargo.lock | 2 +- src/sedpack/io/npz/iterate_npz.py | 56 ++++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0b3f6306..e2e1789e 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -847,7 +847,7 @@ dependencies = [ [[package]] name = "sedpack_rs" -version = "0.1.3" +version = "0.1.4" dependencies = [ "criterion", "flatbuffers", diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index ead3bd47..cb994448 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,6 +35,40 @@ class IterateShardNP(IterateShardBase[T]): @staticmethod def decode_attribute( + attribute: Attribute, + example_index: int, + counting_prefix: str, + shard_content: dict[str, list[AttributeValueT]], + ) -> AttributeValueT: + """Choose the correct way to decode the given attribute. + + Args: + + attribute (Attribute): Information about the attribute being decoded. + + example_index (int): Which example from this shard is being decoded. + + counting_prefix (str): For the case of `bytes` attributes we need to + store them in continuous array otherwise variable length would require + allowing pickling and result in a potential arbitrary code execution. + + shard_content (dict[str, list[AttributeValueT]]): The shard values. + """ + if attribute.dtype == "bytes": + return IterateShardNP.decode_bytes_attribute( + value=shard_content[attribute.name][0], + indexes=shard_content[counting_prefix + attribute.name], + attribute=attribute, + index=example_index, + ) + + return IterateShardNP.decode_non_bytes_attribute( + np_value=shard_content[attribute.name][example_index], + attribute=attribute, + ) + + @staticmethod + def decode_non_bytes_attribute( np_value: AttributeValueT, attribute: Attribute, ) -> AttributeValueT: @@ -106,19 +140,14 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: else: elements = len(shard_content[first_attribute.name]) - for i in range(elements): + for example_index in range(elements): yield { attribute.name: IterateShardNP.decode_attribute( - np_value=shard_content[attribute.name][i], attribute=attribute, - ) if attribute.dtype != "bytes" else - IterateShardNP.decode_bytes_attribute( - value=shard_content[attribute.name][0], - indexes=shard_content[self._counting_prefix + - attribute.name], - attribute=attribute, - index=i, + example_index=example_index, + counting_prefix=self._counting_prefix, + shard_content=shard_content, ) for attribute in self.dataset_structure.saved_data_description } @@ -146,8 +175,11 @@ async def iterate_shard_async( # pylint: disable=invalid-overridden-method elements = len(values) break - for i in range(elements): - yield {name: value[i] for name, value in shard_content.items()} + for example_index in range(elements): + yield { + name: value[example_index] + for name, value in shard_content.items() + } def process_and_list(self, shard_file: Path) -> list[T]: process_record = func_or_identity(self.process_record) From 52b8f780de6f6813a3e435146805f08d0444be04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Mon, 8 Sep 2025 18:06:49 +0200 Subject: [PATCH 17/17] [squash] avoid creating new strings --- src/sedpack/io/npz/iterate_npz.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/sedpack/io/npz/iterate_npz.py b/src/sedpack/io/npz/iterate_npz.py index cb994448..2de5fd05 100644 --- a/src/sedpack/io/npz/iterate_npz.py +++ b/src/sedpack/io/npz/iterate_npz.py @@ -37,7 +37,7 @@ class IterateShardNP(IterateShardBase[T]): def decode_attribute( attribute: Attribute, example_index: int, - counting_prefix: str, + prefixed_name: str, shard_content: dict[str, list[AttributeValueT]], ) -> AttributeValueT: """Choose the correct way to decode the given attribute. @@ -48,16 +48,17 @@ def decode_attribute( example_index (int): Which example from this shard is being decoded. - counting_prefix (str): For the case of `bytes` attributes we need to + prefixed_name (str): For the case of `bytes` attributes we need to store them in continuous array otherwise variable length would require allowing pickling and result in a potential arbitrary code execution. + This name is the prefix-sum encoded lengths of the attribute values. shard_content (dict[str, list[AttributeValueT]]): The shard values. """ if attribute.dtype == "bytes": return IterateShardNP.decode_bytes_attribute( value=shard_content[attribute.name][0], - indexes=shard_content[counting_prefix + attribute.name], + indexes=shard_content[prefixed_name], attribute=attribute, index=example_index, ) @@ -121,9 +122,13 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: """ # A prefix such that prepended it creates a new name without collision # with any attribute name. - self._counting_prefix: str = "len" + "_" * max( + counting_prefix: str = "len" + "_" * max( len(attribute.name) for attribute in self.dataset_structure.saved_data_description) + self._prefixed_names: dict[str, str] = { + attribute.name: counting_prefix + attribute.name + for attribute in self.dataset_structure.saved_data_description + } shard_content: dict[str, list[AttributeValueT]] = np.load( file_path, @@ -134,9 +139,9 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: # attribute. elements: int first_attribute = self.dataset_structure.saved_data_description[0] - if self._counting_prefix + first_attribute.name in shard_content: + if self._prefixed_names[first_attribute.name] in shard_content: elements = len( - shard_content[self._counting_prefix + first_attribute.name]) - 1 + shard_content[self._prefixed_names[first_attribute.name]]) - 1 else: elements = len(shard_content[first_attribute.name]) @@ -146,7 +151,7 @@ def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]: IterateShardNP.decode_attribute( attribute=attribute, example_index=example_index, - counting_prefix=self._counting_prefix, + prefixed_name=self._prefixed_names[attribute.name], shard_content=shard_content, ) for attribute in self.dataset_structure.saved_data_description