Skip to content
4 changes: 4 additions & 0 deletions fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Batch
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig
Expand All @@ -16,6 +17,7 @@
class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC):
_distributed: "Distributed"
_sampling_parameters: dict[str, SamplingParameters]
_preprocessing: PreprocessingConfig
_cache_directory: pathlib.Path | None

def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None:
Expand All @@ -27,11 +29,13 @@ def setup(
self,
distributed: "Distributed",
sampling_parameters: dict[str, SamplingParameters],
preprocessing: PreprocessingConfig,
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
self._distributed = distributed
self._sampling_parameters = sampling_parameters
self._preprocessing = preprocessing
self._cache_directory = cache_directory

@property
Expand Down
12 changes: 8 additions & 4 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
Expand All @@ -29,7 +31,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""

_datasets: dict[str, SampledDataset]
_sampling_parameters: dict[str, GPTSamplingParameters]
_sampling_parameters: dict[str, SamplingParameters]
_is_setup: bool = False

def __init__(
Expand All @@ -46,15 +48,16 @@ def __init__(
def setup(
self,
distributed: "Distributed",
sampling_parameters: dict[str, GPTSamplingParameters],
sampling_parameters: dict[str, SamplingParameters],
preprocessing: LanguageModelPreprocessingConfig,
cache_directory: pathlib.Path,
timeout: float | None = None,
) -> None:
"""
Load the datasets, and prepare or load the samplings.
This may take a while and a significant amount of cpu memory.
"""
super().setup(distributed, sampling_parameters, cache_directory)
super().setup(distributed, sampling_parameters, preprocessing, cache_directory)

# Check and raise an error if a used dataset is not defined.
for dataset_name in self._sampling_parameters.keys():
Expand All @@ -81,6 +84,7 @@ def setup(
sampling = GPTSamplingData(
config=self._config.sampling,
parameters=sampling_parameters,
preprocessing=preprocessing,
cache_directory=self._cache_directory,
distributed=distributed,
dataset_name=dataset_name,
Expand Down
25 changes: 13 additions & 12 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.sample.abstract import Sample
from fast_llm.utils import Assert, normalize_probabilities

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,6 +84,7 @@ class SamplingData:
# TODO: This prevents the sampling config from being pickled in multiprocessing.
distributed: "Distributed"
dataset_name: str
preprocessing: PreprocessingConfig
# Using a mutable rather than an int so it's shared with all copies made with `update`.
_rank_counter: typing.Iterator[int] = itertools.count

Expand Down Expand Up @@ -114,16 +115,16 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]

@config_class()
class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]):
def build(self) -> SamplableDataset[SampleType]:
def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]:
raise NotImplementedError()

def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]:
return self.build().sample(sampling)
return self.build(sampling.preprocessing).sample(sampling)


@config_class()
class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]):
def build(self) -> "IndexedDataset[SampleType]":
def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]":
raise NotImplementedError()


Expand All @@ -147,10 +148,10 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl
valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)),
)

def build(self) -> "ConcatenatedDataset":
def build(self, preprocessing: PreprocessingConfig) -> "ConcatenatedDataset":
from fast_llm.data.dataset.indexed import ConcatenatedDataset

return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets])
return ConcatenatedDataset(self.name, [dataset.build(preprocessing) for dataset in self.datasets])


@config_class(dynamic_type={SampledDatasetConfig: "slice"})
Expand Down Expand Up @@ -180,10 +181,10 @@ class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType])
hint=FieldHint.core,
)

def build(self) -> "DatasetSlice":
def build(self, preprocessing: PreprocessingConfig) -> "DatasetSlice":
from fast_llm.data.dataset.indexed import DatasetSlice

dataset = self.dataset.build()
dataset = self.dataset.build(preprocessing)
size = len(dataset)
return DatasetSlice[SampleType](
f"{dataset.name}_{self.begin}_{self.end}",
Expand Down Expand Up @@ -272,20 +273,20 @@ def build_and_sample(


@config_class(dynamic_type={SampledDatasetConfig: "memmap"})
class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]):
class MemmapDatasetConfig[SampleType: Sample](IndexedDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
path: pathlib.Path = Field(
default=None,
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
hint=FieldHint.core,
)

def build(self) -> "IndexedDataset[SampleType]":
def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]":
name = str(self.path).replace("/", "__")
if self.path.is_file():
from fast_llm.data.dataset.memmap import MemmapDataset

return MemmapDataset[SampleType](name, self.path)
return MemmapDataset[SampleType](name, self.path, preprocessing)
elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file():
logger.warning(
"Using the legacy memmap dataset format."
Expand All @@ -294,6 +295,6 @@ def build(self) -> "IndexedDataset[SampleType]":
)
from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset

return LegacyMemmapDataset[SampleType](name, self.path)
return LegacyMemmapDataset[SampleType](name, self.path, preprocessing)
else:
raise FileNotFoundError(self.path)
34 changes: 11 additions & 23 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,41 @@

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset
from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters
from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData
from fast_llm.data.preprocessing.abstract import PreprocessingConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.preprocessing.tokenizer import TokenizerConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.gpt.fim import GPTFimDataset
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset
from fast_llm.data.sample.language_model import LanguageModelSample


@dataclasses.dataclass(kw_only=True)
class GPTSamplingParameters(SamplingParameters):
"""
Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model.
"""

# TODO: Only used for random dataset. Remove? Or use as safety check?
vocab_size: int | None = None
# TODO: ====== Get these to memmap dataset (currently ignored) ======
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
use_images: bool = False


@dataclasses.dataclass(kw_only=True)
class GPTSamplingData(SamplingData):
"""
Holds all the necessary information for sampling, including dataset-dependent ones (`GPTSamplingConfig`),
usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`.
"""

parameters: GPTSamplingParameters
preprocessing: LanguageModelPreprocessingConfig


@config_class(dynamic_type={SampledDatasetConfig: "random"})
class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]):
class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]):
_abstract: typing.ClassVar[bool] = False
name: str = Field(
default="dummy",
desc="The name of the dataset.",
hint=FieldHint.core,
)

def build(self) -> "GPTRandomDataset[SampleType]":
from fast_llm.data.dataset.gpt.random import GPTRandomDataset
def build_and_sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset[SampleType]":
from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset

return GPTRandomDataset[SampleType](self.name)
return GPTRandomSampledDataset[SampleType](sampling, self.name)


@config_class(dynamic_type={SampledDatasetConfig: "file"})
Expand All @@ -69,10 +57,10 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]
config = self._load_config()
return config.build_and_sample(sampling)

def build(self) -> SamplableDataset[SampleType]:
def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]:
config = self._load_config()
assert isinstance(config, SamplableDatasetConfig)
return config.build()
return config.build(preprocessing)

def _load_config(self) -> SampledDatasetConfig[SampleType]:
assert self.path.is_file(), f"File {self.path} does not exist."
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/dataset/gpt/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(
dataset: SampledDataset[SampleType],
sampling: GPTSamplingData,
):
if sampling.parameters.use_loss_masking_spans:
if sampling.preprocessing.use_loss_masking_spans:
raise NotImplementedError("FIM is currently not compatible with loss masking.")
if sampling.parameters.use_preference_loss_spans:
if sampling.preprocessing.use_preference_spans:
raise NotImplementedError("FIM is currently not compatible with preference loss masking.")
self._config = config
self._dataset = dataset
Expand Down
Loading