From 2ab18258563d7fd81d1d13ee6bc4c22917582dfe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Dec 2025 20:25:19 -0500 Subject: [PATCH 1/8] Fix rotary 2d --- fast_llm/layers/attention/rotary/rotary.py | 14 +++++-- tests/layers/test_rotary.py | 46 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 tests/layers/test_rotary.py diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 55d929f8a..258f9d8bc 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -194,11 +194,17 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: patch_positions = kwargs[VisionKwargs.patch_positions] if not hasattr(self, "_frequencies"): self._frequencies = self._config.theta ** -torch.arange( - 0, 1, 4 / self._head_size, device=kwargs[AttentionKwargs.device], dtype=torch.float64 - ) + 0, 1, 2 / self._head_size, device=kwargs[AttentionKwargs.device], dtype=torch.float64 + ).view(-1, 2) + # TODO: Pre-compute 2d frequencies? - angles = torch.outer(patch_positions.flatten(), self._frequencies).view( - len(patch_positions), self._head_size // 2 + # Equivalent to the separate outer product of height and width frequencies. + # Pre-allocate output to avoid a reshape with copy. + angles = self._frequencies.new_empty(len(patch_positions), self._head_size // 2) + torch.bmm( + patch_positions.T.unsqueeze(2).to(torch.float64), + self._frequencies.T.unsqueeze(1), + out=angles.view(-1, 2, self._head_size // 4).permute(1, 0, 2), ) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) if not self._config.complex_format: diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py new file mode 100644 index 000000000..85d72b316 --- /dev/null +++ b/tests/layers/test_rotary.py @@ -0,0 +1,46 @@ +import torch +import transformers + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import Rotary2DConfig +from fast_llm.layers.vision.config import VisionKwargs +from fast_llm.utils import Assert +from tests.utils.utils import requires_cuda + + +@requires_cuda +def test_rotary_2d(): + """ + Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. + """ + head_dim = 16 + num_heads = 8 + + patch_positions = torch.tensor( + [[h, w] for h in range(4) for w in range(4)], + dtype=torch.int64, + device="cuda", + ) + query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device="cuda").normal_() + key = torch.empty_like(query).normal_() + + pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to("cuda") + # Convert patch positions (h, w) to Pixtral's linear position IDs + # Pixtral expects: position_id = h * max_patches_per_side + w + position_ids = ( + patch_positions[None, :, 0] * (pixtral_config.image_size // pixtral_config.patch_size) + + patch_positions[None, :, 1] + ) + output_pixtral_query, output_pixtral_key = transformers.models.pixtral.modeling_pixtral.apply_rotary_pos_emb( + query, key, *pixtral_rotary(query, position_ids), unsqueeze_dim=2 + ) + + fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: "cuda"} + fast_llm_rotary.preprocess(kwargs) + output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) + + Assert.rms_close(output_pixtral_query, output_fast_llm_query, 1e-5) + Assert.rms_close(output_pixtral_key, output_fast_llm_key, 1e-5) From 8305dd586b6dee97acdf95e3c59db5a4e328f848 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Dec 2025 20:29:33 -0500 Subject: [PATCH 2/8] stuff --- .../data/preparator/gpt_memmap/prepare.py | 3 -- fast_llm/data/preprocessing/abstract.py | 28 +++++++++++++++++++ fast_llm/data/preprocessing/image_patch.py | 7 +++-- fast_llm/data/preprocessing/tokenizer.py | 9 ++++-- fast_llm/data/sample/abstract.py | 3 ++ fast_llm/layers/vision/config.py | 1 - .../models/multimodal/conversion/llava.py | 2 -- 7 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 fast_llm/data/preprocessing/abstract.py diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 5606eeb98..94bab200e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -198,11 +198,9 @@ def _prepare_shard( return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - # TODO: ======= Extract so we can use elsewhere? (ex. inference) ====== text = sample[self._source_schema.text] all_spans = [] if self._source_schema.has_loss_masking_span: - # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) @@ -213,7 +211,6 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: all_spans.extend(loss_masking_spans) if self._source_schema.has_preference_spans: - # TODO: ===== Was `self._config.dataset.field` (bug?) ====== full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] # compute chosen span diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py new file mode 100644 index 000000000..8dbaa3626 --- /dev/null +++ b/fast_llm/data/preprocessing/abstract.py @@ -0,0 +1,28 @@ +import typing + +from fast_llm.config import Config, config_class + + +@config_class(registry=True) +class PreprocessingConfig(Config): + """ + Base preprocessing configuration, with dynamic registry so configs can be saved with memmap datasets. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is PreprocessingConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullPreprocessingConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + +@config_class(dynamic_type={PreprocessingConfig: "none"}) +class NullPreprocessingConfig(PreprocessingConfig): + """ + Configuration for unspecified preprocessing. + """ + + _abstract = False diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index d6f5bf190..22ec04d68 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -4,6 +4,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div @@ -11,13 +12,15 @@ import torch -@config_class() -class ImagePatchConfig(Config): +@config_class(dynamic_type={PreprocessingConfig: "image_patch"}) +class ImagePatchConfig(PreprocessingConfig): """ Configuration for the tokenizer. The tokenizer is needed for FIM and dataset preparation. """ + _abstract = False + height: int = Field( default=16, desc="Height of the image patches, in pixels.", diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 70291bcaa..356407541 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -1,7 +1,8 @@ import pathlib import typing -from fast_llm.config import Config, Configurable, Field, FieldHint, config_class +from fast_llm.config import Configurable, Field, FieldHint, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -11,13 +12,15 @@ import torch -@config_class() -class TokenizerConfig(Config): +@config_class(dynamic_type={PreprocessingConfig: "tokenizer"}) +class TokenizerConfig(PreprocessingConfig): """ Configuration for the tokenizer. The tokenizer is needed for FIM and dataset preparation. """ + _abstract = False + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 11f5d187c..0db7d1c8a 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -4,6 +4,7 @@ import typing from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -101,6 +102,8 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): # Constant strings for alignment safety. header: typing.ClassVar[bytes] footer: typing.ClassVar[bytes] + # Additional information about how the dataset was prepared. + preprocessing: PreprocessingConfig = Field() @property def reader_class(self) -> "type[MemmapReader]": diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 2e0389e89..924e1c305 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -99,7 +99,6 @@ def layer_class(self) -> "type[PatchEmbeddings]": @config_class(registry=True) class VisionEncoderConfig(BlockConfig): _abstract = False - # TODO: ====== Rename to patch_embeddings? ====== embeddings: PatchEmbeddingsConfig = Field( desc="Configuration for the patch convolution layer.", hint=FieldHint.architecture, diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 9657d71b6..748f2f89e 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -79,7 +79,6 @@ def export_config(cls, config: AttentionConfig) -> dict: class PixtralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[PixtralAttentionConverter]] = PixtralAttentionConverter - # TODO: ====== MistralMLPConverter (#391 / #382) ====== mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "attention" @@ -225,7 +224,6 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: VisionEncoderConfig) -> dict: Assert.custom(isinstance, config, VisionEncoderConfig) - # TODO: ====== image_size? ====== vision_config = safe_merge_dicts( cls.embeddings_converter_class.export_config(config.embeddings), cls.encoder_converter_class.export_config(config.encoder), From b6e38b872cfca171cb6582bd4731eff2dd2f0f10 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 4 Dec 2025 01:03:32 -0500 Subject: [PATCH 3/8] stuff --- fast_llm/data/data/abstract.py | 4 + fast_llm/data/data/gpt/data.py | 5 +- fast_llm/data/dataset/config.py | 25 ++++--- fast_llm/data/dataset/gpt/config.py | 20 ++--- fast_llm/data/dataset/gpt/legacy_memmap.py | 44 +++++------ fast_llm/data/dataset/gpt/random.py | 34 +++------ fast_llm/data/dataset/memmap.py | 28 ++++--- .../data/preparator/gpt_memmap/prepare.py | 11 +++ fast_llm/data/preprocessing/abstract.py | 12 +++ fast_llm/data/preprocessing/image_patch.py | 12 +++ fast_llm/data/preprocessing/language_model.py | 40 ++++++++++ fast_llm/data/preprocessing/tokenizer.py | 8 +- fast_llm/data/sample/abstract.py | 18 ++++- fast_llm/data/sample/language_model.py | 75 +++++++++---------- fast_llm/data/sample/patch.py | 1 + fast_llm/data/sample/range.py | 1 + fast_llm/data/sample/token.py | 1 + fast_llm/engine/training/trainer.py | 7 +- fast_llm/models/gpt/trainer.py | 21 +++++- fast_llm/models/multimodal/trainer.py | 17 ++++- tests/models/test_match_megatron.py | 5 +- 21 files changed, 261 insertions(+), 128 deletions(-) create mode 100644 fast_llm/data/preprocessing/language_model.py diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index c67dc0321..2c1902796 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -16,6 +17,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] + _preprocessing: PreprocessingConfig _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: @@ -27,11 +29,13 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, SamplingParameters], + preprocessing: PreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: self._distributed = distributed self._sampling_parameters = sampling_parameters + self._preprocessing = preprocessing self._cache_directory = cache_directory @property diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index de47ef761..084dadc7d 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -13,6 +13,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig @@ -47,6 +48,7 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, GPTSamplingParameters], + preprocessing: LanguageModelPreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -54,7 +56,7 @@ def setup( Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ - super().setup(distributed, sampling_parameters, cache_directory) + super().setup(distributed, sampling_parameters, preprocessing, cache_directory) # Check and raise an error if a used dataset is not defined. for dataset_name in self._sampling_parameters.keys(): @@ -81,6 +83,7 @@ def setup( sampling = GPTSamplingData( config=self._config.sampling, parameters=sampling_parameters, + preprocessing=preprocessing, cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7611b4a31..2858d8d18 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -9,12 +9,12 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -84,6 +84,7 @@ class SamplingData: # TODO: This prevents the sampling config from being pickled in multiprocessing. distributed: "Distributed" dataset_name: str + preprocessing: PreprocessingConfig # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count @@ -114,16 +115,16 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType] @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self) -> SamplableDataset[SampleType]: + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: raise NotImplementedError() def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: - return self.build().sample(sampling) + return self.build(sampling.preprocessing).sample(sampling) @config_class() class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): - def build(self) -> "IndexedDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": raise NotImplementedError() @@ -147,10 +148,10 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) - def build(self) -> "ConcatenatedDataset": + def build(self, preprocessing: PreprocessingConfig) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build(preprocessing) for dataset in self.datasets]) @config_class(dynamic_type={SampledDatasetConfig: "slice"}) @@ -180,10 +181,10 @@ class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]) hint=FieldHint.core, ) - def build(self) -> "DatasetSlice": + def build(self, preprocessing: PreprocessingConfig) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - dataset = self.dataset.build() + dataset = self.dataset.build(preprocessing) size = len(dataset) return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", @@ -272,7 +273,7 @@ def build_and_sample( @config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): +class MemmapDatasetConfig[SampleType: Sample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -280,12 +281,12 @@ class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[ hint=FieldHint.core, ) - def build(self) -> "IndexedDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": name = str(self.path).replace("/", "__") if self.path.is_file(): from fast_llm.data.dataset.memmap import MemmapDataset - return MemmapDataset[SampleType](name, self.path) + return MemmapDataset[SampleType](name, self.path, preprocessing) elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): logger.warning( "Using the legacy memmap dataset format." @@ -294,6 +295,6 @@ def build(self) -> "IndexedDataset[SampleType]": ) from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset - return LegacyMemmapDataset[SampleType](name, self.path) + return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) else: raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 175779823..4336657ce 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,12 +8,14 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset - from fast_llm.data.dataset.gpt.random import GPTRandomDataset + from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset + from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -23,9 +25,8 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - # TODO: Only used for random dataset. Remove? Or use as safety check? - vocab_size: int | None = None # TODO: ====== Get these to memmap dataset (currently ignored) ====== + vocab_size: int | None = None use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False use_images: bool = False @@ -39,10 +40,11 @@ class GPTSamplingData(SamplingData): """ parameters: GPTSamplingParameters + preprocessing: LanguageModelPreprocessingConfig @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -50,10 +52,10 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetCo hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset[SampleType]": - from fast_llm.data.dataset.gpt.random import GPTRandomDataset + def build_and_sample(self, sampling: GPTSamplingData) -> GPTRandomSampledDataset[SampleType]: + from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - return GPTRandomDataset[SampleType](self.name) + return GPTRandomSampledDataset[SampleType](sampling, self.name) @config_class(dynamic_type={SampledDatasetConfig: "file"}) @@ -69,10 +71,10 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType] config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset[SampleType]: + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: config = self._load_config() assert isinstance(config, SamplableDatasetConfig) - return config.build() + return config.build(preprocessing) def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index 2a23e378b..b5bc5b7de 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -4,8 +4,8 @@ import numpy as np import torch -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -38,24 +38,27 @@ def __init__( self, name: str, prefix: pathlib.Path | str, + preprocessing: LanguageModelPreprocessingConfig, ): - self._init(name, prefix) + self._init(name, prefix, preprocessing) - def _init(self, name: str, prefix: pathlib.Path | str) -> None: + def _init(self, name: str, prefix: pathlib.Path | str, preprocessing: LanguageModelPreprocessingConfig) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False + has_loss_masking_spans = False + has_preference_spans = False + assert isinstance(preprocessing, LanguageModelPreprocessingConfig) + self._preprocessing = preprocessing with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack(" None: self._document_sizes = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) + assert not self._preprocessing.use_image_patches # read pointers self._pointers = np.frombuffer( @@ -79,8 +83,8 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: ) # read spans - self._spans = None - if self._has_spans and self._version >= 2: + if self._preprocessing.use_loss_masking_spans: + assert has_loss_masking_spans self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -101,9 +105,8 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: ) # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: + if has_preference_spans: + assert has_preference_spans self._chosen_spans = [] self._rejected_spans = [] chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes @@ -135,11 +138,12 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: self._num_tokens = div(self._bin_buffer_mmap.size, self._dtype.itemsize) - def __getstate__(self) -> tuple[str, pathlib.Path]: - return (self._name, self._prefix) + def __getstate__(self) -> tuple[str, pathlib.Path, dict]: + return self._name, self._prefix, self._preprocessing.to_dict() - def __setstate__(self, state: tuple[str, pathlib.Path]): - self._init(*state) + def __setstate__(self, state: tuple[str, pathlib.Path, dict]): + name, prefix, preprocessing = state + self._init(name, prefix, LanguageModelPreprocessingConfig.from_dict(preprocessing)) def __del__(self): if hasattr(self, "_bin_buffer_mmap"): @@ -149,9 +153,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None - ) -> SampleType: + def get_document(self, index: int, begin: int = 0, end: int | None = None) -> SampleType: if end is None: end = self.get_document_size(index) sample_size = self._document_sizes[index].item() @@ -169,7 +171,7 @@ def get_document( if not self._dtype.is_signed: # Needed because torch doesn't yet support type promotion between signed and unsigned types. TODO: Remove when supported. token_ids = token_ids.to(torch.int64) - if parameters is not None and parameters.use_loss_masking_spans: + if self._preprocessing.use_loss_masking_spans: assert self._spans is not None # Convert to in range format (begin, end). sample_spans = RangeSample( @@ -178,7 +180,7 @@ def get_document( else: sample_spans = None - if parameters is not None and parameters.use_preference_loss_spans: + if self._preprocessing.use_preference_spans: if not self._has_preference_spans: raise ValueError("No preference spans found in memmap dataset.") elif self._has_preference_spans and self._chosen_spans is None: diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index f1e73c595..939b900e5 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -1,39 +1,27 @@ import numpy as np import torch -from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type -class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): - """ - A dummy dataset that always returns the same random sample, for debugging purposes. - """ - - def __init__(self, name: str): - self._name = name - - def sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset": - return GPTRandomSampledDataset(sampling, f"{self.name}_sampled") - - @property - def name(self) -> str: - return self._name - - class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed self._parameters = sampling.parameters - assert self._parameters.vocab_size is not None - # TODO: Support? - assert not self._parameters.use_loss_masking_spans - assert not self._parameters.use_preference_loss_spans - self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch + + assert isinstance(sampling.preprocessing, LanguageModelPreprocessingConfig) + assert not sampling.preprocessing.use_loss_masking_spans + assert not sampling.preprocessing.use_preference_spans + assert not sampling.preprocessing.use_image_patches + self._vocab_size = sampling.preprocessing.vocab_size + + self._dtype = get_unsigned_integer_type(self._vocab_size).torch def __len__(self) -> int: return self._parameters.num_samples @@ -45,7 +33,7 @@ def __getitem__(self, index: int) -> SampleType: torch.from_numpy( np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( 0, - self._parameters.vocab_size, + self._vocab_size, size=(self._parameters.sequence_length + self._parameters.extra_tokens,), ) ).to(self._dtype), diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index 4b1930dd3..4d75ca450 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -7,6 +7,7 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample FILE_HEADER = b"fast_llm_prepared_dataset" @@ -21,13 +22,15 @@ def __init__( self, name: str, path: pathlib.Path | str, + preprocessing: PreprocessingConfig, ): - self._init(name, path) + self._init(name, path, preprocessing) - def _init(self, name: str, path: pathlib.Path | str) -> None: + def _init(self, name: str, path: pathlib.Path | str, preprocessing: PreprocessingConfig) -> None: super().__init__() self._name = name self._path = path + self._preprocessing = preprocessing with self._path.open("rb") as stream: # Very file type. @@ -39,16 +42,19 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) ) + reader_config.preprocessing.check_compatibility(self._preprocessing) + self._memmap = np.memmap(self._path, mode="r") + # TODO: ====== Forward preprocessing config so the reader reads just what we need. self._reader = reader_config.get_reader(memoryview(self._memmap)) - def __getstate__(self) -> tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]: + def __getstate__(self) -> tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]: # We pass the reader config to force its import in data loader workers. - return self._name, self._path, self._reader.config + return self._name, self._path, self._preprocessing.to_dict(), self._reader.config - def __setstate__(self, state: tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]): - name, path, _ = state - self._init(name, path) + def __setstate__(self, state: tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]): + name, path, preprocessing, _ = state + self._init(name, path, PreprocessingConfig.from_dict(preprocessing)) def __del__(self): if hasattr(self, "_memmap"): @@ -81,7 +87,11 @@ def get_document_size(self, index: int) -> int: @classmethod def write_dataset( - cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter] + cls, + path: pathlib.Path, + documents: typing.Iterable[Sample], + writer_class: type[MemmapWriter], + preprocessing_config: PreprocessingConfig | None = None, ) -> MemmapIndexDatasetReaderConfig: # TODO: Match `writer_class` with `SampleType`? path.parent.mkdir(parents=True, exist_ok=True) @@ -93,7 +103,7 @@ def write_dataset( start = stream.tell() stream.seek(start + 8) # Write the data. - reader_config = writer_class.write_dataset(stream, documents) + reader_config = writer_class.write_dataset(stream, documents, preprocessing_config) # Write the reader config. config_offset = stream.tell() reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 94bab200e..d0628e08f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -27,6 +27,8 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter @@ -194,6 +196,15 @@ def _prepare_shard( for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") ), LanguageModelWriter, + LanguageModelPreprocessingConfig( + tokenizer=self._config.tokenizer, + vocab_size=self._tokenizer.vocab_size, + image_patches=( + self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() + ), + has_loss_masking_spans=self._source_schema.has_loss_masking_span, + has_preference_spans=self._source_schema.has_preference_spans, + ), ) return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py index 8dbaa3626..dc8c88375 100644 --- a/fast_llm/data/preprocessing/abstract.py +++ b/fast_llm/data/preprocessing/abstract.py @@ -1,7 +1,10 @@ +import logging import typing from fast_llm.config import Config, config_class +logger = logging.getLogger(__name__) + @config_class(registry=True) class PreprocessingConfig(Config): @@ -18,6 +21,12 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return NullPreprocessingConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) + def check_compatibility(self, preprocessing: typing.Self) -> None: + """ + Check whether a dataset preprocessed with `self` can produce samples for a model that requires `preprocessing`. + """ + raise NotImplementedError() + @config_class(dynamic_type={PreprocessingConfig: "none"}) class NullPreprocessingConfig(PreprocessingConfig): @@ -26,3 +35,6 @@ class NullPreprocessingConfig(PreprocessingConfig): """ _abstract = False + + def check_compatibility(self, preprocessing: typing.Self) -> None: + logger.warning("Dataset preprocessing config not specified, could not check compatibility with the model.") diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index 22ec04d68..7c3d9d53b 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -58,6 +58,18 @@ class ImagePatchConfig(PreprocessingConfig): hint=FieldHint.optional, ) + def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.eq(self.height, preprocessing.height) + Assert.eq(self.width, preprocessing.width) + Assert.eq(self.do_resize, preprocessing.do_resize) + Assert.leq(self.max_image_height, preprocessing.max_image_height) + Assert.leq(self.max_image_width, preprocessing.max_image_width) + # None is used in the trainer to mark unknown value, so we can't do an equality check. TODO: Distinguish. + if preprocessing.image_break_token is not None: + Assert.eq(self.image_break_token, preprocessing.image_break_token) + if preprocessing.image_end_token is not None: + Assert.eq(self.image_end_token, preprocessing.image_end_token) + @property def num_channels(self) -> int: # assume 3 channels (RGB) for all images diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py new file mode 100644 index 000000000..d4e1235ae --- /dev/null +++ b/fast_llm/data/preprocessing/language_model.py @@ -0,0 +1,40 @@ +import functools +import typing + +from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.utils import Assert + + +@config_class(dynamic_type={PreprocessingConfig: "language_model"}) +class LanguageModelPreprocessingConfig(PreprocessingConfig): + tokenizer: TokenizerConfig = Field() + # We can't easily compare tokenizers, + # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, + # so we provide the vocab size and use it for compatibility checks. + vocab_size: int = Field() + image_patches: PreprocessingConfig = Field() + use_loss_masking_spans: bool = Field(default=False) + use_preference_spans: bool = Field(default=False) + + def _validate(self) -> None: + super()._validate() + Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) + + @functools.cached_property + def use_image_patches(self) -> bool: + return isinstance(self.image_patches, ImagePatchConfig) + + def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) + # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? + Assert.geq(self.vocab_size, preprocessing.vocab_size) + if preprocessing.use_loss_masking_spans: + assert self.use_loss_masking_spans + if preprocessing.use_preference_spans: + assert self.use_preference_spans + if preprocessing.use_image_patches: + assert self.use_image_patches + self.image_patches.check_compatibility(preprocessing.image_patches) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 356407541..a0d460d4c 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -1,3 +1,4 @@ +import functools import pathlib import typing @@ -69,9 +70,12 @@ def __init__(self, config: ConfigType): self.eod_id = self.tokenizer.eos_token_id self.bod_id = self.tokenizer.bos_token_id - @property + @functools.cached_property def vocab_size(self) -> int: - return len(self.tokenizer) + out = len(self.tokenizer) + if self._config.max_vocab_size is not None: + out = min(out, self._config.max_vocab_size) + return out @property def vocab(self) -> dict[str, int]: diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0db7d1c8a..973e29ad8 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -4,7 +4,7 @@ import typing from fast_llm.config import Config, Configurable, Field, config_class -from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -195,11 +195,16 @@ def get_document_size(self, index: int) -> int: class MemmapWriter(abc.ABC): - def __init__(self, stream: io.BufferedWriter | pathlib.Path): + def __init__( + self, stream: io.BufferedWriter | pathlib.Path, preprocessing_config: PreprocessingConfig | None = None + ): self._owns_stream = isinstance(stream, pathlib.Path) if self._owns_stream: stream = stream.open("wb") self._stream = stream + self._preprocessing_config = ( + NullPreprocessingConfig() if preprocessing_config is None else preprocessing_config + ) def __enter__(self): self._begin = self._stream.tell() @@ -230,8 +235,13 @@ def _get_config(self, begin: int, end: int): pass @classmethod - def write_dataset(cls, stream: io.BufferedWriter, documents: typing.Iterable[Sample]) -> MemmapReaderConfig: - with cls(stream) as writer: + def write_dataset( + cls, + stream: io.BufferedWriter, + documents: typing.Iterable[Sample], + preprocessing_config: PreprocessingConfig | None = None, + ) -> MemmapReaderConfig: + with cls(stream, preprocessing_config) as writer: for document in documents: writer.write(document) return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 88ca05b95..0e1baaef8 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -6,7 +6,9 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapIndexDatasetReaderConfig, @@ -135,6 +137,11 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): rejected_spans: MemmapReaderBaseConfig = Field() image_patches: MemmapReaderBaseConfig = Field() + def _validate(self) -> None: + super()._validate() + # Dynamic type supported for backward compatibility. + Assert.custom(isinstance, self.preprocessing, (LanguageModelPreprocessingConfig, NullPreprocessingConfig)) + def __len__(self) -> int: return len(self.tokens) @@ -201,9 +208,7 @@ def get_document_size(self, index: int) -> int: class LanguageModelWriter(MemmapWriter): - _has_loss_masking_spans: bool | None = None - _has_preference_spans: bool | None = None - _has_image_patches: bool | None = None + _preprocessing_config: LanguageModelPreprocessingConfig def __enter__(self): super().__enter__() @@ -214,10 +219,13 @@ def __enter__(self): self._path = pathlib.Path(self._directory.name) # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() - self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() - self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() - self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() - self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + if self._preprocessing_config.use_image_patches: + self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() return self def write(self, document: LanguageModelSample): @@ -225,58 +233,46 @@ def write(self, document: LanguageModelSample): # Write tokens. self._token_writer.write(document.tokens) - # Ensure either all samples have loss masking spans or none of them do. - if self._has_loss_masking_spans is None: - self._has_loss_masking_spans = document.loss_masking_spans is not None - else: - Assert.eq(self._has_loss_masking_spans, document.loss_masking_spans is not None) - # Write loss masking spans. - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: + assert document.loss_masking_spans is not None self._loss_masking_span_writer.write(document.loss_masking_spans) - # All sample must either have both chosen and rejected spans, or neither. - if self._has_preference_spans is None: - self._has_preference_spans = document.chosen_spans is not None - else: - Assert.eq(self._has_preference_spans, document.chosen_spans is not None) - Assert.eq(self._has_preference_spans, document.rejected_spans is not None) - # Write preference spans. - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: + assert document.chosen_spans is not None + assert document.rejected_spans is not None self._chosen_spans_writer.write(document.chosen_spans) self._rejected_spans_writer.write(document.rejected_spans) - # Ensure either all samples have image patches or none of them do. - if self._has_image_patches is None: - self._has_image_patches = document.image_patches is not None - else: - Assert.eq(self._has_image_patches, document.image_patches is not None) - # Write image patches - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: + assert document.image_patches is not None self._image_patches_writer.write(document.image_patches) def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) - self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) - self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) - self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) - self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_image_patches: + self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. config = self._get_config(self._begin, None) _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: _copy_chunked( self._path.joinpath("loss_masking_spans"), self._stream, config.loss_masking_spans.begin, config.loss_masking_spans.end, ) - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: _copy_chunked( self._path.joinpath("chosen_spans"), self._stream, @@ -290,7 +286,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): config.rejected_spans.end, ) - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: _copy_chunked( self._path.joinpath("image_patches"), self._stream, @@ -308,12 +304,12 @@ def _get_config_class(cls) -> type[LanguageModelReaderConfig]: def _get_config(self, begin: int, end: int | None): tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) offset = tokens.end - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: loss_masking_spans = self._loss_masking_span_writer.get_config(offset) offset = loss_masking_spans.end else: loss_masking_spans = NullReaderConfig() - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: chosen_spans = self._chosen_spans_writer.get_config(offset) offset = chosen_spans.end rejected_spans = self._rejected_spans_writer.get_config(offset) @@ -321,7 +317,7 @@ def _get_config(self, begin: int, end: int | None): else: chosen_spans = NullReaderConfig() rejected_spans = NullReaderConfig() - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: image_patches = self._image_patches_writer.get_config(offset) offset = image_patches.end else: @@ -338,6 +334,7 @@ def _get_config(self, begin: int, end: int | None): chosen_spans=chosen_spans, rejected_spans=rejected_spans, image_patches=image_patches, + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 9d27d37cd..a75684d76 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -300,4 +300,5 @@ def _get_config(self, begin: int, end: int): num_patch_groups=self._group_count_cumsum[-1], patch_shape=self._patch_shape, data_type=DataType.from_torch(self._data_type), + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index b7be4efe1..0022b3593 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -135,4 +135,5 @@ def _get_config(self, begin: int, end: int): end=end, num_documents=len(self._count_cumsum) - 1, num_ranges=self._count_cumsum[-1], + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 9fedf12b5..3f5912e5e 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -166,4 +166,5 @@ def _get_config(self, begin: int, end: int): num_documents=len(self._size_cumsum) - 1, num_tokens=self._size_cumsum[-1], data_type=DataType.from_torch(self._data_type), + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index aa4f2d570..dd106f35c 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -12,6 +12,7 @@ from fast_llm.core.distributed import allreduce_scalar, safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -239,6 +240,7 @@ def setup(self, distributed: Distributed, run: Run) -> None: ) for eval_sampling_params in self._evaluator_runner.get_sampling_parameters() }, + self._get_preprocessing_config(), None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) @@ -261,10 +263,13 @@ def _get_data(self) -> Data: pass def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], _return_dict: bool = False + self, parameters: dict[str, typing.Any], *, _return_dict: bool = False ) -> SamplingParameters | dict[str, typing.Any]: return parameters if _return_dict else SamplingParameters(**parameters) + def _get_preprocessing_config(self, *, _return_dict: bool = False) -> PreprocessingConfig | dict[str, typing.Any]: + return {} if _return_dict else NullPreprocessingConfig() + @property def _consumed_samples(self) -> int: assert self._is_setup diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b8fb22ebb..1c7be33dd 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,6 +3,7 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -17,18 +18,30 @@ def _get_data(self) -> GPTData: ) def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], _return_dict: bool = False + self, parameters: dict[str, typing.Any], *, _return_dict: bool = False ) -> GPTSamplingParameters | dict[str, typing.Any]: parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.embeddings.vocab_size, + # "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, - "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + # "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. - "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + # "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) + + def _get_preprocessing_config( + self, *, _return_dict: bool = False + ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = { + "type": "language_model", + "vocab_size": self._config.model.base_model.embeddings.vocab_size, + "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + # OK since DPO is not supported for MTP. + "use_preference_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + } + return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py index 2beee1097..cd8e09cae 100644 --- a/fast_llm/models/multimodal/trainer.py +++ b/fast_llm/models/multimodal/trainer.py @@ -1,5 +1,7 @@ import logging +import typing +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.trainer import GPTTrainer from fast_llm.models.multimodal.config import MultiModalTrainerConfig @@ -7,4 +9,17 @@ class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): - pass + def _get_preprocessing_config( + self, *, _return_dict: bool = False + ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = super()._get_preprocessing_config(_return_dict=True) + out["image_patches"] = { + "height": self._config.model.base_model.vision_encoder.embeddings.patch_height, + "width": self._config.model.base_model.vision_encoder.embeddings.patch_width, + # TODO: Max shape and special tokens are unspecified in the model. + "max_image_height": 2**32, + "max_image_width": 2**32, + "image_break_token": None, + "image_end_token": None, + } + return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index f3ce65966..e2cadf717 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -15,6 +15,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -122,8 +123,8 @@ class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig hint=FieldHint.core, ) - def build(self) -> "LegacyMemmapDataset[SampleType]": - return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path) + def build(self, preprocessing: PreprocessingConfig) -> "LegacyMemmapDataset[SampleType]": + return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path, preprocessing) class MegatronMemmapDataset(LegacyMemmapDataset): From 350fb3df7f877549039c4e2d1ffda7fbf9f03e76 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 00:30:55 -0500 Subject: [PATCH 4/8] stuff --- fast_llm/data/data/gpt/data.py | 7 +- fast_llm/data/dataset/gpt/config.py | 20 +-- fast_llm/data/dataset/gpt/fim.py | 4 +- fast_llm/data/dataset/gpt/legacy_memmap.py | 19 ++- fast_llm/data/dataset/memmap.py | 5 +- .../data/preparator/gpt_memmap/prepare.py | 34 +++-- fast_llm/data/preprocessing/abstract.py | 4 +- fast_llm/data/preprocessing/image_patch.py | 5 + fast_llm/data/preprocessing/language_model.py | 15 +- fast_llm/data/preprocessing/tokenizer.py | 22 +-- fast_llm/data/sample/abstract.py | 17 +-- fast_llm/data/sample/language_model.py | 136 +++++++++++++++--- fast_llm/data/sample/patch.py | 27 +++- fast_llm/data/sample/range.py | 12 +- fast_llm/data/sample/token.py | 7 +- fast_llm/models/gpt/trainer.py | 10 +- tests/data/common.py | 21 ++- tests/data/test_blending.py | 18 ++- tests/data/test_concatenate.py | 8 +- tests/data/test_fim.py | 5 +- tests/data/test_image_patch.py | 13 +- tests/data/test_loss_masking_spans.py | 17 ++- tests/data/test_preference_spans.py | 17 ++- tests/data/test_preparator.py | 25 ++-- tests/data/test_random.py | 6 +- tests/data/test_sampling.py | 23 ++- tests/data/test_slice.py | 7 +- tests/models/test_match_megatron.py | 2 +- tests/utils/dataset.py | 53 +++++-- tests/utils/model_configs.py | 2 +- 30 files changed, 364 insertions(+), 197 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 084dadc7d..dbd770895 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -10,7 +10,8 @@ from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig @@ -30,7 +31,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ _datasets: dict[str, SampledDataset] - _sampling_parameters: dict[str, GPTSamplingParameters] + _sampling_parameters: dict[str, SamplingParameters] _is_setup: bool = False def __init__( @@ -47,7 +48,7 @@ def __init__( def setup( self, distributed: "Distributed", - sampling_parameters: dict[str, GPTSamplingParameters], + sampling_parameters: dict[str, SamplingParameters], preprocessing: LanguageModelPreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 4336657ce..fc326d366 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -7,31 +7,18 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample -@dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): - """ - Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. - """ - - # TODO: ====== Get these to memmap dataset (currently ignored) ====== - vocab_size: int | None = None - use_loss_masking_spans: bool = False - use_preference_loss_spans: bool = False - use_images: bool = False - - @dataclasses.dataclass(kw_only=True) class GPTSamplingData(SamplingData): """ @@ -39,7 +26,6 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - parameters: GPTSamplingParameters preprocessing: LanguageModelPreprocessingConfig @@ -52,7 +38,7 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConf hint=FieldHint.core, ) - def build_and_sample(self, sampling: GPTSamplingData) -> GPTRandomSampledDataset[SampleType]: + def build_and_sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset return GPTRandomSampledDataset[SampleType](sampling, self.name) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index d36384ee5..b70fc8360 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -21,9 +21,9 @@ def __init__( dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): - if sampling.parameters.use_loss_masking_spans: + if sampling.preprocessing.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") - if sampling.parameters.use_preference_loss_spans: + if sampling.preprocessing.use_preference_spans: raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index b5bc5b7de..d29e31596 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -105,7 +105,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, preprocessing: LanguageMo ) # read preference spans - if has_preference_spans: + if self._preprocessing.use_preference_spans: assert has_preference_spans self._chosen_spans = [] self._rejected_spans = [] @@ -173,20 +173,17 @@ def get_document(self, index: int, begin: int = 0, end: int | None = None) -> Sa token_ids = token_ids.to(torch.int64) if self._preprocessing.use_loss_masking_spans: assert self._spans is not None - # Convert to in range format (begin, end). - sample_spans = RangeSample( - [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size - ).crop(begin, end) + if hasattr(self, "_spans"): + # Convert to in range format (begin, end). + sample_spans = RangeSample( + [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size + ).crop(begin, end) + else: + sample_spans = RangeSample([], end - begin) else: sample_spans = None if self._preprocessing.use_preference_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") # Convert to in range format (begin, end). chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index 4d75ca450..f80a48b0a 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -42,11 +42,8 @@ def _init(self, name: str, path: pathlib.Path | str, preprocessing: Preprocessin json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) ) - reader_config.preprocessing.check_compatibility(self._preprocessing) - self._memmap = np.memmap(self._path, mode="r") - # TODO: ====== Forward preprocessing config so the reader reads just what we need. - self._reader = reader_config.get_reader(memoryview(self._memmap)) + self._reader = reader_config.get_reader(memoryview(self._memmap), self._preprocessing) def __getstate__(self) -> tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]: # We pass the reader config to force its import in data loader workers. diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d0628e08f..91506e4d5 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,5 +1,6 @@ import collections import enum +import functools import json import logging import math @@ -196,18 +197,22 @@ def _prepare_shard( for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") ), LanguageModelWriter, - LanguageModelPreprocessingConfig( - tokenizer=self._config.tokenizer, - vocab_size=self._tokenizer.vocab_size, - image_patches=( - self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() - ), - has_loss_masking_spans=self._source_schema.has_loss_masking_span, - has_preference_spans=self._source_schema.has_preference_spans, - ), + self._preprocessing_config, ) return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config + @functools.cached_property + def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: + return LanguageModelPreprocessingConfig( + tokenizer=self._config.tokenizer, + vocab_size=self._tokenizer.vocab_size, + image_patches=( + self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() + ), + use_loss_masking_spans=self._source_schema.has_loss_masking_span, + use_preference_spans=self._source_schema.has_preference_spans, + ) + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: text = sample[self._source_schema.text] all_spans = [] @@ -385,9 +390,8 @@ def _blend_dataset_configs( } ) - @classmethod def _split_and_blend_dataset_configs( - cls, + self, dataset_configs: list[MemmapDatasetConfig[_sample_type]], reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], @@ -422,14 +426,16 @@ def _split_and_blend_dataset_configs( elif split_end_in_dataset > split_begin_in_dataset: # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). - dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() + dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build( + self._preprocessing_config + ) sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) if end_index > begin_index: datasets_in_split.append( - DatasetSliceConfig[cls._sample_type].from_dict( + DatasetSliceConfig[self._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -451,7 +457,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py index dc8c88375..ea1f910df 100644 --- a/fast_llm/data/preprocessing/abstract.py +++ b/fast_llm/data/preprocessing/abstract.py @@ -1,5 +1,6 @@ import logging import typing +import warnings from fast_llm.config import Config, config_class @@ -37,4 +38,5 @@ class NullPreprocessingConfig(PreprocessingConfig): _abstract = False def check_compatibility(self, preprocessing: typing.Self) -> None: - logger.warning("Dataset preprocessing config not specified, could not check compatibility with the model.") + if not isinstance(preprocessing, NullPreprocessingConfig): + warnings.warn(f"Preprocessing configuration not specified, could not check compatibility with the model.") diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index 7c3d9d53b..146c5809b 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -59,6 +59,7 @@ class ImagePatchConfig(PreprocessingConfig): ) def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.custom(isinstance, preprocessing, ImagePatchConfig) Assert.eq(self.height, preprocessing.height) Assert.eq(self.width, preprocessing.width) Assert.eq(self.do_resize, preprocessing.do_resize) @@ -75,6 +76,10 @@ def num_channels(self) -> int: # assume 3 channels (RGB) for all images return 3 + @functools.cached_property + def patch_shape(self) -> tuple[int, int, int]: + return self.num_channels, self.height, self.width + @functools.cached_property def max_patches_height(self) -> int: return div(self.max_image_height, self.height) diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d4e1235ae..6c38c3f4e 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -1,4 +1,5 @@ import functools +import logging import typing from fast_llm.config import Field, config_class @@ -7,21 +8,25 @@ from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + @config_class(dynamic_type={PreprocessingConfig: "language_model"}) class LanguageModelPreprocessingConfig(PreprocessingConfig): - tokenizer: TokenizerConfig = Field() + _abstract = False + tokenizer: PreprocessingConfig = Field() # We can't easily compare tokenizers, # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, # so we provide the vocab size and use it for compatibility checks. - vocab_size: int = Field() image_patches: PreprocessingConfig = Field() + vocab_size: int = Field() use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) def _validate(self) -> None: super()._validate() Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) + Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig)) @functools.cached_property def use_image_patches(self) -> bool: @@ -31,10 +36,8 @@ def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? Assert.geq(self.vocab_size, preprocessing.vocab_size) - if preprocessing.use_loss_masking_spans: - assert self.use_loss_masking_spans if preprocessing.use_preference_spans: + # Preference spans are strictly needed for DPO loss. assert self.use_preference_spans - if preprocessing.use_image_patches: - assert self.use_image_patches + if preprocessing.use_image_patches and self.use_image_patches: self.image_patches.check_compatibility(preprocessing.image_patches) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index a0d460d4c..9e11fa66c 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -39,8 +39,6 @@ class TokenizerConfig(PreprocessingConfig): ) def get_tokenizer(self) -> "Tokenizer": - from fast_llm.data.preprocessing.tokenizer import Tokenizer - return Tokenizer(self) @@ -90,14 +88,20 @@ def tokenize( ) -> "torch.Tensor": import torch - tokens = torch.tensor( - ([self.bod_id] if begin else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end else []), - dtype=data_type.torch, - ) + tokens = self.tokenizer.encode(text, add_special_tokens=False) + if begin: + tokens.insert(0, self.bod_id) + if end: + tokens.append(self.eod_id) + if self._config.max_vocab_size is not None: - tokens %= self._config.max_vocab_size + # In some cases creating a tensor before restricting the vocab size may cause an overflow. + ( + torch.tensor(tokens, dtype=torch.int64 if len(self.tokenizer) > torch.iinfo().max else data_type.torch) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) return tokens def tokenize_with_spans( diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 973e29ad8..3fba789d1 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -109,8 +109,8 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): def reader_class(self) -> "type[MemmapReader]": raise NotImplementedError() - def get_reader(self, buffer: memoryview) -> "MemmapReader": - return self.reader_class(self, buffer) + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None) -> "MemmapReader": + return self.reader_class(self, buffer, model_preprocessing) @property def expected_buffer_size(self) -> int: @@ -156,16 +156,17 @@ def num_tokens(self) -> int: def reader_class(self) -> "type[MemmapIndexedDatasetReader]": raise NotImplementedError() - def get_reader( - self, - buffer: memoryview, - ) -> "MemmapIndexedDatasetReader": - return self.reader_class(self, buffer) + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer, model_preprocessing) class MemmapReader[ConfigType: MemmapReaderConfig](Configurable[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): super().__init__(config) + # Note: This is the requirement at reading time (ex. from the model), + # which may differ from how the dataset was actually preprocessed (`config.preprocessing`) + # Compatibility checked in `MemmapDataset`. + self._model_preprocessing = NullPreprocessingConfig if model_preprocessing is None else model_preprocessing buffer_begin = self._config.begin + len(self._config.header) buffer_end = self._config.end - len(self._config.footer) Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 0e1baaef8..1331cf82a 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -1,13 +1,15 @@ import io +import logging import pathlib import tempfile import typing +import warnings import torch from fast_llm.config import Field, config_class from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig -from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig +from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig, ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, @@ -18,11 +20,14 @@ NullReaderConfig, Sample, ) -from fast_llm.data.sample.patch import PatchBatch, PatchSample, PatchWriter -from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter +from fast_llm.data.sample.patch import EmptyPatchReader, PatchBatch, PatchReaderConfig, PatchSample, PatchWriter +from fast_llm.data.sample.range import EmptyRangeReader, RangeBatch, RangeReaderConfig, RangeSample, RangeWriter from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class LanguageModelSample(Sample): def __init__( @@ -139,8 +144,45 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): def _validate(self) -> None: super()._validate() - # Dynamic type supported for backward compatibility. - Assert.custom(isinstance, self.preprocessing, (LanguageModelPreprocessingConfig, NullPreprocessingConfig)) + if isinstance(self.preprocessing, NullPreprocessingConfig): + # Address missing config, mostly for backward compatibility. + # TODO: We can't tell which dataset this comes from. + logger.warning( + f"Preprocessing configuration not specified for dataset reader, generating partial configuration from known parameters." + ) + if isinstance(self.image_patches, PatchReaderConfig): + Assert.eq(len(patch_shape := self.image_patches.patch_shape), 3) + image_patches = ImagePatchConfig(height=patch_shape[1], width=patch_shape[2]) + else: + image_patches = NullPreprocessingConfig() + self.preprocessing = LanguageModelPreprocessingConfig( + vocab_size=0, + image_patches=image_patches, + use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), + use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), + ) + # TODO: Avoid duplicated information. + Assert.custom( + isinstance, + self.loss_masking_spans, + RangeReaderConfig if self.preprocessing.use_loss_masking_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.chosen_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.rejected_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + if self.preprocessing.use_image_patches: + Assert.custom(isinstance, self.image_patches, PatchReaderConfig) + Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) + Assert.eq(self.image_patches.data_type, DataType.uint8) + else: + Assert.custom(isinstance, self.image_patches, NullReaderConfig) def __len__(self) -> int: return len(self.tokens) @@ -169,17 +211,59 @@ def _expected_buffer_size(self) -> int: class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + _model_preprocessing: LanguageModelPreprocessingConfig + + def __init__( + self, + config: ConfigType, + buffer: memoryview, + model_preprocessing: LanguageModelPreprocessingConfig | None = None, + ): + super().__init__(config, buffer, model_preprocessing) + self._config.preprocessing.check_compatibility(self._model_preprocessing) # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. self._tokens = self._config.tokens.get_reader(buffer) - self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) - self._chosen_spans = self._config.chosen_spans.get_reader(buffer) - self._rejected_spans = self._config.rejected_spans.get_reader(buffer) - self._image_patches = self._config.image_patches.get_reader(buffer) - if self._image_patches is not None: - # TODO: Make this configurable. + if self._model_preprocessing.use_loss_masking_spans: + if isinstance(self._config.loss_masking_spans, NullReaderConfig): + # TODO: We can't tell which dataset this comes from. + warnings.warn( + f"The model uses loss masking spans, but the dataset does not specify any." + " Assuming empty span lists." + ) + self._loss_masking_spans = EmptyRangeReader( + RangeReaderConfig(begin=0, end=0, num_documents=0, num_ranges=0), buffer + ) + else: + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + + if self._model_preprocessing.use_preference_spans: + self._chosen_spans = self._config.chosen_spans.get_reader(buffer) + self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + + if self._model_preprocessing.use_image_patches: + model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches + if isinstance(self._config.image_patches, NullReaderConfig): + warnings.warn( + f"The model uses image patches, but the dataset does not specify any." + " Assuming empty patch lists." + ) + self._image_patches = EmptyPatchReader( + PatchReaderConfig( + begin=0, + end=0, + num_documents=0, + num_patches=0, + num_patch_groups=0, + patch_shape=model_image_preprocessing.patch_shape, + data_type=DataType.uint8, + ), + buffer, + ) + else: + self._image_patches = self._config.image_patches.get_reader(buffer) + + # TODO: Make this configurable. (Add to `model_preprocessing`?) self._image_normalization_config = ImageNormalizationConfig() @property @@ -187,16 +271,28 @@ def num_tokens(self) -> int: return self._config.tokens.num_tokens def get_document(self, index: int, begin: int, end: int) -> Sample: - if self._image_patches is None: - image_patches = None - else: + if self._model_preprocessing.use_image_patches: image_patches = self._image_patches.get_document(index, begin, end) image_patches.patches = self._image_normalization_config.normalize(image_patches.patches) + else: + image_patches = None return LanguageModelSample( self._tokens.get_document(index, begin, end), - None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), - None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), - None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), + ( + self._loss_masking_spans.get_document(index, begin, end) + if self._model_preprocessing.use_loss_masking_spans + else None + ), + ( + self._chosen_spans.get_document(index, begin, end) + if self._model_preprocessing.use_preference_spans + else None + ), + ( + self._rejected_spans.get_document(index, begin, end) + if self._model_preprocessing.use_preference_spans + else None + ), image_patches, ) @@ -334,7 +430,7 @@ def _get_config(self, begin: int, end: int | None): chosen_spans=chosen_spans, rejected_spans=rejected_spans, image_patches=image_patches, - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index a75684d76..9ec991cf0 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -5,6 +5,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapReader, @@ -91,6 +92,16 @@ def get_padding(self, size: int) -> typing.Self: [], ) + @classmethod + def get_empty(cls, size: int, shape: tuple[int, ...]) -> typing.Self: + return PatchSample( + self.patches.new_empty((0, *shape[1:])), + self.token_map.new_empty(0), + self.positions.new_empty([0, len(shape) - 2]), + size, + [], + ) + class PatchBatch(Batch): def __init__( @@ -188,8 +199,8 @@ def _expected_buffer_size(self) -> int: class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._patches = torch.frombuffer( self._buffer, dtype=self._config.data_type.torch, @@ -248,6 +259,16 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ) +class EmptyPatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + return PatchSample( + torch.empty(0, *self._config.patch_shape, dtype=self._config.data_type.torch), + torch.empty(0, dtype=torch.int32), + torch.empty(0, self._config.grid_dims, dtype=torch.int32), + end - begin, + ) + + class PatchWriter(MemmapWriter): def __enter__(self): super().__enter__() @@ -300,5 +321,5 @@ def _get_config(self, begin: int, end: int): num_patch_groups=self._group_count_cumsum[-1], patch_shape=self._patch_shape, data_type=DataType.from_torch(self._data_type), - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 0022b3593..f34cc1343 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -4,6 +4,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapReader, @@ -85,8 +86,8 @@ def _expected_buffer_size(self) -> int: class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._ranges = torch.frombuffer( self._buffer, dtype=torch.int32, @@ -108,6 +109,11 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) +class EmptyRangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + return RangeSample([], end - begin) + + class RangeWriter(MemmapWriter): def __enter__(self): super().__enter__() @@ -135,5 +141,5 @@ def _get_config(self, begin: int, end: int): end=end, num_documents=len(self._count_cumsum) - 1, num_ranges=self._count_cumsum[-1], - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 3f5912e5e..04898a12f 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -4,6 +4,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapIndexedDatasetReader, @@ -111,8 +112,8 @@ def _expected_buffer_size(self) -> int: class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._tokens = torch.frombuffer( self._buffer, dtype=self._config.data_type.torch, @@ -166,5 +167,5 @@ def _get_config(self, begin: int, end: int): num_documents=len(self._size_cumsum) - 1, num_tokens=self._size_cumsum[-1], data_type=DataType.from_torch(self._data_type), - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 1c7be33dd..768d3fdd7 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -2,7 +2,7 @@ import typing from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -19,20 +19,16 @@ def _get_data(self) -> GPTData: def _get_sampling_parameters( self, parameters: dict[str, typing.Any], *, _return_dict: bool = False - ) -> GPTSamplingParameters | dict[str, typing.Any]: + ) -> SamplingParameters | dict[str, typing.Any]: parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - # "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, - # "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - # OK since DPO is not supported for MTP. - # "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) - return parameters if _return_dict else GPTSamplingParameters(**parameters) + return parameters if _return_dict else SamplingParameters(**parameters) def _get_preprocessing_config( self, *, _return_dict: bool = False diff --git a/tests/data/common.py b/tests/data/common.py index ac8d8023c..210749864 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,10 +8,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, ShufflingType -from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, SamplingParameters, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -25,10 +26,10 @@ def get_sampling_data( cache_directory: pathlib.Path | None = None, phase=PhaseType.training, sequence_length: int = 512, - vocab_size: int | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, + preprocessing: LanguageModelPreprocessingConfig, ) -> GPTSamplingData: # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) @@ -38,12 +39,12 @@ def get_sampling_data( gpu=gpu, shuffle=shuffle, ), - parameters=GPTSamplingParameters( + parameters=SamplingParameters( num_samples=num_samples, sequence_length=sequence_length, - vocab_size=vocab_size, truncate_documents=truncate_documents, ), + preprocessing=preprocessing, cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, @@ -65,8 +66,8 @@ def get_test_data_and_compare_samples( shuffle: ShufflingType = ShufflingType.epoch, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - vocab_size: int | None = None, expected_samples: dict[str, list[list[int]]] | list[list[int]], + preprocessing: LanguageModelPreprocessingConfig, ) -> GPTData: distributed_config = DistributedConfig(seed=87522) distributed = Distributed(distributed_config, use_cpu=True) @@ -74,11 +75,7 @@ def get_test_data_and_compare_samples( samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} sampling_parameters = { - dataset_name: GPTSamplingParameters( - num_samples=num_samples, - sequence_length=sequence_length, - vocab_size=vocab_size, - ) + dataset_name: SamplingParameters(num_samples=num_samples, sequence_length=sequence_length) for dataset_name, num_samples in samples_per_dataset.items() } @@ -88,7 +85,7 @@ def get_test_data_and_compare_samples( assert "sampling" not in config config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) - data.setup(distributed, sampling_parameters, cache_directory) + data.setup(distributed, sampling_parameters, preprocessing, cache_directory) with NoAutoValidate(): batch_config = GPTBatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 88ecf2c99..5cad573ca 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,6 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( @@ -84,7 +85,7 @@ def test_blending(probs): # Use a list of integers as a mock dataset, encoding both indexes in the sample. [list(range(i * num_samples, (i + 1) * num_samples)) for i, _ in enumerate(probs)], # noqa probs, - get_sampling_data(num_samples), + get_sampling_data(num_samples, preprocessing=LanguageModelPreprocessingConfig(vocab_size=8192)), ) probs = normalize_probabilities(probs) samples = np.array([dataset[i] for i in range(num_samples)]) @@ -106,8 +107,8 @@ def test_blending(probs): def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() - _, alt_config, _ = get_alt_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() + _, alt_config, _, _ = get_alt_test_dataset() sampled = get_dataset_config( dataset_config := { "type": "blended", @@ -115,7 +116,7 @@ def test_gpt_blended(): "weights": [0.75, 0.25], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) # Test in data. @@ -124,12 +125,15 @@ def test_gpt_blended(): 8, sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, + preprocessing=preprocessing, ) def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() + # Random dataset needs an explicit vocab size. + preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 8192}) sampled = get_dataset_config( dataset_config := { "type": "blended", @@ -140,7 +144,7 @@ def test_gpt_blended_mixed(): "weights": [0.6, 0.4], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) # Test in data. @@ -148,6 +152,6 @@ def test_gpt_blended_mixed(): {"datasets": {"training": dataset_config}}, 8, sequence_length=5, - vocab_size=8192, expected_samples=GPT_BLENDED_MIXED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index d7e750c8b..1580842b7 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,6 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset_tokens, @@ -25,19 +26,19 @@ def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], - ).build() + ).build(LanguageModelPreprocessingConfig(vocab_size=0)) compare_indexed_dataset_tokens( dataset, 3 * COMMON_DATASET_LENGTH, 3 * COMMON_DATASET_TOKENS, {j * COMMON_DATASET_LENGTH + i: sample for j in range(3) for i, sample in COMMON_DATASET_SAMPLES.items()}, ) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) + sampled = dataset.sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_CONCATENATED_SAMPLES) # Test in data. @@ -46,4 +47,5 @@ def test_gpt_concatenate(): 8, sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 0600c5258..fd1aefbd8 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -22,9 +22,9 @@ def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. - sampling_config = get_sampling_data(8, sequence_length=5) + sampling_config = get_sampling_data(8, sequence_length=5, preprocessing=preprocessing) sampled = get_dataset_config( dataset_config := { "type": "fim", @@ -45,4 +45,5 @@ def test_gpt_fim(): 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py index 86fe9c70a..9ef20a8a6 100644 --- a/tests/data/test_image_patch.py +++ b/tests/data/test_image_patch.py @@ -6,7 +6,8 @@ import PIL.Image import pytest -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert @@ -123,8 +124,10 @@ def _get_image_tokens( @pytest.mark.parametrize("image_break_token", (None, 55)) @pytest.mark.parametrize("image_end_token", (None, 132)) def test_gpt_data_with_image_patches(image_break_token, image_end_token): - _, config, hf_path = get_test_dataset_with_image_patches(image_break_token, image_end_token) - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_image_patches(image_break_token, image_end_token) + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) test_index = 2 * (image_break_token is not None) + (image_end_token is not None) hf_dataset = datasets.load_from_disk(hf_path)["train"] @@ -146,9 +149,7 @@ def test_gpt_data_with_image_patches(image_break_token, image_end_token): ) Assert.eq(hf_dataset[index]["image_positions"], DATASET_WITH_IMAGE_PATCHES_IMAGE_POSITIONS[index]) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_images=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) expected_tokens = [ tokens for token_or_patches in DATASET_WITH_IMAGE_PATCHES_SAMPLES[index] diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 443a26819..2d112a5c1 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -1,7 +1,8 @@ import datasets import pytest -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -37,8 +38,10 @@ @pytest.mark.slow def test_gpt_data_with_spans(): - _, config, hf_path = get_test_dataset_with_loss_masking_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_loss_masking_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -54,9 +57,7 @@ def test_gpt_data_with_spans(): hf_dataset[index]["text"], text_spans=[(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]], ) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) # Compare tokens and token spans. Assert.all_equal(document.tokens.tokens, expected_tokens) @@ -73,8 +74,6 @@ def test_gpt_data_with_spans(): for index in DATASET_WITH_SPAN_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) Assert.eq(hf_dataset[index]["loss_masking_spans"], HF_LOSS_MASKING_SPANS[index]) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_SPAN_SAMPLES[index]) Assert.eq(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index ef18337eb..35c290670 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -3,7 +3,8 @@ import pytest import torch -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -39,8 +40,10 @@ @pytest.mark.slow def test_gpt_data_with_spans(): - _, config, hf_path = get_test_dataset_with_preference_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_preference_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -79,9 +82,7 @@ def test_gpt_data_with_spans(): (token_length_cumsum[4], token_length_cumsum[5]), ] - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) token_spans = document.chosen_spans.ranges + document.rejected_spans.ranges # Compare tokens and token spans. @@ -100,8 +101,6 @@ def test_gpt_data_with_spans(): DATASET_WITH_PREFERENCE_SPAN_TEXT[index], ) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) Assert.eq(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 729888d9c..dd4375418 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -3,10 +3,11 @@ import datasets import pytest -from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert from tests.data.common import get_dataset_config @@ -42,11 +43,11 @@ def test_common_prepared_dataset(): We already test the dataset preparator indirectly through the test dataset (`get_test_dataset`). Here we verify the correctness of the prepared dataset directly and check for regressions. """ - path, config, hf_path = get_common_test_dataset() - dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build() + path, config, hf_path, preprocessing = get_common_test_dataset() + dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) dataset_from_shard = get_dataset_config( {"type": "memmap", "path": path / "shard_0_0.fast_llm_dataset"}, MemmapDatasetConfig - ).build() + ).build(preprocessing) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -71,18 +72,18 @@ def test_common_prepared_dataset(): # Check some numerical values. for index in COMMON_DATASET_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) - document = dataset.get_document(index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0)) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) @pytest.mark.slow def test_preparator_sharded(): - path, config, hf_path = get_sharded_test_dataset() + path, config, hf_path, preprocessing = get_sharded_test_dataset() dataset_config = get_dataset_config(config, GPTDatasetFromFileConfig)._load_config() Assert.custom(isinstance, dataset_config, BlendedDatasetConfig) Assert.eq(dataset_config.weights, [0.33003587104248827, 0.3455874161709333, 0.3243767127865784]) - datasets_ = [dataset_config_.build() for dataset_config_ in dataset_config.datasets] + datasets_ = [dataset_config_.build(preprocessing) for dataset_config_ in dataset_config.datasets] Assert.eq([len(dataset) for dataset in datasets_], lengths := [334, 333, 333]) Assert.eq([dataset.num_tokens for dataset in datasets_], [14813, 15511, 14559]) @@ -101,7 +102,7 @@ def test_preparator_sharded(): @pytest.mark.slow def test_preparator_split(): - path, config, hf_path = get_split_test_dataset() + path, config, hf_path, _ = get_split_test_dataset() dataset_config = { split: get_dataset_config(split_config, GPTDatasetFromFileConfig)._load_config().to_dict() for split, split_config in config.items() @@ -125,7 +126,7 @@ def test_preparator_split(): @pytest.mark.slow def test_preparator_split_sharded(): - path, config, hf_path = get_split_sharded_test_dataset() + path, config, hf_path, _ = get_split_sharded_test_dataset() dataset_config = { split: get_dataset_config(split_config, GPTDatasetFromFileConfig)._load_config().to_dict() for split, split_config in config.items() @@ -182,7 +183,9 @@ def test_dataset_preparator_from_hub(): assert (croissant_path := output_path / "croissant.json").is_file() Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k") - dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build() + dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build( + LanguageModelPreprocessingConfig(vocab_size=0) + ) Assert.custom(isinstance, dataset, MemmapDataset) hf_dataset = datasets.load_dataset("openai/gsm8k", "main", split="test") diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 7a31358b9..d32fb9880 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -1,4 +1,5 @@ from fast_llm.data.dataset.gpt.config import GPTRandomDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -16,8 +17,9 @@ def test_gpt_random_dataset(): # Make sure the random dataset works and check for unintended changes in behavior. + preprocessing = LanguageModelPreprocessingConfig(vocab_size=8192) sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( - get_sampling_data(4, sequence_length=7, vocab_size=8192) + get_sampling_data(4, sequence_length=7, preprocessing=preprocessing) ) compare_sampled_dataset(sampled, RANDOM_DATASET_EXPECTED_SAMPLES) @@ -26,6 +28,6 @@ def test_gpt_random_dataset(): {"datasets": {"training": config}}, 4, sequence_length=7, - vocab_size=8192, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2d102be01..d6a935c61 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,9 +2,10 @@ import pytest import torch -from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert @@ -38,10 +39,10 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() sampled = get_dataset_config( dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) # Test in data. @@ -50,6 +51,7 @@ def test_gpt_sampled(): 8, sequence_length=5, expected_samples=GPT_MEMMAP_SAMPLES, + preprocessing=preprocessing, ) @@ -59,7 +61,7 @@ def __init__(self, samples): self._samples = samples def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None ) -> SampleType: if end is None: end = len(self._samples[index]) @@ -98,7 +100,15 @@ def test_gpt_sample(seed, shuffle): previous_samples = None # Loop instead of parametrizing for the check below. for num_samples in (20, 10, 6, 5, 2, 1): - sampled = TEST_DATASET.sample(get_sampling_data(num_samples, sequence_length=5, seed=seed, shuffle=shuffle)) + sampled = TEST_DATASET.sample( + get_sampling_data( + num_samples, + sequence_length=5, + seed=seed, + shuffle=shuffle, + preprocessing=LanguageModelPreprocessingConfig(vocab_size=0), + ) + ) samples = validate_indexed_dataset_sampling(sampled) if previous_samples is not None and shuffle != ShufflingType.full: # Check that the sequence is independent of `num_sample`. @@ -162,6 +172,7 @@ def test_gpt_sample_padding(): seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, + preprocessing=LanguageModelPreprocessingConfig(vocab_size=vocab_size), ) if total_tokens == 0: with pytest.raises(RuntimeError): diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 224b18270..54263b8e2 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -31,15 +31,15 @@ def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, DatasetSliceConfig[LanguageModelSample], - ).build() + ).build(preprocessing) compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) + sampled = dataset.sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) validate_indexed_dataset_sampling(sampled, GPT_SLICE_VALIDATION_SAMPLES) # Test in data with multiple phases. @@ -72,4 +72,5 @@ def test_gpt_slice(): "training": GPT_SLICE_TRAINING_SAMPLES, "validation": GPT_SLICE_VALIDATION_SAMPLES, }, + preprocessing=preprocessing, ) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index e2cadf717..e29050b28 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -44,7 +44,7 @@ def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): and prefix.with_suffix(".bin").is_file() and prefix.parent.joinpath("fast_llm_config.yaml").is_file() ): - _, _, hf_path = get_common_test_dataset() + _, _, hf_path, _ = get_common_test_dataset() hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() samples = [ diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index ed3f01307..7348a79ef 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -7,7 +7,9 @@ import PIL.Image from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.utils import padded_cumsum from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH @@ -158,7 +160,7 @@ def _get_test_dataset( path: pathlib.Path, seed: int, tokenizer_path: str = TOKENIZER_PATH, - vocab_size: int | None = None, + max_vocab_size: int | None = None, documents_per_shard: int = 10**6, num_documents: int = 1000, min_document_size: int = 5, @@ -173,7 +175,7 @@ def _get_test_dataset( min_image_size: int = 4, max_image_size: int = 32, config_only: bool = False, -) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path]: +) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: config_paths = ( [path / "fast_llm_config.yaml"] if splits is None @@ -214,7 +216,7 @@ def _get_test_dataset( "load_from_disk": True, "source_schema": source_schema, }, - "tokenizer": {"path": tokenizer_path, "max_vocab_size": vocab_size}, + "tokenizer": {"path": tokenizer_path, "max_vocab_size": max_vocab_size}, "output_path": path, "documents_per_shard": documents_per_shard, "splits": splits, @@ -231,28 +233,45 @@ def _get_test_dataset( for split, config_path in zip(splits, config_paths, strict=True) } ) - return path, config, hf_path + preprocessing = LanguageModelPreprocessingConfig( + tokenizer={"type": "tokenizer", "path": tokenizer_path, "max_vocab_size": max_vocab_size}, + image_patches=NullPreprocessingConfig() if image_patch_config is None else image_patch_config, + vocab_size=max_vocab_size or 0, + use_loss_masking_spans=max_loss_masking_spans > 0, + use_preference_spans=has_preference_spans, + ) + return path, config, hf_path, preprocessing -def get_common_test_dataset(): +def get_common_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "common_dataset", seed=1234) -def get_alt_test_dataset(): +def get_alt_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "other_dataset", seed=2345) -def get_sharded_test_dataset(): +def get_sharded_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "common_dataset_sharded", seed=1234, documents_per_shard=350) -def get_split_test_dataset(): +def get_split_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset( DATASET_CACHE / "common_dataset_split", seed=1234, splits={"training": 1, "validation": 1} ) -def get_split_sharded_test_dataset(): +def get_split_sharded_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset( DATASET_CACHE / "common_dataset_split_sharded", seed=1234, @@ -261,15 +280,21 @@ def get_split_sharded_test_dataset(): ) -def get_test_dataset_with_loss_masking_spans(): +def get_test_dataset_with_loss_masking_spans() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5) -def get_test_dataset_with_preference_spans(): +def get_test_dataset_with_preference_spans() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) -def get_test_dataset_with_image_patches(image_break_token: int | None = None, image_end_token: int | None = None): +def get_test_dataset_with_image_patches( + image_break_token: int | None = None, image_end_token: int | None = None +) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: return _get_test_dataset( DATASET_CACHE / f"dataset_with_image_patches_{image_break_token}_{image_end_token}", seed=1234, @@ -289,7 +314,7 @@ def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, - vocab_size=MODEL_TEST_VOCAB_SIZE, + max_vocab_size=MODEL_TEST_VOCAB_SIZE, splits={"training": 969, "validation": 30, "test": 1}, config_only=config_only, ) @@ -299,7 +324,7 @@ def get_multimodal_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset_multimodal", seed=1234, - vocab_size=MODEL_TEST_VOCAB_SIZE, + max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_images=2, image_patch_config=ImagePatchConfig( height=4, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 752e3a8c8..186991ed5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -95,7 +95,7 @@ class ModelTestingConfig: ) def __post_init__(self): - _, config, _ = self.get_dataset(config_only=True) + _, config, _, _ = self.get_dataset(config_only=True) self.config_dict["data"]["datasets"] = config @functools.cached_property From d27a8151b638e3ef31eb947cef66d7bd8121cb34 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 17:21:29 -0500 Subject: [PATCH 5/8] fix --- fast_llm/data/preprocessing/tokenizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 9e11fa66c..2963e8e63 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -96,8 +96,11 @@ def tokenize( if self._config.max_vocab_size is not None: # In some cases creating a tensor before restricting the vocab size may cause an overflow. - ( - torch.tensor(tokens, dtype=torch.int64 if len(self.tokenizer) > torch.iinfo().max else data_type.torch) + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) % self._config.max_vocab_size ).to(data_type.torch) else: From 5ab6cd03b28506e6847b7b7cea08f083600e9b07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 22:30:27 -0500 Subject: [PATCH 6/8] fixes --- fast_llm/data/preprocessing/language_model.py | 5 ++-- fast_llm/data/sample/language_model.py | 1 - fast_llm/engine/checkpoint/distributed.py | 21 ++++++++------ fast_llm/engine/multi_stage/fsdp.py | 6 ++++ fast_llm/engine/multi_stage/multi_stage.py | 22 ++++++++++---- fast_llm/models/multimodal/trainer.py | 1 + fast_llm/utils.py | 12 +++++--- tests/data/common.py | 7 ++++- tests/data/test_blending.py | 11 ++++--- tests/data/test_concatenate.py | 2 +- tests/data/test_preparator.py | 2 +- tests/data/test_sampling.py | 3 -- tests/models/distributed_test_checkpoint.py | 2 +- tests/models/test_checkpoint.py | 4 ++- tests/test_varlen.py | 9 ++---- tests/utils/dataset.py | 2 +- tests/utils/model_configs.py | 29 +++++++++---------- 17 files changed, 81 insertions(+), 58 deletions(-) diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index 6c38c3f4e..88ec8f245 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -19,7 +19,7 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, # so we provide the vocab size and use it for compatibility checks. image_patches: PreprocessingConfig = Field() - vocab_size: int = Field() + vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) @@ -35,7 +35,8 @@ def use_image_patches(self) -> bool: def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? - Assert.geq(self.vocab_size, preprocessing.vocab_size) + if self.vocab_size is not None and preprocessing.vocab_size is not None: + Assert.leq(self.vocab_size, preprocessing.vocab_size) if preprocessing.use_preference_spans: # Preference spans are strictly needed for DPO loss. assert self.use_preference_spans diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 1331cf82a..beadb1161 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -156,7 +156,6 @@ def _validate(self) -> None: else: image_patches = NullPreprocessingConfig() self.preprocessing = LanguageModelPreprocessingConfig( - vocab_size=0, image_patches=image_patches, use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index c2f4d8cdd..d953ea35d 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -120,12 +120,15 @@ def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in loaded_shards} - for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): - for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): - counter = self_fsdp.copy_shard_overlaps( - loaded_fsdp, - self_fsdp_shards, - loaded_fsdp_shards, - ) - for parameter, count in counter.items(): - context.mark_as_loaded(count, parameter, True) + for loaded_stage, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): + # Skip tied weight copies to avoid duplicate loads. + # We can't call `loaded_stage.is_tied_weight_copy` because the loaded model isn't setup. + if not loaded_stage.index not in loaded_model.stages_owned: + for self_stage, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + self_fsdp_shards, + loaded_fsdp_shards, + ) + for parameter, count in counter.items(): + context.mark_as_loaded(count, parameter, True) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 827079f6e..36e8ff20d 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,4 +1,5 @@ import dataclasses +import logging import math import typing @@ -18,6 +19,8 @@ from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta from fast_llm.utils import Assert, clamp, padded_cumsum +logger = logging.getLogger(__name__) + class FSDP: _is_setup: bool = False @@ -276,6 +279,9 @@ def split_buffer(self, buffer: torch.Tensor) -> dict[str, torch.Tensor]: return {name: self._get_parameter_in_buffer(buffer, name) for name in self._parameter_metas} def _get_parameter_in_buffer(self, buffer: torch.Tensor, name: str) -> torch.Tensor: + logger.info( + f"{name}, {self.get_parameter_begin_in_buffer(name)}, {self.get_parameter_end_in_buffer(name)}, {buffer.shape}, {self._parameter_metas[name]}" + ) return buffer[self.get_parameter_begin_in_buffer(name) : self.get_parameter_end_in_buffer(name)].view( self._parameter_metas[name].shape ) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index f45f93862..89be60c24 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -3,6 +3,7 @@ import typing import warnings +import safetensors.torch import torch from torch._C._distributed_c10d import ProcessGroup @@ -21,6 +22,7 @@ from fast_llm.utils import Assert, get_unique logger = logging.getLogger(__name__) +safetensors.torch.safe_open class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): @@ -426,6 +428,10 @@ def stages(self) -> list[Stage]: def stages_on_device(self) -> dict[int, Stage]: return self._stages_on_device + @property + def stages_owned(self) -> dict[int, Stage]: + return self._stages_owned + @property def tied_parameters(self) -> dict[str, "TiedParameter"]: return self._tied_parameters @@ -485,11 +491,17 @@ def get_state_tensor_iterator( ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) - for shard_index, (stage, shard) in enumerate(zip(self._stages_owned.values(), shard_split, strict=True)): - for name, tensor in stage._export_shard( - shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type - ): # noqa - yield name, shard_name, tensor + logger.info( + f"{shard_name}, {self._shards[shard_name].shape}, {self._stage_weight_shard_sizes}, {self._stages_owned.values}, {[x.shape for x in shard_split]}" + ) + for shard_index, ((stage_index, stage), shard) in enumerate( + zip(self._stages_on_device.items(), shard_split, strict=True) + ): + if stage_index in self._stages_owned: + for name, tensor in stage._export_shard( + shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type + ): # noqa + yield name, shard_name, tensor def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torch.Tensor | SafeTensorSlice): """ diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py index cd8e09cae..43a8f8885 100644 --- a/fast_llm/models/multimodal/trainer.py +++ b/fast_llm/models/multimodal/trainer.py @@ -14,6 +14,7 @@ def _get_preprocessing_config( ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: out = super()._get_preprocessing_config(_return_dict=True) out["image_patches"] = { + "type": "image_patch", "height": self._config.model.base_model.vision_encoder.embeddings.patch_height, "width": self._config.model.base_model.vision_encoder.embeddings.patch_width, # TODO: Max shape and special tokens are unspecified in the model. diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 83675ac74..259073e32 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -146,19 +146,23 @@ def multiple(x, y): assert x % y == 0, f"{x} not a multiple of {y}" @staticmethod - def rms_close(x, y, threshold): + def rms_close(x, y, threshold, *, msg=None): rms = rms_diff(x, y).detach().item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + "" if msg is None else f"| {msg}" + ) @staticmethod - def rms_close_relative(x, y, threshold, min_threshold=0): + def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): import torch Assert.eq(x.shape, y.shape) scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5 threshold = max(threshold * scale, min_threshold) rms = rms_diff(x, y).item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + "" if msg is None else f"| {msg}" + ) @staticmethod def all_equal(x, *args): diff --git a/tests/data/common.py b/tests/data/common.py index 210749864..34fdba321 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -29,10 +29,12 @@ def get_sampling_data( gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, - preprocessing: LanguageModelPreprocessingConfig, + preprocessing: LanguageModelPreprocessingConfig | None = None, ) -> GPTSamplingData: # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) + if preprocessing is None: + preprocessing = LanguageModelPreprocessingConfig() return GPTSamplingData( config=SamplingConfig( seed=seed, @@ -122,6 +124,9 @@ def compare_indexed_dataset_tokens( def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: + # Uncomment to print the current list of samples. + # for i in range(len(expected_samples)): + # print(i, sampled[i].tokens.tokens.tolist()) Assert.eq(len(sampled), len(expected_samples)) Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 5cad573ca..989e99b24 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,6 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( @@ -44,12 +43,12 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_MIXED_SAMPLES = [ [49152, 46, 10, 819, 19, 45], - [916, 6683, 7685, 1277, 5106, 378], + [25492, 15877, 37874, 8570, 31649, 15521], [45, 69, 17, 86, 38826, 15], - [3359, 6803, 780, 4561, 669, 7878], + [3359, 20945, 33437, 32454, 42084, 45942], [15, 25, 51, 31, 32348, 64], [64, 17, 93, 78, 40, 1793], - [6920, 2218, 2921, 3963, 7606, 6904], + [15112, 36731, 47864, 35586, 33356, 37537], [1793, 1, 1746, 38, 27, 58], ] @@ -85,7 +84,7 @@ def test_blending(probs): # Use a list of integers as a mock dataset, encoding both indexes in the sample. [list(range(i * num_samples, (i + 1) * num_samples)) for i, _ in enumerate(probs)], # noqa probs, - get_sampling_data(num_samples, preprocessing=LanguageModelPreprocessingConfig(vocab_size=8192)), + get_sampling_data(num_samples), ) probs = normalize_probabilities(probs) samples = np.array([dataset[i] for i in range(num_samples)]) @@ -133,7 +132,7 @@ def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # Random dataset needs an explicit vocab size. - preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 8192}) + preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 50000}) sampled = get_dataset_config( dataset_config := { "type": "blended", diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 1580842b7..19539cc8c 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -31,7 +31,7 @@ def test_gpt_concatenate(): dataset = get_dataset_config( dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], - ).build(LanguageModelPreprocessingConfig(vocab_size=0)) + ).build(LanguageModelPreprocessingConfig()) compare_indexed_dataset_tokens( dataset, 3 * COMMON_DATASET_LENGTH, diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index dd4375418..f4f6fab82 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -184,7 +184,7 @@ def test_dataset_preparator_from_hub(): Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k") dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build( - LanguageModelPreprocessingConfig(vocab_size=0) + LanguageModelPreprocessingConfig() ) Assert.custom(isinstance, dataset, MemmapDataset) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index d6a935c61..2e47fd6aa 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -5,7 +5,6 @@ from fast_llm.data.dataset.config import SamplingParameters, ShufflingType from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert @@ -106,7 +105,6 @@ def test_gpt_sample(seed, shuffle): sequence_length=5, seed=seed, shuffle=shuffle, - preprocessing=LanguageModelPreprocessingConfig(vocab_size=0), ) ) samples = validate_indexed_dataset_sampling(sampled) @@ -172,7 +170,6 @@ def test_gpt_sample_padding(): seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, - preprocessing=LanguageModelPreprocessingConfig(vocab_size=vocab_size), ) if total_tokens == 0: with pytest.raises(RuntimeError): diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 217ecd0e1..001eb36da 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -41,7 +41,7 @@ def _test_load_and_save_parallel( mode=StageMode.inference, ) for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): - logger.info(f"Loading {save_format.name} checkpoint to {config.save_path / save_format.name}") + logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)) del model gc.collect() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 9acf8a9d7..bb53de29e 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -428,8 +428,10 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: return None +# We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. +# This should still run after `test_save_and_load_in_parallel` @requires_cuda -@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config: DistributedSaveLoadConfig, diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 126a3e1e5..730bab2c9 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -8,6 +8,7 @@ from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import gdn as gdn_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.utils import Assert @pytest.fixture @@ -207,13 +208,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - torch.testing.assert_close( - _param_grad(param), - _param_grad(param_ref), - atol=1e-3, - rtol=1e-3, - msg=f"Grad mismatch for parameter {name}", - ) + Assert.rms_close_relative(_param_grad(param), _param_grad(param_ref), 1e-3, 1e-3, msg=name) if __name__ == "__main__": diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7348a79ef..47f254893 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -236,7 +236,7 @@ def _get_test_dataset( preprocessing = LanguageModelPreprocessingConfig( tokenizer={"type": "tokenizer", "path": tokenizer_path, "max_vocab_size": max_vocab_size}, image_patches=NullPreprocessingConfig() if image_patch_config is None else image_patch_config, - vocab_size=max_vocab_size or 0, + vocab_size=max_vocab_size, use_loss_masking_spans=max_loss_masking_spans > 0, use_preference_spans=has_preference_spans, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e43503137..b0a9acf36 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -584,8 +584,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 }, compare_factor=2, - # modes not supported with reference models - skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), + # Modes not supported with reference models + skip_tests=("sdp", "ms", "pp"), ) _update_and_add_testing_config( @@ -611,11 +611,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, df4, df4_sf, tp2, stp2, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=8, - # modes not supported with reference models - skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "stp2_ce4"), + # Modes not supported with reference models and/or activation distillation. + # TODO: Fix gradient accumulation and fp16, add TP support. + skip_tests=("sdp", "ms", "pp", "tp", "df", "bf", "fp16"), ) _update_and_add_testing_config( @@ -674,8 +675,8 @@ def _update_and_add_testing_config( checkpoint_format=AprielHybridSSMCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, # TODO: Fix and bring back to `testing_groups` ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.broken, @@ -684,7 +685,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=(r"sdp", r"ms"), + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( @@ -725,10 +726,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=( - r"sdp", - r"ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + skip_tests=("sdp", "ms"), ) @@ -846,15 +844,16 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + # TODO: Fix (`fast_llm/models/gpt/conversion/apriel.py:235: KeyError: 'value_head_dim'`) + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + compare_factor=10.0, # High diff for fp16 and bf16 due to rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! - skip_tests=(r"sdp", r"ms", r"^tp2$"), + skip_tests=("sdp", "ms", r"(? Date: Tue, 9 Dec 2025 19:15:49 -0500 Subject: [PATCH 7/8] fix --- tests/functional/test_functional.py | 1 + tests/utils/model_configs.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 05fafe7a9..c48a0a531 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -136,6 +136,7 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow +@pytest.mark.skip("Dropless MoE is broken") @requires_cuda def test_dropless_mlp(): num_experts = 4 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 736e4897e..7f66bec7d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -288,8 +288,8 @@ def _update_and_add_testing_config( ], checkpoint_format=None, groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.main, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: PP checkpoint failing for tied weights. ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, @@ -636,12 +636,12 @@ def _update_and_add_testing_config( checkpoint_format=MixtralCheckpointFormat, # TODO: New base image broke mixtral groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, - ModelTestingGroup.megatron: ModelTestingGroupAction.normal, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.megatron: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=2.0, ) From db93bb56dd90ed02253ace0498030e4d54687561 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Dec 2025 20:31:29 -0500 Subject: [PATCH 8/8] fixes --- fast_llm/data/sample/token.py | 2 +- tests/layers/test_gdn_equivalence.py | 4 ---- tests/utils/model_configs.py | 17 ++++++++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index b456baaa2..1bc9ef1a1 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -127,7 +127,7 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: begin_ = self._size_cumsums[index].item() # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. # Convert begin and end to int to avoid numpy dtype overflow when adding to begin_ - return TokenSample(self._tokens[begin_ + int(begin) : begin_ + int(end)].to(torch.int64), [end - begin]) + return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) def get_document_sizes(self) -> torch.Tensor: return self._size_cumsums[1:] - self._size_cumsums[:-1] diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index dfbaab9c9..803d2eaac 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -100,7 +100,3 @@ def test_fast_llm_gdn_matches_apriel2_forward(): fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7f66bec7d..9231168aa 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -40,6 +40,9 @@ _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +TP_NO_STP = r"(?:^|(?<=[^s]))tp" + + class ModelTestingGroup(enum.StrEnum): basic = "basic" checkpoint = "checkpoint" @@ -863,7 +866,7 @@ def _update_and_add_testing_config( compare_factor=10.0, # High diff for fp16 and bf16 due to rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! - skip_tests=("sdp", "ms", r"(?