From 62c1b771be885286288f6ec37b7f78f8d037aaca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 6 Aug 2025 22:08:39 +0200 Subject: [PATCH 1/9] Allow arbitrary shard iteration Start migration to allow customizing shard info iteration. --- src/sedpack/io/dataset_base.py | 129 ++++++++++++++++++++++++++-- src/sedpack/io/dataset_iteration.py | 81 +++++++++-------- 2 files changed, 170 insertions(+), 40 deletions(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index d01cf97a..90c20519 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -13,11 +13,13 @@ # limitations under the License. """Base class for a dataset.""" import itertools +import json import logging from pathlib import Path +import random import semver -from typing import Iterator, Union +from typing import Callable, Iterator, Union import sedpack from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo @@ -162,10 +164,13 @@ class ShardInfoIterator: """Iterate shards of a dataset. """ - def __init__(self, - split: SplitT | None, - dataset: DatasetBase, - repeat: bool = False) -> None: + def __init__( + self, + *, + split: SplitT | None, + dataset: DatasetBase, + repeat: bool = False, + ) -> None: """Initialize shard information iteration. Args: @@ -235,3 +240,117 @@ def __next__(self) -> ShardInfo: """Get the next item. """ return next(self._iterator) + + +class CachedShardInfoIterator(ShardInfoIterator): + """Iterate shards of a dataset. + """ + + def __init__( + self, + *, + split: SplitT | None, + dataset: DatasetBase, + repeat: bool = False, + shards: int | None = None, + custom_metadata_type_limit: int | None = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, + shuffle: int = 0, + ) -> None: + """Initialize shard information iteration. + + Args: + + split (SplitT | None): Which split to iterate or all if set to None. + + dataset (DatasetBase): The dataset being iterated. + + repeat (bool): Should we cycle indefinitely? + + shards (int | None): If specified limits the dataset to the first + `shards` shards. + + custom_metadata_type_limit (int | None): Ignored when None. If + non-zero then limit the number of shards with different + `custom_metadata`. Take only the first `custom_metadata_type_limit` + shards with the concrete `custom_metadata`. This is best effort for + different `custom_metadata` (string of sorted items). + + shard_filter (Callable[[ShardInfo], bool | None): If present this is + a function taking the ShardInfo and returning True if the shard shall + be used for traversal and False otherwise. + + shuffle (int): When set to 0 the iteration is deterministic. + """ + super().__init__( + split=split, + dataset=dataset, + repeat=repeat, + ) + + self.shuffle: int = shuffle + + # Cache the list of shards. + shard_list: list[ShardInfo] = list( + ShardInfoIterator( + split=split, + dataset=dataset, + repeat=False, + )) + + # Filter if needed. + if shard_filter: + shard_list = [ + shard_info for shard_info in shard_list + if shard_filter(shard_info) + ] + + # Only use a limited amount of shards for each setting of + # custom_metadata. + if custom_metadata_type_limit: + counts: dict[str, int] = {} + old_shards_list = shard_list + shard_list = [] + for shard_info in old_shards_list: + k: str = str(tuple(sorted(shard_info.custom_metadata.items()))) + counts[k] = counts.get(k, 0) + 1 + if counts[k] <= custom_metadata_type_limit: + shard_list.append(shard_info) + + # Limit the number of shards. + if shards: + shard_list = shard_list[:shards] + + # Initial shuffling. + if shuffle: + random.shuffle(shard_list) + + # Cached shards. + self._index: int = -1 # The last returned element. + self._shards: list[ShardInfo] = shard_list + + def number_of_shards(self) -> int: + """Return the number of distinct shards that are iterated. When + repeated this method still returns a finite answer. + """ + return len(self._shards) + + def __iter__(self) -> Iterator[ShardInfo]: + """Return the shard information iterator (reentrant). + """ + return self + + def __next__(self) -> ShardInfo: + """Get the next item. + """ + self._index += 1 + + if self._index >= len(self._shards): + if self.repeat: + self._index = 0 + if self.shuffle: + random.shuffle(self._shards) + else: + raise StopIteration + + return self._shards[self._index] diff --git a/src/sedpack/io/dataset_iteration.py b/src/sedpack/io/dataset_iteration.py index a0f24ffe..cb0489b9 100644 --- a/src/sedpack/io/dataset_iteration.py +++ b/src/sedpack/io/dataset_iteration.py @@ -30,7 +30,7 @@ import numpy as np import tensorflow as tf -from sedpack.io.dataset_base import DatasetBase +from sedpack.io.dataset_base import CachedShardInfoIterator, DatasetBase from sedpack.io.flatbuffer import IterateShardFlatBuffer from sedpack.io.itertools import LazyPool from sedpack.io.itertools import round_robin, round_robin_async, shuffle_buffer @@ -69,7 +69,7 @@ def shard_paths_dataset( non-zero then limit the number of shards with different `custom_metadata`. Take only the first `custom_metadata_type_limit` shards with the concrete `custom_metadata`. This is best effort for - different `custom_metadata` (hashed as a tuple of sorted items). + different `custom_metadata` (hashed as `json.dumps`). shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the @@ -77,45 +77,22 @@ def shard_paths_dataset( Returns: A list of shards filenames. """ - # List of all shard informations shards_list: list[ShardInfo] = list( - self.shard_info_iterator(split=split)) - - # Filter which shards to use. - if shard_filter is not None: - shards_list = list(filter(shard_filter, shards_list)) - - kept_metadata: set[str] = { - str(s.custom_metadata) for s in shards_list - } - self._logger.info( - "Filtered shards with custom metadata: %s from split: %s", - kept_metadata, - split, - ) + CachedShardInfoIterator( + split=split, + dataset=self, + repeat=False, + shards=shards, + custom_metadata_type_limit=custom_metadata_type_limit, + shard_filter=shard_filter, + shuffle=0, + )) # Check that there is still something to iterate if not shards_list: raise ValueError("The list of shards is empty. Try less " "restrictive filtering.") - # Truncate the shard list - if shards: - shards_list = shards_list[:shards] - - # Only use a limited amount of shards for each setting of - # custom_metadata. - if custom_metadata_type_limit: - counts: dict[tuple[tuple[str, Any], ...], int] = {} - old_shards_list = shards_list - shards_list = [] - for shard_info in old_shards_list: - k = tuple(sorted(shard_info.custom_metadata.items())) - counts[k] = counts.get(k, 0) + 1 - if counts[k] <= custom_metadata_type_limit: - shards_list.append(shard_info) - self._logger.info("Took %s shards total", len(shards_list)) - # Full shard file paths. shard_paths = [ str(self.path / s.file_infos[0].file_path) for s in shards_list @@ -646,7 +623,8 @@ def as_numpy_iterator( Returns: An iterator over numpy examples (unless the parameter `process_record` returns something else). No batching is done. """ - shard_paths_iterator: Iterable[str] = self.as_numpy_common( + shard_iterator: Iterable[ShardInfo] = CachedShardInfoIterator( + dataset=self, split=split, shards=shards, custom_metadata_type_limit=custom_metadata_type_limit, @@ -655,6 +633,39 @@ def as_numpy_iterator( shuffle=shuffle, ) + yield from self.example_iterator_from_shard_iterator( + shard_iterator=shard_iterator, + process_record=process_record, + shuffle=shuffle, + ) + + def example_iterator_from_shard_iterator( + self, + *, + shard_iterator: Iterable[ShardInfo], + process_record: Callable[[ExampleT], T] | None = None, + shuffle: int = 1_000, + ) -> Iterable[ExampleT] | Iterable[T]: + """Low level iterator of examples given an iterator of shard + information. + + Args: + + shard_iterator (Iterable[ShardInfo]): These shards are being + iterated. + + process_record (Callable[[ExampleT], T] | None): Optional + function that processes a single record. + + shuffle (int): How many examples should be shuffled across shards. + When set to 0 the iteration is deterministic. It might be faster to + """ + shard_paths_iterator: Iterable[str] = map( + lambda shard_info: str(self.path / shard_info.file_infos[0]. + file_path), + shard_iterator, + ) + # Decode the files. shard_iterator: IterateShardBase[ExampleT] match self.dataset_structure.shard_file_type: From c122b37259a7a95fcd9dabc9e566679022123820 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 6 Aug 2025 22:17:44 +0200 Subject: [PATCH 2/9] [squash] --- src/sedpack/io/dataset_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index 90c20519..4fd1f720 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -13,7 +13,6 @@ # limitations under the License. """Base class for a dataset.""" import itertools -import json import logging from pathlib import Path import random From 30cd35c88b2d981fde14b181801c8aca762069aa Mon Sep 17 00:00:00 2001 From: kralka Date: Wed, 6 Aug 2025 22:25:59 +0200 Subject: [PATCH 3/9] Update src/sedpack/io/dataset_base.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/sedpack/io/dataset_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index 4fd1f720..30a749a7 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -311,7 +311,7 @@ def __init__( old_shards_list = shard_list shard_list = [] for shard_info in old_shards_list: - k: str = str(tuple(sorted(shard_info.custom_metadata.items()))) + k: str = json.dumps(shard_info.custom_metadata, sort_keys=True) counts[k] = counts.get(k, 0) + 1 if counts[k] <= custom_metadata_type_limit: shard_list.append(shard_info) From 1567f2f86c5fc95a272af2367584a359fc4a44eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 6 Aug 2025 22:31:25 +0200 Subject: [PATCH 4/9] [squash] --- src/sedpack/io/dataset_base.py | 8 ++++++-- src/sedpack/io/dataset_iteration.py | 7 ++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index 30a749a7..f2c73450 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -13,6 +13,7 @@ # limitations under the License. """Base class for a dataset.""" import itertools +import json import logging from pathlib import Path import random @@ -273,7 +274,7 @@ def __init__( non-zero then limit the number of shards with different `custom_metadata`. Take only the first `custom_metadata_type_limit` shards with the concrete `custom_metadata`. This is best effort for - different `custom_metadata` (string of sorted items). + different `custom_metadata` (`json.dumps` with `sort_keys`). shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the shard shall @@ -311,7 +312,10 @@ def __init__( old_shards_list = shard_list shard_list = [] for shard_info in old_shards_list: - k: str = json.dumps(shard_info.custom_metadata, sort_keys=True) + k: str = json.dumps( + shard_info.custom_metadata, + sort_keys=True, + ) counts[k] = counts.get(k, 0) + 1 if counts[k] <= custom_metadata_type_limit: shard_list.append(shard_info) diff --git a/src/sedpack/io/dataset_iteration.py b/src/sedpack/io/dataset_iteration.py index cb0489b9..536b846c 100644 --- a/src/sedpack/io/dataset_iteration.py +++ b/src/sedpack/io/dataset_iteration.py @@ -605,9 +605,10 @@ def as_numpy_iterator( custom_metadata_type_limit (int | None): Ignored when None. If non-zero then limit the number of shards with different - `custom_metadata`. Take only the first `custom_metadata_type_limit` - shards with the concrete `custom_metadata`. This is best effort for - different `custom_metadata` (hashed as a tuple of sorted items). + `custom_metadata`. Take only the first + `custom_metadata_type_limit` shards with the concrete + `custom_metadata`. This is best effort for different + `custom_metadata` (`json.dumps` with `sort_keys`). shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the From 84d648a8f698f6e09340f3eb3aa05e0ef9b3287a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 6 Aug 2025 22:39:46 +0200 Subject: [PATCH 5/9] [squash] --- src/sedpack/io/dataset_iteration.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sedpack/io/dataset_iteration.py b/src/sedpack/io/dataset_iteration.py index 536b846c..10fcb09c 100644 --- a/src/sedpack/io/dataset_iteration.py +++ b/src/sedpack/io/dataset_iteration.py @@ -668,21 +668,21 @@ def example_iterator_from_shard_iterator( ) # Decode the files. - shard_iterator: IterateShardBase[ExampleT] + shards_iterator: IterateShardBase[ExampleT] match self.dataset_structure.shard_file_type: case "tfrec": - shard_iterator = IterateShardTFRec( + shards_iterator = IterateShardTFRec( dataset_structure=self.dataset_structure, process_record=None, num_parallel_calls=os.cpu_count() or 1, ) case "npz": - shard_iterator = IterateShardNP( + shards_iterator = IterateShardNP( dataset_structure=self.dataset_structure, process_record=None, ) case "fb": - shard_iterator = IterateShardFlatBuffer( + shards_iterator = IterateShardFlatBuffer( dataset_structure=self.dataset_structure, process_record=None, ) @@ -692,7 +692,7 @@ def example_iterator_from_shard_iterator( example_iterator = itertools.chain.from_iterable( map( - shard_iterator.iterate_shard, # type: ignore[arg-type] + shards_iterator.iterate_shard, # type: ignore[arg-type] shard_paths_iterator, )) From 8872579960260320264c0a18a61053a8a12006bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Wed, 6 Aug 2025 22:56:02 +0200 Subject: [PATCH 6/9] [squash] reintroduce logging --- src/sedpack/io/dataset_base.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index f2c73450..bd5903c9 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -137,6 +137,12 @@ def dataset_structure(self) -> DatasetStructure: def dataset_structure(self, value: DatasetStructure) -> None: self._dataset_info.dataset_structure = value + @property + def logger(self) -> logging.Logger: + """Get the logger. + """ + return self._logger + def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]: """Iterate all `ShardInfo` in the split. @@ -305,6 +311,15 @@ def __init__( if shard_filter(shard_info) ] + kept_metadata: set[str] = { + str(s.custom_metadata) for s in shards_list + } + self.dataset.logger.info( + "Filtered shards with custom metadata: %s from split: %s", + kept_metadata, + split, + ) + # Only use a limited amount of shards for each setting of # custom_metadata. if custom_metadata_type_limit: @@ -319,6 +334,7 @@ def __init__( counts[k] = counts.get(k, 0) + 1 if counts[k] <= custom_metadata_type_limit: shard_list.append(shard_info) + self.dataset.logger.info("Took %s shards total", len(shard_list)) # Limit the number of shards. if shards: From 4ae1a51c7ec57566f0159c4e0dcadf798fdc5b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=7D=C2=A1l?= Date: Wed, 6 Aug 2025 21:01:20 +0000 Subject: [PATCH 7/9] [squash] fix typo --- src/sedpack/io/dataset_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index bd5903c9..ff2c0dc3 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -312,7 +312,7 @@ def __init__( ] kept_metadata: set[str] = { - str(s.custom_metadata) for s in shards_list + str(s.custom_metadata) for s in shard_list } self.dataset.logger.info( "Filtered shards with custom metadata: %s from split: %s", From 0cf252fdbe338d77ce9fe9974d0bdc10af01c201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Thu, 7 Aug 2025 11:23:13 +0200 Subject: [PATCH 8/9] [squash] shuffle docstring --- src/sedpack/io/dataset_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index ff2c0dc3..4cd21f2c 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -286,7 +286,9 @@ def __init__( a function taking the ShardInfo and returning True if the shard shall be used for traversal and False otherwise. - shuffle (int): When set to 0 the iteration is deterministic. + shuffle (int): When set to 0 the iteration is deterministic otherwise + shuffle the shards with a shuffle buffer of at least `shuffle` + elements. Current implementation shuffles all shard information. """ super().__init__( split=split, From 4b73da722c71a8340deee24807468435f17afd55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karel=20Kr=C3=A1l?= Date: Thu, 7 Aug 2025 11:31:12 +0200 Subject: [PATCH 9/9] [squash] json.dumps for the logging --- src/sedpack/io/dataset_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index 4cd21f2c..35f9f736 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -314,7 +314,10 @@ def __init__( ] kept_metadata: set[str] = { - str(s.custom_metadata) for s in shard_list + json.dumps( + s.custom_metadata, + sort_keys=True, + ) for s in shard_list } self.dataset.logger.info( "Filtered shards with custom metadata: %s from split: %s",