diff --git a/src/sedpack/io/dataset_base.py b/src/sedpack/io/dataset_base.py index d01cf97a..35f9f736 100644 --- a/src/sedpack/io/dataset_base.py +++ b/src/sedpack/io/dataset_base.py @@ -13,11 +13,13 @@ # limitations under the License. """Base class for a dataset.""" import itertools +import json import logging from pathlib import Path +import random import semver -from typing import Iterator, Union +from typing import Callable, Iterator, Union import sedpack from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo @@ -135,6 +137,12 @@ def dataset_structure(self) -> DatasetStructure: def dataset_structure(self, value: DatasetStructure) -> None: self._dataset_info.dataset_structure = value + @property + def logger(self) -> logging.Logger: + """Get the logger. + """ + return self._logger + def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]: """Iterate all `ShardInfo` in the split. @@ -162,10 +170,13 @@ class ShardInfoIterator: """Iterate shards of a dataset. """ - def __init__(self, - split: SplitT | None, - dataset: DatasetBase, - repeat: bool = False) -> None: + def __init__( + self, + *, + split: SplitT | None, + dataset: DatasetBase, + repeat: bool = False, + ) -> None: """Initialize shard information iteration. Args: @@ -235,3 +246,135 @@ def __next__(self) -> ShardInfo: """Get the next item. """ return next(self._iterator) + + +class CachedShardInfoIterator(ShardInfoIterator): + """Iterate shards of a dataset. + """ + + def __init__( + self, + *, + split: SplitT | None, + dataset: DatasetBase, + repeat: bool = False, + shards: int | None = None, + custom_metadata_type_limit: int | None = None, + shard_filter: Callable[[ShardInfo], bool] | None = None, + shuffle: int = 0, + ) -> None: + """Initialize shard information iteration. + + Args: + + split (SplitT | None): Which split to iterate or all if set to None. + + dataset (DatasetBase): The dataset being iterated. + + repeat (bool): Should we cycle indefinitely? + + shards (int | None): If specified limits the dataset to the first + `shards` shards. + + custom_metadata_type_limit (int | None): Ignored when None. If + non-zero then limit the number of shards with different + `custom_metadata`. Take only the first `custom_metadata_type_limit` + shards with the concrete `custom_metadata`. This is best effort for + different `custom_metadata` (`json.dumps` with `sort_keys`). + + shard_filter (Callable[[ShardInfo], bool | None): If present this is + a function taking the ShardInfo and returning True if the shard shall + be used for traversal and False otherwise. + + shuffle (int): When set to 0 the iteration is deterministic otherwise + shuffle the shards with a shuffle buffer of at least `shuffle` + elements. Current implementation shuffles all shard information. + """ + super().__init__( + split=split, + dataset=dataset, + repeat=repeat, + ) + + self.shuffle: int = shuffle + + # Cache the list of shards. + shard_list: list[ShardInfo] = list( + ShardInfoIterator( + split=split, + dataset=dataset, + repeat=False, + )) + + # Filter if needed. + if shard_filter: + shard_list = [ + shard_info for shard_info in shard_list + if shard_filter(shard_info) + ] + + kept_metadata: set[str] = { + json.dumps( + s.custom_metadata, + sort_keys=True, + ) for s in shard_list + } + self.dataset.logger.info( + "Filtered shards with custom metadata: %s from split: %s", + kept_metadata, + split, + ) + + # Only use a limited amount of shards for each setting of + # custom_metadata. + if custom_metadata_type_limit: + counts: dict[str, int] = {} + old_shards_list = shard_list + shard_list = [] + for shard_info in old_shards_list: + k: str = json.dumps( + shard_info.custom_metadata, + sort_keys=True, + ) + counts[k] = counts.get(k, 0) + 1 + if counts[k] <= custom_metadata_type_limit: + shard_list.append(shard_info) + self.dataset.logger.info("Took %s shards total", len(shard_list)) + + # Limit the number of shards. + if shards: + shard_list = shard_list[:shards] + + # Initial shuffling. + if shuffle: + random.shuffle(shard_list) + + # Cached shards. + self._index: int = -1 # The last returned element. + self._shards: list[ShardInfo] = shard_list + + def number_of_shards(self) -> int: + """Return the number of distinct shards that are iterated. When + repeated this method still returns a finite answer. + """ + return len(self._shards) + + def __iter__(self) -> Iterator[ShardInfo]: + """Return the shard information iterator (reentrant). + """ + return self + + def __next__(self) -> ShardInfo: + """Get the next item. + """ + self._index += 1 + + if self._index >= len(self._shards): + if self.repeat: + self._index = 0 + if self.shuffle: + random.shuffle(self._shards) + else: + raise StopIteration + + return self._shards[self._index] diff --git a/src/sedpack/io/dataset_iteration.py b/src/sedpack/io/dataset_iteration.py index a0f24ffe..10fcb09c 100644 --- a/src/sedpack/io/dataset_iteration.py +++ b/src/sedpack/io/dataset_iteration.py @@ -30,7 +30,7 @@ import numpy as np import tensorflow as tf -from sedpack.io.dataset_base import DatasetBase +from sedpack.io.dataset_base import CachedShardInfoIterator, DatasetBase from sedpack.io.flatbuffer import IterateShardFlatBuffer from sedpack.io.itertools import LazyPool from sedpack.io.itertools import round_robin, round_robin_async, shuffle_buffer @@ -69,7 +69,7 @@ def shard_paths_dataset( non-zero then limit the number of shards with different `custom_metadata`. Take only the first `custom_metadata_type_limit` shards with the concrete `custom_metadata`. This is best effort for - different `custom_metadata` (hashed as a tuple of sorted items). + different `custom_metadata` (hashed as `json.dumps`). shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the @@ -77,45 +77,22 @@ def shard_paths_dataset( Returns: A list of shards filenames. """ - # List of all shard informations shards_list: list[ShardInfo] = list( - self.shard_info_iterator(split=split)) - - # Filter which shards to use. - if shard_filter is not None: - shards_list = list(filter(shard_filter, shards_list)) - - kept_metadata: set[str] = { - str(s.custom_metadata) for s in shards_list - } - self._logger.info( - "Filtered shards with custom metadata: %s from split: %s", - kept_metadata, - split, - ) + CachedShardInfoIterator( + split=split, + dataset=self, + repeat=False, + shards=shards, + custom_metadata_type_limit=custom_metadata_type_limit, + shard_filter=shard_filter, + shuffle=0, + )) # Check that there is still something to iterate if not shards_list: raise ValueError("The list of shards is empty. Try less " "restrictive filtering.") - # Truncate the shard list - if shards: - shards_list = shards_list[:shards] - - # Only use a limited amount of shards for each setting of - # custom_metadata. - if custom_metadata_type_limit: - counts: dict[tuple[tuple[str, Any], ...], int] = {} - old_shards_list = shards_list - shards_list = [] - for shard_info in old_shards_list: - k = tuple(sorted(shard_info.custom_metadata.items())) - counts[k] = counts.get(k, 0) + 1 - if counts[k] <= custom_metadata_type_limit: - shards_list.append(shard_info) - self._logger.info("Took %s shards total", len(shards_list)) - # Full shard file paths. shard_paths = [ str(self.path / s.file_infos[0].file_path) for s in shards_list @@ -628,9 +605,10 @@ def as_numpy_iterator( custom_metadata_type_limit (int | None): Ignored when None. If non-zero then limit the number of shards with different - `custom_metadata`. Take only the first `custom_metadata_type_limit` - shards with the concrete `custom_metadata`. This is best effort for - different `custom_metadata` (hashed as a tuple of sorted items). + `custom_metadata`. Take only the first + `custom_metadata_type_limit` shards with the concrete + `custom_metadata`. This is best effort for different + `custom_metadata` (`json.dumps` with `sort_keys`). shard_filter (Callable[[ShardInfo], bool | None): If present this is a function taking the ShardInfo and returning True if the @@ -646,7 +624,8 @@ def as_numpy_iterator( Returns: An iterator over numpy examples (unless the parameter `process_record` returns something else). No batching is done. """ - shard_paths_iterator: Iterable[str] = self.as_numpy_common( + shard_iterator: Iterable[ShardInfo] = CachedShardInfoIterator( + dataset=self, split=split, shards=shards, custom_metadata_type_limit=custom_metadata_type_limit, @@ -655,22 +634,55 @@ def as_numpy_iterator( shuffle=shuffle, ) + yield from self.example_iterator_from_shard_iterator( + shard_iterator=shard_iterator, + process_record=process_record, + shuffle=shuffle, + ) + + def example_iterator_from_shard_iterator( + self, + *, + shard_iterator: Iterable[ShardInfo], + process_record: Callable[[ExampleT], T] | None = None, + shuffle: int = 1_000, + ) -> Iterable[ExampleT] | Iterable[T]: + """Low level iterator of examples given an iterator of shard + information. + + Args: + + shard_iterator (Iterable[ShardInfo]): These shards are being + iterated. + + process_record (Callable[[ExampleT], T] | None): Optional + function that processes a single record. + + shuffle (int): How many examples should be shuffled across shards. + When set to 0 the iteration is deterministic. It might be faster to + """ + shard_paths_iterator: Iterable[str] = map( + lambda shard_info: str(self.path / shard_info.file_infos[0]. + file_path), + shard_iterator, + ) + # Decode the files. - shard_iterator: IterateShardBase[ExampleT] + shards_iterator: IterateShardBase[ExampleT] match self.dataset_structure.shard_file_type: case "tfrec": - shard_iterator = IterateShardTFRec( + shards_iterator = IterateShardTFRec( dataset_structure=self.dataset_structure, process_record=None, num_parallel_calls=os.cpu_count() or 1, ) case "npz": - shard_iterator = IterateShardNP( + shards_iterator = IterateShardNP( dataset_structure=self.dataset_structure, process_record=None, ) case "fb": - shard_iterator = IterateShardFlatBuffer( + shards_iterator = IterateShardFlatBuffer( dataset_structure=self.dataset_structure, process_record=None, ) @@ -680,7 +692,7 @@ def as_numpy_iterator( example_iterator = itertools.chain.from_iterable( map( - shard_iterator.iterate_shard, # type: ignore[arg-type] + shards_iterator.iterate_shard, # type: ignore[arg-type] shard_paths_iterator, ))