Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 79 additions & 26 deletions src/sedpack/io/dataset_filler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Google LLC
# Copyright 2023-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
number."""

from __future__ import annotations
import concurrent.futures
import dataclasses
from pathlib import Path
from types import TracebackType
Expand All @@ -31,6 +32,13 @@
from sedpack.io.dataset_writing import DatasetWriting


def _close_shard(shard: Shard) -> ShardInfo:
"""Helper function to close a shard. This can be pickled and thus sent to
another process.
"""
return shard.close()


@dataclasses.dataclass
class ShardProgress:
"""Internal information about a shard progress. Since we do not want to
Expand All @@ -47,11 +55,15 @@ class DatasetFillerContext:
group_number.
"""

def __init__(self,
dataset_root_path: Path,
dataset_structure: DatasetStructure,
relative_path_from_split: Path,
write_updates: bool = True) -> None:
def __init__(
self,
dataset_root_path: Path,
dataset_structure: DatasetStructure,
relative_path_from_split: Path,
*,
concurrent_pool: concurrent.futures.Executor,
write_updates: bool = True,
) -> None:
"""Initialize a dataset filler context which writes examples and
automatically opens and closes shards (except possibly not closing the
last shard which is closed by `DatasetFiller`).
Expand All @@ -67,6 +79,8 @@ def __init__(self,
`dataset_root_path / split / relative_path_from_split`. May not
contain "..".

concurrent_pool (concurrent.futures.Executor): Shard file writing.

write_updates (bool): Whether to save progress of written shard info
files. Defaults to `True`. In theory when set to `False` writing
could be faster (not writing the info file after each shard). But
Expand All @@ -84,6 +98,7 @@ def __init__(self,
self._dataset_root_path: Path = dataset_root_path
self._dataset_structure: DatasetStructure = dataset_structure
self._relative_path_from_split: Path = relative_path_from_split
self._concurrent_pool: concurrent.futures.Executor = concurrent_pool
self._write_updates: bool = write_updates

# Constants.
Expand All @@ -94,6 +109,8 @@ def __init__(self,

# Cumulated shard infos.
self._shards_lists: dict[SplitT, ShardsList] = {}
self._future_shard_infos: dict[
SplitT, list[concurrent.futures.Future[ShardInfo]]] = {}

# Random path generator.
self._path_generator = PathGenerator()
Expand All @@ -103,6 +120,34 @@ def shard_lists(self) -> dict[SplitT, ShardsList]:
"""Return information about all shards written by this
DatasetFillerContext.
"""
if not self._future_shard_infos:
return self._shards_lists

for split, shard_info_futures in self._future_shard_infos.items():
for future_info in shard_info_futures:
shard_info: ShardInfo = future_info.result()

# Remember this shard info.
if split not in self._shards_lists:
# Load if exists.
self._shards_lists[split] = ShardsList.load_or_create(
dataset_root_path=self._dataset_root_path,
relative_path_self=shard_info.file_infos[0].file_path.
parent / "shards_list.json",
)
self._shards_lists[split].shard_files.append(shard_info)
self._shards_lists[
split].number_of_examples += shard_info.number_of_examples

# Write down which shard files have been saved.
if self._write_updates:
for shards_list in self._shards_lists.values():
shards_list.write_config(
dataset_root_path=self._dataset_root_path,
hashes=(), # We forget these now.
)

self._future_shard_infos = {}
return self._shards_lists

def _get_new_shard(self, split: SplitT) -> Shard:
Expand Down Expand Up @@ -184,26 +229,13 @@ def write_example(self,
def close_shard(self, shard: Shard, split: SplitT) -> None:
"""Close shard. Called automatically by DatasetFiller.__exit__."""
# Finish writing the shard.
shard_info: ShardInfo = shard.close()

# Remember this shard info.
if split not in self._shards_lists:
# Load if exists.
self._shards_lists[split] = ShardsList.load_or_create(
dataset_root_path=self._dataset_root_path,
relative_path_self=shard_info.file_infos[0].file_path.parent /
"shards_list.json",
)
self._shards_lists[split].shard_files.append(shard_info)
self._shards_lists[
split].number_of_examples += shard_info.number_of_examples

# Write down which shard files have been saved.
if self._write_updates:
self._shards_lists[split].write_config(
dataset_root_path=self._dataset_root_path,
hashes=(), # We forget these now.
)
if split not in self._future_shard_infos:
self._future_shard_infos[split] = []
self._future_shard_infos[split].append(
self._concurrent_pool.submit(
_close_shard,
shard=shard,
))


class DatasetFiller:
Expand All @@ -223,6 +255,7 @@ class DatasetFiller:

def __init__(self,
dataset: DatasetWriting,
concurrency: int = 1,
relative_path_from_split: Path = Path("."),
auto_update_dataset: bool = True) -> None:
"""Context manager for writing examples into a dataset.
Expand All @@ -232,6 +265,10 @@ def __init__(self,
dataset (DatasetWriting): Dataset object representing the dataset we
are filling.

concurrency (int): Setting to a positive integer allows writing shard
files in parallel. Defaults to 1 (sequential writes in another thread
or process). Thread is used for "tfrec" otherwise process.

relative_path_from_split (Path): New shards are created inside
`dataset_root_path / split / relative_path_from_split` or children.
It will be created if it does not exist yet. Useful for
Expand All @@ -244,11 +281,20 @@ def __init__(self,
`DatasetFiller` objects it is advised to set this parameter to
`False` since this is not thread safe.
"""
self._concurrent_pool: concurrent.futures.Executor
if dataset.dataset_structure.shard_file_type == "tfrec":
self._concurrent_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
else:
self._concurrent_pool = concurrent.futures.ProcessPoolExecutor(
max_workers=max(1, concurrency),)

self._dataset_filler_context: DatasetFillerContext
self._dataset_filler_context = DatasetFillerContext(
dataset_root_path=dataset.path,
dataset_structure=dataset.dataset_structure,
relative_path_from_split=relative_path_from_split,
concurrent_pool=self._concurrent_pool,
)
self._auto_update_dataset: bool = auto_update_dataset
self._dataset: DatasetWriting = dataset
Expand Down Expand Up @@ -286,6 +332,13 @@ def __exit__(self, exc_type: Type[BaseException] | None,
split=split,
)

# Wait to finish writing of all files (after closing all shards but
# before updating metadata files).
self._concurrent_pool.shutdown(
wait=True,
cancel_futures=False,
)

# Write updated info files.
self._update_infos()

Expand Down
12 changes: 9 additions & 3 deletions src/sedpack/io/dataset_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,14 @@ def write_multiprocessing( # pylint: disable=too-many-arguments
# Return user-defined results.
return results

def filler(self) -> DatasetFiller:
def filler(self, concurrency: int = 1) -> DatasetFiller:
"""Return a dataset filler context manager for writing examples.

Args:

concurrency (int): Setting to a positive integer allows writing shard
files in parallel. Defaults to 1 (sequential writes).

Example use:
# Context manager properly opens and closes shards.
with dataset.filler() as dataset_filler:
Expand All @@ -153,7 +158,7 @@ def filler(self) -> DatasetFiller:
split=split,
)
"""
return DatasetFiller(self)
return DatasetFiller(dataset=self, concurrency=concurrency)

def write_config(
self,
Expand Down Expand Up @@ -279,7 +284,8 @@ def check(
if real_hashes != file_info.hash_checksums:
raise ValueError(
f"Hash checksum miss-match in {file_info.file_path}"
)
f"got: {real_hashes} but expected "
f"{file_info.hash_checksums}")


# We want to get results back.
Expand Down
24 changes: 5 additions & 19 deletions src/sedpack/io/shard/shard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Google LLC
# Copyright 2023-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@

from pathlib import Path

import sedpack
from sedpack.io.metadata import DatasetStructure
from sedpack.io.shard_file_metadata import ShardInfo
from sedpack.io.types import ExampleT
Expand Down Expand Up @@ -65,16 +64,16 @@ def write(self, values: ExampleT) -> None:
self.shard_info.number_of_examples += 1

def close(self) -> ShardInfo:
"""Close shard and return statistics."""
"""Close shard and return statistics.
"""
if self._shard_writer is None:
raise ValueError("Closing a shard which has not been open.")

self._shard_writer.close()
hash_checksums: tuple[str, ...] = self._shard_writer.close()
self._shard_writer = None

# Compute sha256 checksum.
self.shard_info.file_infos[
0].hash_checksums = self._compute_file_hash_checksums()
self.shard_info.file_infos[0].hash_checksums = hash_checksums

# Return shard info.
return self.shard_info
Expand All @@ -83,16 +82,3 @@ def _get_full_path(self) -> Path:
"""Return full path to the shard file.
"""
return self._dataset_path / self.shard_info.file_infos[0].file_path

def _compute_file_hash_checksums(self) -> tuple[str, ...]:
"""Compute hash checksums of the shard file(-s).

TODO This method should return a list of checksums defined by the user
in `self.dataset_structure`.
"""
# Compute sha256 checksum.
shard_path = self._get_full_path()
return sedpack.io.utils.hash_checksums(
file_path=shard_path,
hashes=self.dataset_structure.hash_checksum_algorithms,
)
12 changes: 10 additions & 2 deletions src/sedpack/io/shard/shard_writer_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2024-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@

import numpy as np

import sedpack
from sedpack.io.metadata import DatasetStructure
from sedpack.io.types import ExampleT, CompressionT

Expand Down Expand Up @@ -81,7 +82,7 @@ def _write(self, values: ExampleT) -> None:
"""

@abstractmethod
def close(self) -> None:
def close(self) -> tuple[str, ...]:
"""Close the shard file(-s).
"""

Expand All @@ -90,3 +91,10 @@ def close(self) -> None:
def supported_compressions() -> list[CompressionT]:
"""Return a list of supported compression types.
"""

def _compute_file_hash_checksums(self) -> tuple[str, ...]:
"""Compute hash checksums of the shard file(-s). """
return sedpack.io.utils.hash_checksums(
file_path=self._shard_file,
hashes=self.dataset_structure.hash_checksum_algorithms,
)
26 changes: 19 additions & 7 deletions src/sedpack/io/shard/shard_writer_flatbuffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Google LLC
# Copyright 2024-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,6 +26,7 @@
from sedpack.io.metadata import Attribute, DatasetStructure
from sedpack.io.types import AttributeValueT, CompressionT, ExampleT
from sedpack.io.shard.shard_writer_base import ShardWriterBase
from sedpack.io.utils import hash_checksums_from_bytes

# Autogenerated from src/sedpack/io/flatbuffer/shard.fbs
import sedpack.io.flatbuffer.shardfile.Attribute as fbapi_Attribute
Expand Down Expand Up @@ -253,13 +254,16 @@ def save_numpy_vector_as_bytearray( # type: ignore[no-any-unimported]
# Return the vector offset.
return builder.EndVector() # type: ignore[no-any-return]

def close(self) -> None:
def close(self) -> tuple[str, ...]:
"""Close the shard file(-s).
"""
if not self._examples:
# Nothing to save.
assert not self._shard_file.is_file()
return
return hash_checksums_from_bytes(
file_content=b"",
hashes=self.dataset_structure.hash_checksum_algorithms,
)

if not self._builder:
raise ValueError("Attempting to close a closed shard")
Expand All @@ -278,13 +282,21 @@ def close(self) -> None:
# Finish the builder.
self._builder.Finish(shard)

# Hash check-sums must be computed from already compressed file.
file_content: bytes = CompressedFile(
self.dataset_structure.compression).compress(
bytes(self._builder.Output()))
self._builder = None
hash_checksums: tuple[str, ...] = hash_checksums_from_bytes(
file_content=file_content,
hashes=self.dataset_structure.hash_checksum_algorithms,
)

# Write the buffer into a file.
with open(self._shard_file, "wb") as file:
compressor = CompressedFile(self.dataset_structure.compression)
file.write(compressor.compress(bytes(self._builder.Output())))
file.write(file_content)

self._builder = None
assert self._shard_file.is_file()
return hash_checksums

@staticmethod
def supported_compressions() -> list[CompressionT]:
Expand Down
Loading
Loading