Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 148 additions & 5 deletions src/sedpack/io/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""Base class for a dataset."""
import itertools
import json
import logging
from pathlib import Path
import random
import semver

from typing import Iterator, Union
from typing import Callable, Iterator, Union

import sedpack
from sedpack.io.shard_file_metadata import ShardInfo, ShardsList, ShardListInfo
Expand Down Expand Up @@ -135,6 +137,12 @@ def dataset_structure(self) -> DatasetStructure:
def dataset_structure(self, value: DatasetStructure) -> None:
self._dataset_info.dataset_structure = value

@property
def logger(self) -> logging.Logger:
"""Get the logger.
"""
return self._logger

def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
"""Iterate all `ShardInfo` in the split.

Expand Down Expand Up @@ -162,10 +170,13 @@ class ShardInfoIterator:
"""Iterate shards of a dataset.
"""

def __init__(self,
split: SplitT | None,
dataset: DatasetBase,
repeat: bool = False) -> None:
def __init__(
self,
*,
split: SplitT | None,
dataset: DatasetBase,
repeat: bool = False,
) -> None:
"""Initialize shard information iteration.

Args:
Expand Down Expand Up @@ -235,3 +246,135 @@ def __next__(self) -> ShardInfo:
"""Get the next item.
"""
return next(self._iterator)


class CachedShardInfoIterator(ShardInfoIterator):
"""Iterate shards of a dataset.
"""

def __init__(
self,
*,
split: SplitT | None,
dataset: DatasetBase,
repeat: bool = False,
shards: int | None = None,
custom_metadata_type_limit: int | None = None,
shard_filter: Callable[[ShardInfo], bool] | None = None,
shuffle: int = 0,
) -> None:
"""Initialize shard information iteration.

Args:

split (SplitT | None): Which split to iterate or all if set to None.

dataset (DatasetBase): The dataset being iterated.

repeat (bool): Should we cycle indefinitely?

shards (int | None): If specified limits the dataset to the first
`shards` shards.

custom_metadata_type_limit (int | None): Ignored when None. If
non-zero then limit the number of shards with different
`custom_metadata`. Take only the first `custom_metadata_type_limit`
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (`json.dumps` with `sort_keys`).

shard_filter (Callable[[ShardInfo], bool | None): If present this is
a function taking the ShardInfo and returning True if the shard shall
be used for traversal and False otherwise.

shuffle (int): When set to 0 the iteration is deterministic otherwise
shuffle the shards with a shuffle buffer of at least `shuffle`
elements. Current implementation shuffles all shard information.
"""
super().__init__(
split=split,
dataset=dataset,
repeat=repeat,
)

self.shuffle: int = shuffle

# Cache the list of shards.
shard_list: list[ShardInfo] = list(
ShardInfoIterator(
split=split,
dataset=dataset,
repeat=False,
))

# Filter if needed.
if shard_filter:
shard_list = [
shard_info for shard_info in shard_list
if shard_filter(shard_info)
]

kept_metadata: set[str] = {
json.dumps(
s.custom_metadata,
sort_keys=True,
) for s in shard_list
}
self.dataset.logger.info(
"Filtered shards with custom metadata: %s from split: %s",
kept_metadata,
split,
)

# Only use a limited amount of shards for each setting of
# custom_metadata.
if custom_metadata_type_limit:
counts: dict[str, int] = {}
old_shards_list = shard_list
shard_list = []
for shard_info in old_shards_list:
k: str = json.dumps(
shard_info.custom_metadata,
sort_keys=True,
)
counts[k] = counts.get(k, 0) + 1
if counts[k] <= custom_metadata_type_limit:
shard_list.append(shard_info)
self.dataset.logger.info("Took %s shards total", len(shard_list))

# Limit the number of shards.
if shards:
shard_list = shard_list[:shards]

# Initial shuffling.
if shuffle:
random.shuffle(shard_list)

# Cached shards.
self._index: int = -1 # The last returned element.
self._shards: list[ShardInfo] = shard_list

def number_of_shards(self) -> int:
"""Return the number of distinct shards that are iterated. When
repeated this method still returns a finite answer.
"""
return len(self._shards)

def __iter__(self) -> Iterator[ShardInfo]:
"""Return the shard information iterator (reentrant).
"""
return self

def __next__(self) -> ShardInfo:
"""Get the next item.
"""
self._index += 1

if self._index >= len(self._shards):
if self.repeat:
self._index = 0
if self.shuffle:
random.shuffle(self._shards)
else:
raise StopIteration

return self._shards[self._index]
98 changes: 55 additions & 43 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np
import tensorflow as tf

from sedpack.io.dataset_base import DatasetBase
from sedpack.io.dataset_base import CachedShardInfoIterator, DatasetBase
from sedpack.io.flatbuffer import IterateShardFlatBuffer
from sedpack.io.itertools import LazyPool
from sedpack.io.itertools import round_robin, round_robin_async, shuffle_buffer
Expand Down Expand Up @@ -69,53 +69,30 @@ def shard_paths_dataset(
non-zero then limit the number of shards with different
`custom_metadata`. Take only the first `custom_metadata_type_limit`
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
different `custom_metadata` (hashed as `json.dumps`).

shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
shard shall be used for traversal and False otherwise.

Returns: A list of shards filenames.
"""
# List of all shard informations
shards_list: list[ShardInfo] = list(
self.shard_info_iterator(split=split))

# Filter which shards to use.
if shard_filter is not None:
shards_list = list(filter(shard_filter, shards_list))

kept_metadata: set[str] = {
str(s.custom_metadata) for s in shards_list
}
self._logger.info(
"Filtered shards with custom metadata: %s from split: %s",
kept_metadata,
split,
)
CachedShardInfoIterator(
split=split,
dataset=self,
repeat=False,
shards=shards,
custom_metadata_type_limit=custom_metadata_type_limit,
shard_filter=shard_filter,
shuffle=0,
))

