diff --git a/src/sedpack/io/dataset_filler.py b/src/sedpack/io/dataset_filler.py index 52ee104a..e20bf196 100644 --- a/src/sedpack/io/dataset_filler.py +++ b/src/sedpack/io/dataset_filler.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Google LLC +# Copyright 2023-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ number.""" from __future__ import annotations +import concurrent.futures import dataclasses from pathlib import Path from types import TracebackType @@ -31,6 +32,13 @@ from sedpack.io.dataset_writing import DatasetWriting +def _close_shard(shard: Shard) -> ShardInfo: + """Helper function to close a shard. This can be pickled and thus sent to + another process. + """ + return shard.close() + + @dataclasses.dataclass class ShardProgress: """Internal information about a shard progress. Since we do not want to @@ -47,11 +55,15 @@ class DatasetFillerContext: group_number. """ - def __init__(self, - dataset_root_path: Path, - dataset_structure: DatasetStructure, - relative_path_from_split: Path, - write_updates: bool = True) -> None: + def __init__( + self, + dataset_root_path: Path, + dataset_structure: DatasetStructure, + relative_path_from_split: Path, + *, + concurrent_pool: concurrent.futures.Executor, + write_updates: bool = True, + ) -> None: """Initialize a dataset filler context which writes examples and automatically opens and closes shards (except possibly not closing the last shard which is closed by `DatasetFiller`). @@ -67,6 +79,8 @@ def __init__(self, `dataset_root_path / split / relative_path_from_split`. May not contain "..". + concurrent_pool (concurrent.futures.Executor): Shard file writing. + write_updates (bool): Whether to save progress of written shard info files. Defaults to `True`. In theory when set to `False` writing could be faster (not writing the info file after each shard). But @@ -84,6 +98,7 @@ def __init__(self, self._dataset_root_path: Path = dataset_root_path self._dataset_structure: DatasetStructure = dataset_structure self._relative_path_from_split: Path = relative_path_from_split + self._concurrent_pool: concurrent.futures.Executor = concurrent_pool self._write_updates: bool = write_updates # Constants. @@ -94,6 +109,8 @@ def __init__(self, # Cumulated shard infos. self._shards_lists: dict[SplitT, ShardsList] = {} + self._future_shard_infos: dict[ + SplitT, list[concurrent.futures.Future[ShardInfo]]] = {} # Random path generator. self._path_generator = PathGenerator() @@ -103,6 +120,34 @@ def shard_lists(self) -> dict[SplitT, ShardsList]: """Return information about all shards written by this DatasetFillerContext. """ + if not self._future_shard_infos: + return self._shards_lists + + for split, shard_info_futures in self._future_shard_infos.items(): + for future_info in shard_info_futures: + shard_info: ShardInfo = future_info.result() + + # Remember this shard info. + if split not in self._shards_lists: + # Load if exists. + self._shards_lists[split] = ShardsList.load_or_create( + dataset_root_path=self._dataset_root_path, + relative_path_self=shard_info.file_infos[0].file_path. + parent / "shards_list.json", + ) + self._shards_lists[split].shard_files.append(shard_info) + self._shards_lists[ + split].number_of_examples += shard_info.number_of_examples + + # Write down which shard files have been saved. + if self._write_updates: + for shards_list in self._shards_lists.values(): + shards_list.write_config( + dataset_root_path=self._dataset_root_path, + hashes=(), # We forget these now. + ) + + self._future_shard_infos = {} return self._shards_lists def _get_new_shard(self, split: SplitT) -> Shard: @@ -184,26 +229,13 @@ def write_example(self, def close_shard(self, shard: Shard, split: SplitT) -> None: """Close shard. Called automatically by DatasetFiller.__exit__.""" # Finish writing the shard. - shard_info: ShardInfo = shard.close() - - # Remember this shard info. - if split not in self._shards_lists: - # Load if exists. - self._shards_lists[split] = ShardsList.load_or_create( - dataset_root_path=self._dataset_root_path, - relative_path_self=shard_info.file_infos[0].file_path.parent / - "shards_list.json", - ) - self._shards_lists[split].shard_files.append(shard_info) - self._shards_lists[ - split].number_of_examples += shard_info.number_of_examples - - # Write down which shard files have been saved. - if self._write_updates: - self._shards_lists[split].write_config( - dataset_root_path=self._dataset_root_path, - hashes=(), # We forget these now. - ) + if split not in self._future_shard_infos: + self._future_shard_infos[split] = [] + self._future_shard_infos[split].append( + self._concurrent_pool.submit( + _close_shard, + shard=shard, + )) class DatasetFiller: @@ -223,6 +255,7 @@ class DatasetFiller: def __init__(self, dataset: DatasetWriting, + concurrency: int = 1, relative_path_from_split: Path = Path("."), auto_update_dataset: bool = True) -> None: """Context manager for writing examples into a dataset. @@ -232,6 +265,10 @@ def __init__(self, dataset (DatasetWriting): Dataset object representing the dataset we are filling. + concurrency (int): Setting to a positive integer allows writing shard + files in parallel. Defaults to 1 (sequential writes in another thread + or process). Thread is used for "tfrec" otherwise process. + relative_path_from_split (Path): New shards are created inside `dataset_root_path / split / relative_path_from_split` or children. It will be created if it does not exist yet. Useful for @@ -244,11 +281,20 @@ def __init__(self, `DatasetFiller` objects it is advised to set this parameter to `False` since this is not thread safe. """ + self._concurrent_pool: concurrent.futures.Executor + if dataset.dataset_structure.shard_file_type == "tfrec": + self._concurrent_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=1) + else: + self._concurrent_pool = concurrent.futures.ProcessPoolExecutor( + max_workers=max(1, concurrency),) + self._dataset_filler_context: DatasetFillerContext self._dataset_filler_context = DatasetFillerContext( dataset_root_path=dataset.path, dataset_structure=dataset.dataset_structure, relative_path_from_split=relative_path_from_split, + concurrent_pool=self._concurrent_pool, ) self._auto_update_dataset: bool = auto_update_dataset self._dataset: DatasetWriting = dataset @@ -286,6 +332,13 @@ def __exit__(self, exc_type: Type[BaseException] | None, split=split, ) + # Wait to finish writing of all files (after closing all shards but + # before updating metadata files). + self._concurrent_pool.shutdown( + wait=True, + cancel_futures=False, + ) + # Write updated info files. self._update_infos() diff --git a/src/sedpack/io/dataset_writing.py b/src/sedpack/io/dataset_writing.py index e9f30084..5aab8d0f 100644 --- a/src/sedpack/io/dataset_writing.py +++ b/src/sedpack/io/dataset_writing.py @@ -140,9 +140,14 @@ def write_multiprocessing( # pylint: disable=too-many-arguments # Return user-defined results. return results - def filler(self) -> DatasetFiller: + def filler(self, concurrency: int = 1) -> DatasetFiller: """Return a dataset filler context manager for writing examples. + Args: + + concurrency (int): Setting to a positive integer allows writing shard + files in parallel. Defaults to 1 (sequential writes). + Example use: # Context manager properly opens and closes shards. with dataset.filler() as dataset_filler: @@ -153,7 +158,7 @@ def filler(self) -> DatasetFiller: split=split, ) """ - return DatasetFiller(self) + return DatasetFiller(dataset=self, concurrency=concurrency) def write_config( self, @@ -279,7 +284,8 @@ def check( if real_hashes != file_info.hash_checksums: raise ValueError( f"Hash checksum miss-match in {file_info.file_path}" - ) + f"got: {real_hashes} but expected " + f"{file_info.hash_checksums}") # We want to get results back. diff --git a/src/sedpack/io/shard/shard.py b/src/sedpack/io/shard/shard.py index 407dd691..a6522c6c 100644 --- a/src/sedpack/io/shard/shard.py +++ b/src/sedpack/io/shard/shard.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Google LLC +# Copyright 2023-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ from pathlib import Path -import sedpack from sedpack.io.metadata import DatasetStructure from sedpack.io.shard_file_metadata import ShardInfo from sedpack.io.types import ExampleT @@ -65,16 +64,16 @@ def write(self, values: ExampleT) -> None: self.shard_info.number_of_examples += 1 def close(self) -> ShardInfo: - """Close shard and return statistics.""" + """Close shard and return statistics. + """ if self._shard_writer is None: raise ValueError("Closing a shard which has not been open.") - self._shard_writer.close() + hash_checksums: tuple[str, ...] = self._shard_writer.close() self._shard_writer = None # Compute sha256 checksum. - self.shard_info.file_infos[ - 0].hash_checksums = self._compute_file_hash_checksums() + self.shard_info.file_infos[0].hash_checksums = hash_checksums # Return shard info. return self.shard_info @@ -83,16 +82,3 @@ def _get_full_path(self) -> Path: """Return full path to the shard file. """ return self._dataset_path / self.shard_info.file_infos[0].file_path - - def _compute_file_hash_checksums(self) -> tuple[str, ...]: - """Compute hash checksums of the shard file(-s). - - TODO This method should return a list of checksums defined by the user - in `self.dataset_structure`. - """ - # Compute sha256 checksum. - shard_path = self._get_full_path() - return sedpack.io.utils.hash_checksums( - file_path=shard_path, - hashes=self.dataset_structure.hash_checksum_algorithms, - ) diff --git a/src/sedpack/io/shard/shard_writer_base.py b/src/sedpack/io/shard/shard_writer_base.py index b59ae68b..64f5476f 100644 --- a/src/sedpack/io/shard/shard_writer_base.py +++ b/src/sedpack/io/shard/shard_writer_base.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import numpy as np +import sedpack from sedpack.io.metadata import DatasetStructure from sedpack.io.types import ExampleT, CompressionT @@ -81,7 +82,7 @@ def _write(self, values: ExampleT) -> None: """ @abstractmethod - def close(self) -> None: + def close(self) -> tuple[str, ...]: """Close the shard file(-s). """ @@ -90,3 +91,10 @@ def close(self) -> None: def supported_compressions() -> list[CompressionT]: """Return a list of supported compression types. """ + + def _compute_file_hash_checksums(self) -> tuple[str, ...]: + """Compute hash checksums of the shard file(-s). """ + return sedpack.io.utils.hash_checksums( + file_path=self._shard_file, + hashes=self.dataset_structure.hash_checksum_algorithms, + ) diff --git a/src/sedpack/io/shard/shard_writer_flatbuffer.py b/src/sedpack/io/shard/shard_writer_flatbuffer.py index 0a8ad7c1..cb86919d 100644 --- a/src/sedpack/io/shard/shard_writer_flatbuffer.py +++ b/src/sedpack/io/shard/shard_writer_flatbuffer.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ from sedpack.io.metadata import Attribute, DatasetStructure from sedpack.io.types import AttributeValueT, CompressionT, ExampleT from sedpack.io.shard.shard_writer_base import ShardWriterBase +from sedpack.io.utils import hash_checksums_from_bytes # Autogenerated from src/sedpack/io/flatbuffer/shard.fbs import sedpack.io.flatbuffer.shardfile.Attribute as fbapi_Attribute @@ -253,13 +254,16 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported] # Return the vector offset. return builder.EndVector() # type: ignore[no-any-return] - def close(self) -> None: + def close(self) -> tuple[str, ...]: """Close the shard file(-s). """ if not self._examples: # Nothing to save. assert not self._shard_file.is_file() - return + return hash_checksums_from_bytes( + file_content=b"", + hashes=self.dataset_structure.hash_checksum_algorithms, + ) if not self._builder: raise ValueError("Attempting to close a closed shard") @@ -278,13 +282,21 @@ def close(self) -> None: # Finish the builder. self._builder.Finish(shard) + # Hash check-sums must be computed from already compressed file. + file_content: bytes = CompressedFile( + self.dataset_structure.compression).compress( + bytes(self._builder.Output())) + self._builder = None + hash_checksums: tuple[str, ...] = hash_checksums_from_bytes( + file_content=file_content, + hashes=self.dataset_structure.hash_checksum_algorithms, + ) + # Write the buffer into a file. with open(self._shard_file, "wb") as file: - compressor = CompressedFile(self.dataset_structure.compression) - file.write(compressor.compress(bytes(self._builder.Output()))) + file.write(file_content) - self._builder = None - assert self._shard_file.is_file() + return hash_checksums @staticmethod def supported_compressions() -> list[CompressionT]: diff --git a/src/sedpack/io/shard/shard_writer_np.py b/src/sedpack/io/shard/shard_writer_np.py index 7c8fcdf7..fa600e4c 100644 --- a/src/sedpack/io/shard/shard_writer_np.py +++ b/src/sedpack/io/shard/shard_writer_np.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,8 +23,8 @@ from numpy import typing as npt from sedpack.io.metadata import Attribute, DatasetStructure -from sedpack.io.types import AttributeValueT, CompressionT, ExampleT from sedpack.io.shard.shard_writer_base import ShardWriterBase +from sedpack.io.types import AttributeValueT, CompressionT, ExampleT class ShardWriterNP(ShardWriterBase): @@ -117,12 +117,12 @@ def _write(self, values: ExampleT) -> None: ) self._buffer[name] = byte_list # type: ignore[assignment] - def close(self) -> None: - """Close the shard file(-s). + def close(self) -> tuple[str, ...]: + """Close the shard file and return hash check-sums. """ if not self._buffer: assert not self._shard_file.is_file() - return + return () # Deal properly with "bytes" attributes. for attribute in self.dataset_structure.saved_data_description: @@ -160,6 +160,7 @@ def close(self) -> None: self._buffer = {} assert self._shard_file.is_file() + return self._compute_file_hash_checksums() @staticmethod def supported_compressions() -> list[CompressionT]: diff --git a/src/sedpack/io/shard/shard_writer_tfrec.py b/src/sedpack/io/shard/shard_writer_tfrec.py index a51b4b45..0df0aa7f 100644 --- a/src/sedpack/io/shard/shard_writer_tfrec.py +++ b/src/sedpack/io/shard/shard_writer_tfrec.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -80,13 +80,14 @@ def _write(self, values: ExampleT) -> None: ) self._tf_shard_writer.write(example) - def close(self) -> None: + def close(self) -> tuple[str, ...]: """Close the shard file(-s). """ if not self._tf_shard_writer: raise ValueError("Trying to close a shard that was not open") self._tf_shard_writer.close() self._tf_shard_writer = None + return self._compute_file_hash_checksums() @staticmethod def supported_compressions() -> list[CompressionT]: diff --git a/src/sedpack/io/utils.py b/src/sedpack/io/utils.py index 0590b377..8067955d 100644 --- a/src/sedpack/io/utils.py +++ b/src/sedpack/io/utils.py @@ -60,6 +60,35 @@ def _get_hash_function(name: HashChecksumT) -> HashProtocol: return hashlib.new(name) +def hash_checksums_from_bytes( + file_content: bytes, + hashes: tuple[HashChecksumT, ...], +) -> tuple[str, ...]: + """Compute the hex-encoded hash checksums. An alternative to + `hash_checksums` to avoid reading the file again. + + Args: + + file_content (bytes): The whole file content. + + hashes (tuple[HashChecksumT, ...]): A tuple of hash algorithm names to + be computed. + + Returns: hex-encoded hash checksums of the file in the order given by + `hashes`. + """ + # Actual hash functions, same order as hashes. + hash_functions = tuple( + _get_hash_function(hash_name) for hash_name in hashes) + + # Update all hashes. + for hash_function in hash_functions: + hash_function.update(file_content) + + # Hex-encoded results, same order as hashes. + return tuple(hash_function.hexdigest() for hash_function in hash_functions) + + def hash_checksums(file_path: Path, hashes: tuple[HashChecksumT, ...]) -> tuple[str, ...]: """Compute the hex-encoded hash checksums. diff --git a/tests/io/iteration/test_rust_batched_generator.py b/tests/io/iteration/test_rust_batched_generator.py index bfcd83a2..7fd98f15 100644 --- a/tests/io/iteration/test_rust_batched_generator.py +++ b/tests/io/iteration/test_rust_batched_generator.py @@ -107,7 +107,7 @@ def dataset_and_values(tmpdir_factory) -> None: # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for i in range(data_points): filler.write_example( values={ diff --git a/tests/io/test_bytes.py b/tests/io/test_bytes.py index ce57f66f..39cbb606 100644 --- a/tests/io/test_bytes.py +++ b/tests/io/test_bytes.py @@ -62,7 +62,7 @@ def test_attribute_bytes(tmpdir: Union[str, Path]) -> None: # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value}, diff --git a/tests/io/test_check_consistency.py b/tests/io/test_check_consistency.py index 3a4165b7..e2c519dc 100644 --- a/tests/io/test_check_consistency.py +++ b/tests/io/test_check_consistency.py @@ -61,7 +61,7 @@ def get_dataset(tmp_path: Union[str, Path], shard_file_type: ShardFileTypeT, # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value}, diff --git a/tests/io/test_custom_metadata_type_limit.py b/tests/io/test_custom_metadata_type_limit.py index c422e2cc..2d2633a4 100644 --- a/tests/io/test_custom_metadata_type_limit.py +++ b/tests/io/test_custom_metadata_type_limit.py @@ -61,7 +61,7 @@ def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, method: str, # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value}, diff --git a/tests/io/test_end2end.py b/tests/io/test_end2end.py index 3aa54e08..8dd5268d 100644 --- a/tests/io/test_end2end.py +++ b/tests/io/test_end2end.py @@ -64,7 +64,7 @@ def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, method: str, # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value}, diff --git a/tests/io/test_end2end_dtypes.py b/tests/io/test_end2end_dtypes.py index 6d7944f0..6e11efda 100644 --- a/tests/io/test_end2end_dtypes.py +++ b/tests/io/test_end2end_dtypes.py @@ -97,7 +97,7 @@ def dataset_and_values_dynamic_shape( # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for i in range(items): filler.write_example( values={ diff --git a/tests/io/test_end2end_shuffled.py b/tests/io/test_end2end_shuffled.py index 67b5e506..84b1cda2 100644 --- a/tests/io/test_end2end_shuffled.py +++ b/tests/io/test_end2end_shuffled.py @@ -58,7 +58,7 @@ def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, method: str, # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value}, diff --git a/tests/io/test_hash_checksums.py b/tests/io/test_hash_checksums.py index f2e78757..c64272e9 100644 --- a/tests/io/test_hash_checksums.py +++ b/tests/io/test_hash_checksums.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2024-2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,10 @@ from pathlib import Path -from sedpack.io.utils import hash_checksums +import numpy as np +import pytest + +from sedpack.io.utils import hash_checksums, hash_checksums_from_bytes def test_compress_gzip_write(tmpdir: str | Path) -> None: @@ -31,3 +34,39 @@ def test_compress_gzip_write(tmpdir: str | Path) -> None: "b7f783baed8297f0db917462184ff4f08e69c2d5e5f79a942600f9725f58ce1f29c18139bf80b06c0fff2bdd34738452ecf40c488c22a7e3d80cdf6f9c1c0d47", "9203b0c4439fd1e6ae5878866337b7c532acd6d9260150c80318e8ab8c27ce330189f8df94fb890df1d298ff360627e1", ) + + +@pytest.mark.parametrize("size", [1, 3, 7, 8, 13, 17, 33, 67]) +def test_equivalent(size: int, tmp_path: Path) -> None: + file_content = np.random.randint( + 0, + 256, + size=size, + dtype=np.uint8, + ).tobytes() + hashes = ( + "md5", + "sha256", + "sha512", + "sha384", + "xxh32", + "xxh64", + "xxh128", + ) + + tmp_path = tmp_path / "shard_file.extension" + + with open(tmp_path, "wb") as f: + f.write(file_content) + + # Redundant, but read again. + with open(tmp_path, "rb") as f: + assert f.read() == file_content + + assert hash_checksums_from_bytes( + file_content=file_content, + hashes=hashes, + ) == hash_checksums( + file_path=tmp_path, + hashes=hashes, + ) diff --git a/tests/io/test_rust_iter.py b/tests/io/test_rust_iter.py index c227cac6..f694ab88 100644 --- a/tests/io/test_rust_iter.py +++ b/tests/io/test_rust_iter.py @@ -61,7 +61,7 @@ def end2end(tmpdir: Union[str, Path], dtype: npt.DTypeLike, # Fill data in the dataset - with dataset.filler() as filler: + with dataset.filler(concurrency=np.random.randint(0, 4),) as filler: for attribute_value in array_of_values: filler.write_example( values={"attribute_name": attribute_value},