# Check that there is still something to iterate
if not shards_list:
raise ValueError("The list of shards is empty. Try less "
"restrictive filtering.")

# Truncate the shard list
if shards:
shards_list = shards_list[:shards]

# Only use a limited amount of shards for each setting of
# custom_metadata.
if custom_metadata_type_limit:
counts: dict[tuple[tuple[str, Any], ...], int] = {}
old_shards_list = shards_list
shards_list = []
for shard_info in old_shards_list:
k = tuple(sorted(shard_info.custom_metadata.items()))
counts[k] = counts.get(k, 0) + 1
if counts[k] <= custom_metadata_type_limit:
shards_list.append(shard_info)
self._logger.info("Took %s shards total", len(shards_list))

# Full shard file paths.
shard_paths = [
str(self.path / s.file_infos[0].file_path) for s in shards_list
Expand Down Expand Up @@ -628,9 +605,10 @@ def as_numpy_iterator(

custom_metadata_type_limit (int | None): Ignored when None. If
non-zero then limit the number of shards with different
`custom_metadata`. Take only the first `custom_metadata_type_limit`
shards with the concrete `custom_metadata`. This is best effort for
different `custom_metadata` (hashed as a tuple of sorted items).
`custom_metadata`. Take only the first
`custom_metadata_type_limit` shards with the concrete
`custom_metadata`. This is best effort for different
`custom_metadata` (`json.dumps` with `sort_keys`).

shard_filter (Callable[[ShardInfo], bool | None): If present
this is a function taking the ShardInfo and returning True if the
Expand All @@ -646,7 +624,8 @@ def as_numpy_iterator(
Returns: An iterator over numpy examples (unless the parameter
`process_record` returns something else). No batching is done.
"""
shard_paths_iterator: Iterable[str] = self.as_numpy_common(
shard_iterator: Iterable[ShardInfo] = CachedShardInfoIterator(
dataset=self,
split=split,
shards=shards,
custom_metadata_type_limit=custom_metadata_type_limit,
Expand All @@ -655,22 +634,55 @@ def as_numpy_iterator(
shuffle=shuffle,
)

yield from self.example_iterator_from_shard_iterator(
shard_iterator=shard_iterator,
process_record=process_record,
shuffle=shuffle,
)

def example_iterator_from_shard_iterator(
self,
*,
shard_iterator: Iterable[ShardInfo],
process_record: Callable[[ExampleT], T] | None = None,
shuffle: int = 1_000,
) -> Iterable[ExampleT] | Iterable[T]:
"""Low level iterator of examples given an iterator of shard
information.

Args:

shard_iterator (Iterable[ShardInfo]): These shards are being
iterated.

process_record (Callable[[ExampleT], T] | None): Optional
function that processes a single record.

shuffle (int): How many examples should be shuffled across shards.
When set to 0 the iteration is deterministic. It might be faster to
"""
shard_paths_iterator: Iterable[str] = map(
lambda shard_info: str(self.path / shard_info.file_infos[0].
file_path),
shard_iterator,
)

# Decode the files.
shard_iterator: IterateShardBase[ExampleT]
shards_iterator: IterateShardBase[ExampleT]
match self.dataset_structure.shard_file_type:
case "tfrec":
shard_iterator = IterateShardTFRec(
shards_iterator = IterateShardTFRec(
dataset_structure=self.dataset_structure,
process_record=None,
num_parallel_calls=os.cpu_count() or 1,
)
case "npz":
shard_iterator = IterateShardNP(
shards_iterator = IterateShardNP(
dataset_structure=self.dataset_structure,
process_record=None,
)
case "fb":
shard_iterator = IterateShardFlatBuffer(
shards_iterator = IterateShardFlatBuffer(
dataset_structure=self.dataset_structure,
process_record=None,
)
Expand All @@ -680,7 +692,7 @@ def as_numpy_iterator(

example_iterator = itertools.chain.from_iterable(
map(
shard_iterator.iterate_shard, # type: ignore[arg-type]
shards_iterator.iterate_shard, # type: ignore[arg-type]
shard_paths_iterator,
))

Expand Down
Loading