From 2ab18258563d7fd81d1d13ee6bc4c22917582dfe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Dec 2025 20:25:19 -0500 Subject: [PATCH 01/14] Fix rotary 2d --- fast_llm/layers/attention/rotary/rotary.py | 14 +++++-- tests/layers/test_rotary.py | 46 ++++++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 tests/layers/test_rotary.py diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 55d929f8a..258f9d8bc 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -194,11 +194,17 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: patch_positions = kwargs[VisionKwargs.patch_positions] if not hasattr(self, "_frequencies"): self._frequencies = self._config.theta ** -torch.arange( - 0, 1, 4 / self._head_size, device=kwargs[AttentionKwargs.device], dtype=torch.float64 - ) + 0, 1, 2 / self._head_size, device=kwargs[AttentionKwargs.device], dtype=torch.float64 + ).view(-1, 2) + # TODO: Pre-compute 2d frequencies? - angles = torch.outer(patch_positions.flatten(), self._frequencies).view( - len(patch_positions), self._head_size // 2 + # Equivalent to the separate outer product of height and width frequencies. + # Pre-allocate output to avoid a reshape with copy. + angles = self._frequencies.new_empty(len(patch_positions), self._head_size // 2) + torch.bmm( + patch_positions.T.unsqueeze(2).to(torch.float64), + self._frequencies.T.unsqueeze(1), + out=angles.view(-1, 2, self._head_size // 4).permute(1, 0, 2), ) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) if not self._config.complex_format: diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py new file mode 100644 index 000000000..85d72b316 --- /dev/null +++ b/tests/layers/test_rotary.py @@ -0,0 +1,46 @@ +import torch +import transformers + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import Rotary2DConfig +from fast_llm.layers.vision.config import VisionKwargs +from fast_llm.utils import Assert +from tests.utils.utils import requires_cuda + + +@requires_cuda +def test_rotary_2d(): + """ + Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. + """ + head_dim = 16 + num_heads = 8 + + patch_positions = torch.tensor( + [[h, w] for h in range(4) for w in range(4)], + dtype=torch.int64, + device="cuda", + ) + query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device="cuda").normal_() + key = torch.empty_like(query).normal_() + + pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to("cuda") + # Convert patch positions (h, w) to Pixtral's linear position IDs + # Pixtral expects: position_id = h * max_patches_per_side + w + position_ids = ( + patch_positions[None, :, 0] * (pixtral_config.image_size // pixtral_config.patch_size) + + patch_positions[None, :, 1] + ) + output_pixtral_query, output_pixtral_key = transformers.models.pixtral.modeling_pixtral.apply_rotary_pos_emb( + query, key, *pixtral_rotary(query, position_ids), unsqueeze_dim=2 + ) + + fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: "cuda"} + fast_llm_rotary.preprocess(kwargs) + output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) + + Assert.rms_close(output_pixtral_query, output_fast_llm_query, 1e-5) + Assert.rms_close(output_pixtral_key, output_fast_llm_key, 1e-5) From 8305dd586b6dee97acdf95e3c59db5a4e328f848 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 3 Dec 2025 20:29:33 -0500 Subject: [PATCH 02/14] stuff --- .../data/preparator/gpt_memmap/prepare.py | 3 -- fast_llm/data/preprocessing/abstract.py | 28 +++++++++++++++++++ fast_llm/data/preprocessing/image_patch.py | 7 +++-- fast_llm/data/preprocessing/tokenizer.py | 9 ++++-- fast_llm/data/sample/abstract.py | 3 ++ fast_llm/layers/vision/config.py | 1 - .../models/multimodal/conversion/llava.py | 2 -- 7 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 fast_llm/data/preprocessing/abstract.py diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 5606eeb98..94bab200e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -198,11 +198,9 @@ def _prepare_shard( return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - # TODO: ======= Extract so we can use elsewhere? (ex. inference) ====== text = sample[self._source_schema.text] all_spans = [] if self._source_schema.has_loss_masking_span: - # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (SpanType.loss_masking, (begin, last + 1)) @@ -213,7 +211,6 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: all_spans.extend(loss_masking_spans) if self._source_schema.has_preference_spans: - # TODO: ===== Was `self._config.dataset.field` (bug?) ====== full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] # compute chosen span diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py new file mode 100644 index 000000000..8dbaa3626 --- /dev/null +++ b/fast_llm/data/preprocessing/abstract.py @@ -0,0 +1,28 @@ +import typing + +from fast_llm.config import Config, config_class + + +@config_class(registry=True) +class PreprocessingConfig(Config): + """ + Base preprocessing configuration, with dynamic registry so configs can be saved with memmap datasets. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is PreprocessingConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullPreprocessingConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + +@config_class(dynamic_type={PreprocessingConfig: "none"}) +class NullPreprocessingConfig(PreprocessingConfig): + """ + Configuration for unspecified preprocessing. + """ + + _abstract = False diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index d6f5bf190..22ec04d68 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -4,6 +4,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div @@ -11,13 +12,15 @@ import torch -@config_class() -class ImagePatchConfig(Config): +@config_class(dynamic_type={PreprocessingConfig: "image_patch"}) +class ImagePatchConfig(PreprocessingConfig): """ Configuration for the tokenizer. The tokenizer is needed for FIM and dataset preparation. """ + _abstract = False + height: int = Field( default=16, desc="Height of the image patches, in pixels.", diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 70291bcaa..356407541 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -1,7 +1,8 @@ import pathlib import typing -from fast_llm.config import Config, Configurable, Field, FieldHint, config_class +from fast_llm.config import Configurable, Field, FieldHint, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -11,13 +12,15 @@ import torch -@config_class() -class TokenizerConfig(Config): +@config_class(dynamic_type={PreprocessingConfig: "tokenizer"}) +class TokenizerConfig(PreprocessingConfig): """ Configuration for the tokenizer. The tokenizer is needed for FIM and dataset preparation. """ + _abstract = False + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 11f5d187c..0db7d1c8a 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -4,6 +4,7 @@ import typing from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -101,6 +102,8 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): # Constant strings for alignment safety. header: typing.ClassVar[bytes] footer: typing.ClassVar[bytes] + # Additional information about how the dataset was prepared. + preprocessing: PreprocessingConfig = Field() @property def reader_class(self) -> "type[MemmapReader]": diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py index 2e0389e89..924e1c305 100644 --- a/fast_llm/layers/vision/config.py +++ b/fast_llm/layers/vision/config.py @@ -99,7 +99,6 @@ def layer_class(self) -> "type[PatchEmbeddings]": @config_class(registry=True) class VisionEncoderConfig(BlockConfig): _abstract = False - # TODO: ====== Rename to patch_embeddings? ====== embeddings: PatchEmbeddingsConfig = Field( desc="Configuration for the patch convolution layer.", hint=FieldHint.architecture, diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 9657d71b6..748f2f89e 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -79,7 +79,6 @@ def export_config(cls, config: AttentionConfig) -> dict: class PixtralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[PixtralAttentionConverter]] = PixtralAttentionConverter - # TODO: ====== MistralMLPConverter (#391 / #382) ====== mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter normalization_converter_class: typing.ClassVar[type[PixtralNormalizationConverter]] = PixtralNormalizationConverter hf_mixer_name: typing.ClassVar[str] = "attention" @@ -225,7 +224,6 @@ def import_config(cls, config: dict) -> dict: @classmethod def export_config(cls, config: VisionEncoderConfig) -> dict: Assert.custom(isinstance, config, VisionEncoderConfig) - # TODO: ====== image_size? ====== vision_config = safe_merge_dicts( cls.embeddings_converter_class.export_config(config.embeddings), cls.encoder_converter_class.export_config(config.encoder), From b6e38b872cfca171cb6582bd4731eff2dd2f0f10 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 4 Dec 2025 01:03:32 -0500 Subject: [PATCH 03/14] stuff --- fast_llm/data/data/abstract.py | 4 + fast_llm/data/data/gpt/data.py | 5 +- fast_llm/data/dataset/config.py | 25 ++++--- fast_llm/data/dataset/gpt/config.py | 20 ++--- fast_llm/data/dataset/gpt/legacy_memmap.py | 44 +++++------ fast_llm/data/dataset/gpt/random.py | 34 +++------ fast_llm/data/dataset/memmap.py | 28 ++++--- .../data/preparator/gpt_memmap/prepare.py | 11 +++ fast_llm/data/preprocessing/abstract.py | 12 +++ fast_llm/data/preprocessing/image_patch.py | 12 +++ fast_llm/data/preprocessing/language_model.py | 40 ++++++++++ fast_llm/data/preprocessing/tokenizer.py | 8 +- fast_llm/data/sample/abstract.py | 18 ++++- fast_llm/data/sample/language_model.py | 75 +++++++++---------- fast_llm/data/sample/patch.py | 1 + fast_llm/data/sample/range.py | 1 + fast_llm/data/sample/token.py | 1 + fast_llm/engine/training/trainer.py | 7 +- fast_llm/models/gpt/trainer.py | 21 +++++- fast_llm/models/multimodal/trainer.py | 17 ++++- tests/models/test_match_megatron.py | 5 +- 21 files changed, 261 insertions(+), 128 deletions(-) create mode 100644 fast_llm/data/preprocessing/language_model.py diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index c67dc0321..2c1902796 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -16,6 +17,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" _sampling_parameters: dict[str, SamplingParameters] + _preprocessing: PreprocessingConfig _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: @@ -27,11 +29,13 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, SamplingParameters], + preprocessing: PreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: self._distributed = distributed self._sampling_parameters = sampling_parameters + self._preprocessing = preprocessing self._cache_directory = cache_directory @property diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index de47ef761..084dadc7d 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -13,6 +13,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig @@ -47,6 +48,7 @@ def setup( self, distributed: "Distributed", sampling_parameters: dict[str, GPTSamplingParameters], + preprocessing: LanguageModelPreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -54,7 +56,7 @@ def setup( Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ - super().setup(distributed, sampling_parameters, cache_directory) + super().setup(distributed, sampling_parameters, preprocessing, cache_directory) # Check and raise an error if a used dataset is not defined. for dataset_name in self._sampling_parameters.keys(): @@ -81,6 +83,7 @@ def setup( sampling = GPTSamplingData( config=self._config.sampling, parameters=sampling_parameters, + preprocessing=preprocessing, cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7611b4a31..2858d8d18 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -9,12 +9,12 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -84,6 +84,7 @@ class SamplingData: # TODO: This prevents the sampling config from being pickled in multiprocessing. distributed: "Distributed" dataset_name: str + preprocessing: PreprocessingConfig # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count @@ -114,16 +115,16 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType] @config_class() class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): - def build(self) -> SamplableDataset[SampleType]: + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: raise NotImplementedError() def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: - return self.build().sample(sampling) + return self.build(sampling.preprocessing).sample(sampling) @config_class() class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): - def build(self) -> "IndexedDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": raise NotImplementedError() @@ -147,10 +148,10 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)), ) - def build(self) -> "ConcatenatedDataset": + def build(self, preprocessing: PreprocessingConfig) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build(preprocessing) for dataset in self.datasets]) @config_class(dynamic_type={SampledDatasetConfig: "slice"}) @@ -180,10 +181,10 @@ class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]) hint=FieldHint.core, ) - def build(self) -> "DatasetSlice": + def build(self, preprocessing: PreprocessingConfig) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - dataset = self.dataset.build() + dataset = self.dataset.build(preprocessing) size = len(dataset) return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", @@ -272,7 +273,7 @@ def build_and_sample( @config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): +class MemmapDatasetConfig[SampleType: Sample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -280,12 +281,12 @@ class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[ hint=FieldHint.core, ) - def build(self) -> "IndexedDataset[SampleType]": + def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleType]": name = str(self.path).replace("/", "__") if self.path.is_file(): from fast_llm.data.dataset.memmap import MemmapDataset - return MemmapDataset[SampleType](name, self.path) + return MemmapDataset[SampleType](name, self.path, preprocessing) elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): logger.warning( "Using the legacy memmap dataset format." @@ -294,6 +295,6 @@ def build(self) -> "IndexedDataset[SampleType]": ) from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset - return LegacyMemmapDataset[SampleType](name, self.path) + return LegacyMemmapDataset[SampleType](name, self.path, preprocessing) else: raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 175779823..4336657ce 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,12 +8,14 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset - from fast_llm.data.dataset.gpt.random import GPTRandomDataset + from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset + from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -23,9 +25,8 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - # TODO: Only used for random dataset. Remove? Or use as safety check? - vocab_size: int | None = None # TODO: ====== Get these to memmap dataset (currently ignored) ====== + vocab_size: int | None = None use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False use_images: bool = False @@ -39,10 +40,11 @@ class GPTSamplingData(SamplingData): """ parameters: GPTSamplingParameters + preprocessing: LanguageModelPreprocessingConfig @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -50,10 +52,10 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetCo hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset[SampleType]": - from fast_llm.data.dataset.gpt.random import GPTRandomDataset + def build_and_sample(self, sampling: GPTSamplingData) -> GPTRandomSampledDataset[SampleType]: + from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - return GPTRandomDataset[SampleType](self.name) + return GPTRandomSampledDataset[SampleType](sampling, self.name) @config_class(dynamic_type={SampledDatasetConfig: "file"}) @@ -69,10 +71,10 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType] config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset[SampleType]: + def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleType]: config = self._load_config() assert isinstance(config, SamplableDatasetConfig) - return config.build() + return config.build(preprocessing) def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index 2a23e378b..b5bc5b7de 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -4,8 +4,8 @@ import numpy as np import torch -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -38,24 +38,27 @@ def __init__( self, name: str, prefix: pathlib.Path | str, + preprocessing: LanguageModelPreprocessingConfig, ): - self._init(name, prefix) + self._init(name, prefix, preprocessing) - def _init(self, name: str, prefix: pathlib.Path | str) -> None: + def _init(self, name: str, prefix: pathlib.Path | str, preprocessing: LanguageModelPreprocessingConfig) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False + has_loss_masking_spans = False + has_preference_spans = False + assert isinstance(preprocessing, LanguageModelPreprocessingConfig) + self._preprocessing = preprocessing with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack(" None: self._document_sizes = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) + assert not self._preprocessing.use_image_patches # read pointers self._pointers = np.frombuffer( @@ -79,8 +83,8 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: ) # read spans - self._spans = None - if self._has_spans and self._version >= 2: + if self._preprocessing.use_loss_masking_spans: + assert has_loss_masking_spans self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -101,9 +105,8 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: ) # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: + if has_preference_spans: + assert has_preference_spans self._chosen_spans = [] self._rejected_spans = [] chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes @@ -135,11 +138,12 @@ def _init(self, name: str, prefix: pathlib.Path | str) -> None: self._num_tokens = div(self._bin_buffer_mmap.size, self._dtype.itemsize) - def __getstate__(self) -> tuple[str, pathlib.Path]: - return (self._name, self._prefix) + def __getstate__(self) -> tuple[str, pathlib.Path, dict]: + return self._name, self._prefix, self._preprocessing.to_dict() - def __setstate__(self, state: tuple[str, pathlib.Path]): - self._init(*state) + def __setstate__(self, state: tuple[str, pathlib.Path, dict]): + name, prefix, preprocessing = state + self._init(name, prefix, LanguageModelPreprocessingConfig.from_dict(preprocessing)) def __del__(self): if hasattr(self, "_bin_buffer_mmap"): @@ -149,9 +153,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None - ) -> SampleType: + def get_document(self, index: int, begin: int = 0, end: int | None = None) -> SampleType: if end is None: end = self.get_document_size(index) sample_size = self._document_sizes[index].item() @@ -169,7 +171,7 @@ def get_document( if not self._dtype.is_signed: # Needed because torch doesn't yet support type promotion between signed and unsigned types. TODO: Remove when supported. token_ids = token_ids.to(torch.int64) - if parameters is not None and parameters.use_loss_masking_spans: + if self._preprocessing.use_loss_masking_spans: assert self._spans is not None # Convert to in range format (begin, end). sample_spans = RangeSample( @@ -178,7 +180,7 @@ def get_document( else: sample_spans = None - if parameters is not None and parameters.use_preference_loss_spans: + if self._preprocessing.use_preference_spans: if not self._has_preference_spans: raise ValueError("No preference spans found in memmap dataset.") elif self._has_preference_spans and self._chosen_spans is None: diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index f1e73c595..939b900e5 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -1,39 +1,27 @@ import numpy as np import torch -from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type -class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): - """ - A dummy dataset that always returns the same random sample, for debugging purposes. - """ - - def __init__(self, name: str): - self._name = name - - def sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset": - return GPTRandomSampledDataset(sampling, f"{self.name}_sampled") - - @property - def name(self) -> str: - return self._name - - class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed self._parameters = sampling.parameters - assert self._parameters.vocab_size is not None - # TODO: Support? - assert not self._parameters.use_loss_masking_spans - assert not self._parameters.use_preference_loss_spans - self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch + + assert isinstance(sampling.preprocessing, LanguageModelPreprocessingConfig) + assert not sampling.preprocessing.use_loss_masking_spans + assert not sampling.preprocessing.use_preference_spans + assert not sampling.preprocessing.use_image_patches + self._vocab_size = sampling.preprocessing.vocab_size + + self._dtype = get_unsigned_integer_type(self._vocab_size).torch def __len__(self) -> int: return self._parameters.num_samples @@ -45,7 +33,7 @@ def __getitem__(self, index: int) -> SampleType: torch.from_numpy( np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( 0, - self._parameters.vocab_size, + self._vocab_size, size=(self._parameters.sequence_length + self._parameters.extra_tokens,), ) ).to(self._dtype), diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index 4b1930dd3..4d75ca450 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -7,6 +7,7 @@ from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig, MemmapWriter, Sample FILE_HEADER = b"fast_llm_prepared_dataset" @@ -21,13 +22,15 @@ def __init__( self, name: str, path: pathlib.Path | str, + preprocessing: PreprocessingConfig, ): - self._init(name, path) + self._init(name, path, preprocessing) - def _init(self, name: str, path: pathlib.Path | str) -> None: + def _init(self, name: str, path: pathlib.Path | str, preprocessing: PreprocessingConfig) -> None: super().__init__() self._name = name self._path = path + self._preprocessing = preprocessing with self._path.open("rb") as stream: # Very file type. @@ -39,16 +42,19 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) ) + reader_config.preprocessing.check_compatibility(self._preprocessing) + self._memmap = np.memmap(self._path, mode="r") + # TODO: ====== Forward preprocessing config so the reader reads just what we need. self._reader = reader_config.get_reader(memoryview(self._memmap)) - def __getstate__(self) -> tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]: + def __getstate__(self) -> tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]: # We pass the reader config to force its import in data loader workers. - return self._name, self._path, self._reader.config + return self._name, self._path, self._preprocessing.to_dict(), self._reader.config - def __setstate__(self, state: tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]): - name, path, _ = state - self._init(name, path) + def __setstate__(self, state: tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]): + name, path, preprocessing, _ = state + self._init(name, path, PreprocessingConfig.from_dict(preprocessing)) def __del__(self): if hasattr(self, "_memmap"): @@ -81,7 +87,11 @@ def get_document_size(self, index: int) -> int: @classmethod def write_dataset( - cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter] + cls, + path: pathlib.Path, + documents: typing.Iterable[Sample], + writer_class: type[MemmapWriter], + preprocessing_config: PreprocessingConfig | None = None, ) -> MemmapIndexDatasetReaderConfig: # TODO: Match `writer_class` with `SampleType`? path.parent.mkdir(parents=True, exist_ok=True) @@ -93,7 +103,7 @@ def write_dataset( start = stream.tell() stream.seek(start + 8) # Write the data. - reader_config = writer_class.write_dataset(stream, documents) + reader_config = writer_class.write_dataset(stream, documents, preprocessing_config) # Write the reader config. config_offset = stream.tell() reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 94bab200e..d0628e08f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -27,6 +27,8 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import Tokenizer from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter @@ -194,6 +196,15 @@ def _prepare_shard( for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") ), LanguageModelWriter, + LanguageModelPreprocessingConfig( + tokenizer=self._config.tokenizer, + vocab_size=self._tokenizer.vocab_size, + image_patches=( + self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() + ), + has_loss_masking_spans=self._source_schema.has_loss_masking_span, + has_preference_spans=self._source_schema.has_preference_spans, + ), ) return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py index 8dbaa3626..dc8c88375 100644 --- a/fast_llm/data/preprocessing/abstract.py +++ b/fast_llm/data/preprocessing/abstract.py @@ -1,7 +1,10 @@ +import logging import typing from fast_llm.config import Config, config_class +logger = logging.getLogger(__name__) + @config_class(registry=True) class PreprocessingConfig(Config): @@ -18,6 +21,12 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return NullPreprocessingConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) + def check_compatibility(self, preprocessing: typing.Self) -> None: + """ + Check whether a dataset preprocessed with `self` can produce samples for a model that requires `preprocessing`. + """ + raise NotImplementedError() + @config_class(dynamic_type={PreprocessingConfig: "none"}) class NullPreprocessingConfig(PreprocessingConfig): @@ -26,3 +35,6 @@ class NullPreprocessingConfig(PreprocessingConfig): """ _abstract = False + + def check_compatibility(self, preprocessing: typing.Self) -> None: + logger.warning("Dataset preprocessing config not specified, could not check compatibility with the model.") diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index 22ec04d68..7c3d9d53b 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -58,6 +58,18 @@ class ImagePatchConfig(PreprocessingConfig): hint=FieldHint.optional, ) + def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.eq(self.height, preprocessing.height) + Assert.eq(self.width, preprocessing.width) + Assert.eq(self.do_resize, preprocessing.do_resize) + Assert.leq(self.max_image_height, preprocessing.max_image_height) + Assert.leq(self.max_image_width, preprocessing.max_image_width) + # None is used in the trainer to mark unknown value, so we can't do an equality check. TODO: Distinguish. + if preprocessing.image_break_token is not None: + Assert.eq(self.image_break_token, preprocessing.image_break_token) + if preprocessing.image_end_token is not None: + Assert.eq(self.image_end_token, preprocessing.image_end_token) + @property def num_channels(self) -> int: # assume 3 channels (RGB) for all images diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py new file mode 100644 index 000000000..d4e1235ae --- /dev/null +++ b/fast_llm/data/preprocessing/language_model.py @@ -0,0 +1,40 @@ +import functools +import typing + +from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.utils import Assert + + +@config_class(dynamic_type={PreprocessingConfig: "language_model"}) +class LanguageModelPreprocessingConfig(PreprocessingConfig): + tokenizer: TokenizerConfig = Field() + # We can't easily compare tokenizers, + # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, + # so we provide the vocab size and use it for compatibility checks. + vocab_size: int = Field() + image_patches: PreprocessingConfig = Field() + use_loss_masking_spans: bool = Field(default=False) + use_preference_spans: bool = Field(default=False) + + def _validate(self) -> None: + super()._validate() + Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) + + @functools.cached_property + def use_image_patches(self) -> bool: + return isinstance(self.image_patches, ImagePatchConfig) + + def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) + # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? + Assert.geq(self.vocab_size, preprocessing.vocab_size) + if preprocessing.use_loss_masking_spans: + assert self.use_loss_masking_spans + if preprocessing.use_preference_spans: + assert self.use_preference_spans + if preprocessing.use_image_patches: + assert self.use_image_patches + self.image_patches.check_compatibility(preprocessing.image_patches) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 356407541..a0d460d4c 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -1,3 +1,4 @@ +import functools import pathlib import typing @@ -69,9 +70,12 @@ def __init__(self, config: ConfigType): self.eod_id = self.tokenizer.eos_token_id self.bod_id = self.tokenizer.bos_token_id - @property + @functools.cached_property def vocab_size(self) -> int: - return len(self.tokenizer) + out = len(self.tokenizer) + if self._config.max_vocab_size is not None: + out = min(out, self._config.max_vocab_size) + return out @property def vocab(self) -> dict[str, int]: diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0db7d1c8a..973e29ad8 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -4,7 +4,7 @@ import typing from fast_llm.config import Config, Configurable, Field, config_class -from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -195,11 +195,16 @@ def get_document_size(self, index: int) -> int: class MemmapWriter(abc.ABC): - def __init__(self, stream: io.BufferedWriter | pathlib.Path): + def __init__( + self, stream: io.BufferedWriter | pathlib.Path, preprocessing_config: PreprocessingConfig | None = None + ): self._owns_stream = isinstance(stream, pathlib.Path) if self._owns_stream: stream = stream.open("wb") self._stream = stream + self._preprocessing_config = ( + NullPreprocessingConfig() if preprocessing_config is None else preprocessing_config + ) def __enter__(self): self._begin = self._stream.tell() @@ -230,8 +235,13 @@ def _get_config(self, begin: int, end: int): pass @classmethod - def write_dataset(cls, stream: io.BufferedWriter, documents: typing.Iterable[Sample]) -> MemmapReaderConfig: - with cls(stream) as writer: + def write_dataset( + cls, + stream: io.BufferedWriter, + documents: typing.Iterable[Sample], + preprocessing_config: PreprocessingConfig | None = None, + ) -> MemmapReaderConfig: + with cls(stream, preprocessing_config) as writer: for document in documents: writer.write(document) return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 88ca05b95..0e1baaef8 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -6,7 +6,9 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapIndexDatasetReaderConfig, @@ -135,6 +137,11 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): rejected_spans: MemmapReaderBaseConfig = Field() image_patches: MemmapReaderBaseConfig = Field() + def _validate(self) -> None: + super()._validate() + # Dynamic type supported for backward compatibility. + Assert.custom(isinstance, self.preprocessing, (LanguageModelPreprocessingConfig, NullPreprocessingConfig)) + def __len__(self) -> int: return len(self.tokens) @@ -201,9 +208,7 @@ def get_document_size(self, index: int) -> int: class LanguageModelWriter(MemmapWriter): - _has_loss_masking_spans: bool | None = None - _has_preference_spans: bool | None = None - _has_image_patches: bool | None = None + _preprocessing_config: LanguageModelPreprocessingConfig def __enter__(self): super().__enter__() @@ -214,10 +219,13 @@ def __enter__(self): self._path = pathlib.Path(self._directory.name) # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() - self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() - self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() - self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() - self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + if self._preprocessing_config.use_image_patches: + self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() return self def write(self, document: LanguageModelSample): @@ -225,58 +233,46 @@ def write(self, document: LanguageModelSample): # Write tokens. self._token_writer.write(document.tokens) - # Ensure either all samples have loss masking spans or none of them do. - if self._has_loss_masking_spans is None: - self._has_loss_masking_spans = document.loss_masking_spans is not None - else: - Assert.eq(self._has_loss_masking_spans, document.loss_masking_spans is not None) - # Write loss masking spans. - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: + assert document.loss_masking_spans is not None self._loss_masking_span_writer.write(document.loss_masking_spans) - # All sample must either have both chosen and rejected spans, or neither. - if self._has_preference_spans is None: - self._has_preference_spans = document.chosen_spans is not None - else: - Assert.eq(self._has_preference_spans, document.chosen_spans is not None) - Assert.eq(self._has_preference_spans, document.rejected_spans is not None) - # Write preference spans. - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: + assert document.chosen_spans is not None + assert document.rejected_spans is not None self._chosen_spans_writer.write(document.chosen_spans) self._rejected_spans_writer.write(document.rejected_spans) - # Ensure either all samples have image patches or none of them do. - if self._has_image_patches is None: - self._has_image_patches = document.image_patches is not None - else: - Assert.eq(self._has_image_patches, document.image_patches is not None) - # Write image patches - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: + assert document.image_patches is not None self._image_patches_writer.write(document.image_patches) def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) - self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) - self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) - self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) - self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_loss_masking_spans: + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_preference_spans: + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_image_patches: + self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. config = self._get_config(self._begin, None) _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: _copy_chunked( self._path.joinpath("loss_masking_spans"), self._stream, config.loss_masking_spans.begin, config.loss_masking_spans.end, ) - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: _copy_chunked( self._path.joinpath("chosen_spans"), self._stream, @@ -290,7 +286,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): config.rejected_spans.end, ) - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: _copy_chunked( self._path.joinpath("image_patches"), self._stream, @@ -308,12 +304,12 @@ def _get_config_class(cls) -> type[LanguageModelReaderConfig]: def _get_config(self, begin: int, end: int | None): tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) offset = tokens.end - if self._has_loss_masking_spans: + if self._preprocessing_config.use_loss_masking_spans: loss_masking_spans = self._loss_masking_span_writer.get_config(offset) offset = loss_masking_spans.end else: loss_masking_spans = NullReaderConfig() - if self._has_preference_spans: + if self._preprocessing_config.use_preference_spans: chosen_spans = self._chosen_spans_writer.get_config(offset) offset = chosen_spans.end rejected_spans = self._rejected_spans_writer.get_config(offset) @@ -321,7 +317,7 @@ def _get_config(self, begin: int, end: int | None): else: chosen_spans = NullReaderConfig() rejected_spans = NullReaderConfig() - if self._has_image_patches: + if self._preprocessing_config.use_image_patches: image_patches = self._image_patches_writer.get_config(offset) offset = image_patches.end else: @@ -338,6 +334,7 @@ def _get_config(self, begin: int, end: int | None): chosen_spans=chosen_spans, rejected_spans=rejected_spans, image_patches=image_patches, + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 9d27d37cd..a75684d76 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -300,4 +300,5 @@ def _get_config(self, begin: int, end: int): num_patch_groups=self._group_count_cumsum[-1], patch_shape=self._patch_shape, data_type=DataType.from_torch(self._data_type), + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index b7be4efe1..0022b3593 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -135,4 +135,5 @@ def _get_config(self, begin: int, end: int): end=end, num_documents=len(self._count_cumsum) - 1, num_ranges=self._count_cumsum[-1], + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 9fedf12b5..3f5912e5e 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -166,4 +166,5 @@ def _get_config(self, begin: int, end: int): num_documents=len(self._size_cumsum) - 1, num_tokens=self._size_cumsum[-1], data_type=DataType.from_torch(self._data_type), + preprocessing_config=self._preprocessing_config, ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index aa4f2d570..dd106f35c 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -12,6 +12,7 @@ from fast_llm.core.distributed import allreduce_scalar, safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig, PreprocessingConfig from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -239,6 +240,7 @@ def setup(self, distributed: Distributed, run: Run) -> None: ) for eval_sampling_params in self._evaluator_runner.get_sampling_parameters() }, + self._get_preprocessing_config(), None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) @@ -261,10 +263,13 @@ def _get_data(self) -> Data: pass def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], _return_dict: bool = False + self, parameters: dict[str, typing.Any], *, _return_dict: bool = False ) -> SamplingParameters | dict[str, typing.Any]: return parameters if _return_dict else SamplingParameters(**parameters) + def _get_preprocessing_config(self, *, _return_dict: bool = False) -> PreprocessingConfig | dict[str, typing.Any]: + return {} if _return_dict else NullPreprocessingConfig() + @property def _consumed_samples(self) -> int: assert self._is_setup diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b8fb22ebb..1c7be33dd 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,6 +3,7 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -17,18 +18,30 @@ def _get_data(self) -> GPTData: ) def _get_sampling_parameters( - self, parameters: dict[str, typing.Any], _return_dict: bool = False + self, parameters: dict[str, typing.Any], *, _return_dict: bool = False ) -> GPTSamplingParameters | dict[str, typing.Any]: parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.embeddings.vocab_size, + # "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, - "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + # "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. - "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + # "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) + + def _get_preprocessing_config( + self, *, _return_dict: bool = False + ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = { + "type": "language_model", + "vocab_size": self._config.model.base_model.embeddings.vocab_size, + "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + # OK since DPO is not supported for MTP. + "use_preference_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + } + return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py index 2beee1097..cd8e09cae 100644 --- a/fast_llm/models/multimodal/trainer.py +++ b/fast_llm/models/multimodal/trainer.py @@ -1,5 +1,7 @@ import logging +import typing +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.trainer import GPTTrainer from fast_llm.models.multimodal.config import MultiModalTrainerConfig @@ -7,4 +9,17 @@ class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): - pass + def _get_preprocessing_config( + self, *, _return_dict: bool = False + ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = super()._get_preprocessing_config(_return_dict=True) + out["image_patches"] = { + "height": self._config.model.base_model.vision_encoder.embeddings.patch_height, + "width": self._config.model.base_model.vision_encoder.embeddings.patch_width, + # TODO: Max shape and special tokens are unspecified in the model. + "max_image_height": 2**32, + "max_image_width": 2**32, + "image_break_token": None, + "image_end_token": None, + } + return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index f3ce65966..e2cadf717 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -15,6 +15,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -122,8 +123,8 @@ class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig hint=FieldHint.core, ) - def build(self) -> "LegacyMemmapDataset[SampleType]": - return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path) + def build(self, preprocessing: PreprocessingConfig) -> "LegacyMemmapDataset[SampleType]": + return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path, preprocessing) class MegatronMemmapDataset(LegacyMemmapDataset): From 350fb3df7f877549039c4e2d1ffda7fbf9f03e76 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 00:30:55 -0500 Subject: [PATCH 04/14] stuff --- fast_llm/data/data/gpt/data.py | 7 +- fast_llm/data/dataset/gpt/config.py | 20 +-- fast_llm/data/dataset/gpt/fim.py | 4 +- fast_llm/data/dataset/gpt/legacy_memmap.py | 19 ++- fast_llm/data/dataset/memmap.py | 5 +- .../data/preparator/gpt_memmap/prepare.py | 34 +++-- fast_llm/data/preprocessing/abstract.py | 4 +- fast_llm/data/preprocessing/image_patch.py | 5 + fast_llm/data/preprocessing/language_model.py | 15 +- fast_llm/data/preprocessing/tokenizer.py | 22 +-- fast_llm/data/sample/abstract.py | 17 +-- fast_llm/data/sample/language_model.py | 136 +++++++++++++++--- fast_llm/data/sample/patch.py | 27 +++- fast_llm/data/sample/range.py | 12 +- fast_llm/data/sample/token.py | 7 +- fast_llm/models/gpt/trainer.py | 10 +- tests/data/common.py | 21 ++- tests/data/test_blending.py | 18 ++- tests/data/test_concatenate.py | 8 +- tests/data/test_fim.py | 5 +- tests/data/test_image_patch.py | 13 +- tests/data/test_loss_masking_spans.py | 17 ++- tests/data/test_preference_spans.py | 17 ++- tests/data/test_preparator.py | 25 ++-- tests/data/test_random.py | 6 +- tests/data/test_sampling.py | 23 ++- tests/data/test_slice.py | 7 +- tests/models/test_match_megatron.py | 2 +- tests/utils/dataset.py | 53 +++++-- tests/utils/model_configs.py | 2 +- 30 files changed, 364 insertions(+), 197 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 084dadc7d..dbd770895 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -10,7 +10,8 @@ from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig @@ -30,7 +31,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ _datasets: dict[str, SampledDataset] - _sampling_parameters: dict[str, GPTSamplingParameters] + _sampling_parameters: dict[str, SamplingParameters] _is_setup: bool = False def __init__( @@ -47,7 +48,7 @@ def __init__( def setup( self, distributed: "Distributed", - sampling_parameters: dict[str, GPTSamplingParameters], + sampling_parameters: dict[str, SamplingParameters], preprocessing: LanguageModelPreprocessingConfig, cache_directory: pathlib.Path, timeout: float | None = None, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 4336657ce..fc326d366 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -7,31 +7,18 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset - from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample -@dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): - """ - Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. - """ - - # TODO: ====== Get these to memmap dataset (currently ignored) ====== - vocab_size: int | None = None - use_loss_masking_spans: bool = False - use_preference_loss_spans: bool = False - use_images: bool = False - - @dataclasses.dataclass(kw_only=True) class GPTSamplingData(SamplingData): """ @@ -39,7 +26,6 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - parameters: GPTSamplingParameters preprocessing: LanguageModelPreprocessingConfig @@ -52,7 +38,7 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConf hint=FieldHint.core, ) - def build_and_sample(self, sampling: GPTSamplingData) -> GPTRandomSampledDataset[SampleType]: + def build_and_sample(self, sampling: GPTSamplingData) -> "GPTRandomSampledDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomSampledDataset return GPTRandomSampledDataset[SampleType](sampling, self.name) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index d36384ee5..b70fc8360 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -21,9 +21,9 @@ def __init__( dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): - if sampling.parameters.use_loss_masking_spans: + if sampling.preprocessing.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") - if sampling.parameters.use_preference_loss_spans: + if sampling.preprocessing.use_preference_spans: raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index b5bc5b7de..d29e31596 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -105,7 +105,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, preprocessing: LanguageMo ) # read preference spans - if has_preference_spans: + if self._preprocessing.use_preference_spans: assert has_preference_spans self._chosen_spans = [] self._rejected_spans = [] @@ -173,20 +173,17 @@ def get_document(self, index: int, begin: int = 0, end: int | None = None) -> Sa token_ids = token_ids.to(torch.int64) if self._preprocessing.use_loss_masking_spans: assert self._spans is not None - # Convert to in range format (begin, end). - sample_spans = RangeSample( - [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size - ).crop(begin, end) + if hasattr(self, "_spans"): + # Convert to in range format (begin, end). + sample_spans = RangeSample( + [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size + ).crop(begin, end) + else: + sample_spans = RangeSample([], end - begin) else: sample_spans = None if self._preprocessing.use_preference_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") # Convert to in range format (begin, end). chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index 4d75ca450..f80a48b0a 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -42,11 +42,8 @@ def _init(self, name: str, path: pathlib.Path | str, preprocessing: Preprocessin json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) ) - reader_config.preprocessing.check_compatibility(self._preprocessing) - self._memmap = np.memmap(self._path, mode="r") - # TODO: ====== Forward preprocessing config so the reader reads just what we need. - self._reader = reader_config.get_reader(memoryview(self._memmap)) + self._reader = reader_config.get_reader(memoryview(self._memmap), self._preprocessing) def __getstate__(self) -> tuple[str, pathlib.Path, dict, MemmapIndexDatasetReaderConfig]: # We pass the reader config to force its import in data loader workers. diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d0628e08f..91506e4d5 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,5 +1,6 @@ import collections import enum +import functools import json import logging import math @@ -196,18 +197,22 @@ def _prepare_shard( for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") ), LanguageModelWriter, - LanguageModelPreprocessingConfig( - tokenizer=self._config.tokenizer, - vocab_size=self._tokenizer.vocab_size, - image_patches=( - self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() - ), - has_loss_masking_spans=self._source_schema.has_loss_masking_span, - has_preference_spans=self._source_schema.has_preference_spans, - ), + self._preprocessing_config, ) return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config + @functools.cached_property + def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: + return LanguageModelPreprocessingConfig( + tokenizer=self._config.tokenizer, + vocab_size=self._tokenizer.vocab_size, + image_patches=( + self._config.image_patches if self._source_schema.has_images else NullPreprocessingConfig() + ), + use_loss_masking_spans=self._source_schema.has_loss_masking_span, + use_preference_spans=self._source_schema.has_preference_spans, + ) + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: text = sample[self._source_schema.text] all_spans = [] @@ -385,9 +390,8 @@ def _blend_dataset_configs( } ) - @classmethod def _split_and_blend_dataset_configs( - cls, + self, dataset_configs: list[MemmapDatasetConfig[_sample_type]], reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], @@ -422,14 +426,16 @@ def _split_and_blend_dataset_configs( elif split_end_in_dataset > split_begin_in_dataset: # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). - dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() + dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build( + self._preprocessing_config + ) sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) if end_index > begin_index: datasets_in_split.append( - DatasetSliceConfig[cls._sample_type].from_dict( + DatasetSliceConfig[self._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -451,7 +457,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[self._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/preprocessing/abstract.py b/fast_llm/data/preprocessing/abstract.py index dc8c88375..ea1f910df 100644 --- a/fast_llm/data/preprocessing/abstract.py +++ b/fast_llm/data/preprocessing/abstract.py @@ -1,5 +1,6 @@ import logging import typing +import warnings from fast_llm.config import Config, config_class @@ -37,4 +38,5 @@ class NullPreprocessingConfig(PreprocessingConfig): _abstract = False def check_compatibility(self, preprocessing: typing.Self) -> None: - logger.warning("Dataset preprocessing config not specified, could not check compatibility with the model.") + if not isinstance(preprocessing, NullPreprocessingConfig): + warnings.warn(f"Preprocessing configuration not specified, could not check compatibility with the model.") diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py index 7c3d9d53b..146c5809b 100644 --- a/fast_llm/data/preprocessing/image_patch.py +++ b/fast_llm/data/preprocessing/image_patch.py @@ -59,6 +59,7 @@ class ImagePatchConfig(PreprocessingConfig): ) def check_compatibility(self, preprocessing: typing.Self) -> None: + Assert.custom(isinstance, preprocessing, ImagePatchConfig) Assert.eq(self.height, preprocessing.height) Assert.eq(self.width, preprocessing.width) Assert.eq(self.do_resize, preprocessing.do_resize) @@ -75,6 +76,10 @@ def num_channels(self) -> int: # assume 3 channels (RGB) for all images return 3 + @functools.cached_property + def patch_shape(self) -> tuple[int, int, int]: + return self.num_channels, self.height, self.width + @functools.cached_property def max_patches_height(self) -> int: return div(self.max_image_height, self.height) diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d4e1235ae..6c38c3f4e 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -1,4 +1,5 @@ import functools +import logging import typing from fast_llm.config import Field, config_class @@ -7,21 +8,25 @@ from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + @config_class(dynamic_type={PreprocessingConfig: "language_model"}) class LanguageModelPreprocessingConfig(PreprocessingConfig): - tokenizer: TokenizerConfig = Field() + _abstract = False + tokenizer: PreprocessingConfig = Field() # We can't easily compare tokenizers, # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, # so we provide the vocab size and use it for compatibility checks. - vocab_size: int = Field() image_patches: PreprocessingConfig = Field() + vocab_size: int = Field() use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) def _validate(self) -> None: super()._validate() Assert.custom(isinstance, self.image_patches, (ImagePatchConfig, NullPreprocessingConfig)) + Assert.custom(isinstance, self.tokenizer, (TokenizerConfig, NullPreprocessingConfig)) @functools.cached_property def use_image_patches(self) -> bool: @@ -31,10 +36,8 @@ def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? Assert.geq(self.vocab_size, preprocessing.vocab_size) - if preprocessing.use_loss_masking_spans: - assert self.use_loss_masking_spans if preprocessing.use_preference_spans: + # Preference spans are strictly needed for DPO loss. assert self.use_preference_spans - if preprocessing.use_image_patches: - assert self.use_image_patches + if preprocessing.use_image_patches and self.use_image_patches: self.image_patches.check_compatibility(preprocessing.image_patches) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index a0d460d4c..9e11fa66c 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -39,8 +39,6 @@ class TokenizerConfig(PreprocessingConfig): ) def get_tokenizer(self) -> "Tokenizer": - from fast_llm.data.preprocessing.tokenizer import Tokenizer - return Tokenizer(self) @@ -90,14 +88,20 @@ def tokenize( ) -> "torch.Tensor": import torch - tokens = torch.tensor( - ([self.bod_id] if begin else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end else []), - dtype=data_type.torch, - ) + tokens = self.tokenizer.encode(text, add_special_tokens=False) + if begin: + tokens.insert(0, self.bod_id) + if end: + tokens.append(self.eod_id) + if self._config.max_vocab_size is not None: - tokens %= self._config.max_vocab_size + # In some cases creating a tensor before restricting the vocab size may cause an overflow. + ( + torch.tensor(tokens, dtype=torch.int64 if len(self.tokenizer) > torch.iinfo().max else data_type.torch) + % self._config.max_vocab_size + ).to(data_type.torch) + else: + tokens = torch.tensor(tokens, dtype=data_type.torch) return tokens def tokenize_with_spans( diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 973e29ad8..3fba789d1 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -109,8 +109,8 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): def reader_class(self) -> "type[MemmapReader]": raise NotImplementedError() - def get_reader(self, buffer: memoryview) -> "MemmapReader": - return self.reader_class(self, buffer) + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None) -> "MemmapReader": + return self.reader_class(self, buffer, model_preprocessing) @property def expected_buffer_size(self) -> int: @@ -156,16 +156,17 @@ def num_tokens(self) -> int: def reader_class(self) -> "type[MemmapIndexedDatasetReader]": raise NotImplementedError() - def get_reader( - self, - buffer: memoryview, - ) -> "MemmapIndexedDatasetReader": - return self.reader_class(self, buffer) + def get_reader(self, buffer: memoryview, model_preprocessing: PreprocessingConfig) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer, model_preprocessing) class MemmapReader[ConfigType: MemmapReaderConfig](Configurable[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): super().__init__(config) + # Note: This is the requirement at reading time (ex. from the model), + # which may differ from how the dataset was actually preprocessed (`config.preprocessing`) + # Compatibility checked in `MemmapDataset`. + self._model_preprocessing = NullPreprocessingConfig if model_preprocessing is None else model_preprocessing buffer_begin = self._config.begin + len(self._config.header) buffer_end = self._config.end - len(self._config.footer) Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 0e1baaef8..1331cf82a 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -1,13 +1,15 @@ import io +import logging import pathlib import tempfile import typing +import warnings import torch from fast_llm.config import Field, config_class from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig -from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig +from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig, ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, @@ -18,11 +20,14 @@ NullReaderConfig, Sample, ) -from fast_llm.data.sample.patch import PatchBatch, PatchSample, PatchWriter -from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter +from fast_llm.data.sample.patch import EmptyPatchReader, PatchBatch, PatchReaderConfig, PatchSample, PatchWriter +from fast_llm.data.sample.range import EmptyRangeReader, RangeBatch, RangeReaderConfig, RangeSample, RangeWriter from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class LanguageModelSample(Sample): def __init__( @@ -139,8 +144,45 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): def _validate(self) -> None: super()._validate() - # Dynamic type supported for backward compatibility. - Assert.custom(isinstance, self.preprocessing, (LanguageModelPreprocessingConfig, NullPreprocessingConfig)) + if isinstance(self.preprocessing, NullPreprocessingConfig): + # Address missing config, mostly for backward compatibility. + # TODO: We can't tell which dataset this comes from. + logger.warning( + f"Preprocessing configuration not specified for dataset reader, generating partial configuration from known parameters." + ) + if isinstance(self.image_patches, PatchReaderConfig): + Assert.eq(len(patch_shape := self.image_patches.patch_shape), 3) + image_patches = ImagePatchConfig(height=patch_shape[1], width=patch_shape[2]) + else: + image_patches = NullPreprocessingConfig() + self.preprocessing = LanguageModelPreprocessingConfig( + vocab_size=0, + image_patches=image_patches, + use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), + use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), + ) + # TODO: Avoid duplicated information. + Assert.custom( + isinstance, + self.loss_masking_spans, + RangeReaderConfig if self.preprocessing.use_loss_masking_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.chosen_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.rejected_spans, + RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, + ) + if self.preprocessing.use_image_patches: + Assert.custom(isinstance, self.image_patches, PatchReaderConfig) + Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) + Assert.eq(self.image_patches.data_type, DataType.uint8) + else: + Assert.custom(isinstance, self.image_patches, NullReaderConfig) def __len__(self) -> int: return len(self.tokens) @@ -169,17 +211,59 @@ def _expected_buffer_size(self) -> int: class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + _model_preprocessing: LanguageModelPreprocessingConfig + + def __init__( + self, + config: ConfigType, + buffer: memoryview, + model_preprocessing: LanguageModelPreprocessingConfig | None = None, + ): + super().__init__(config, buffer, model_preprocessing) + self._config.preprocessing.check_compatibility(self._model_preprocessing) # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. self._tokens = self._config.tokens.get_reader(buffer) - self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) - self._chosen_spans = self._config.chosen_spans.get_reader(buffer) - self._rejected_spans = self._config.rejected_spans.get_reader(buffer) - self._image_patches = self._config.image_patches.get_reader(buffer) - if self._image_patches is not None: - # TODO: Make this configurable. + if self._model_preprocessing.use_loss_masking_spans: + if isinstance(self._config.loss_masking_spans, NullReaderConfig): + # TODO: We can't tell which dataset this comes from. + warnings.warn( + f"The model uses loss masking spans, but the dataset does not specify any." + " Assuming empty span lists." + ) + self._loss_masking_spans = EmptyRangeReader( + RangeReaderConfig(begin=0, end=0, num_documents=0, num_ranges=0), buffer + ) + else: + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + + if self._model_preprocessing.use_preference_spans: + self._chosen_spans = self._config.chosen_spans.get_reader(buffer) + self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + + if self._model_preprocessing.use_image_patches: + model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches + if isinstance(self._config.image_patches, NullReaderConfig): + warnings.warn( + f"The model uses image patches, but the dataset does not specify any." + " Assuming empty patch lists." + ) + self._image_patches = EmptyPatchReader( + PatchReaderConfig( + begin=0, + end=0, + num_documents=0, + num_patches=0, + num_patch_groups=0, + patch_shape=model_image_preprocessing.patch_shape, + data_type=DataType.uint8, + ), + buffer, + ) + else: + self._image_patches = self._config.image_patches.get_reader(buffer) + + # TODO: Make this configurable. (Add to `model_preprocessing`?) self._image_normalization_config = ImageNormalizationConfig() @property @@ -187,16 +271,28 @@ def num_tokens(self) -> int: return self._config.tokens.num_tokens def get_document(self, index: int, begin: int, end: int) -> Sample: - if self._image_patches is None: - image_patches = None - else: + if self._model_preprocessing.use_image_patches: image_patches = self._image_patches.get_document(index, begin, end) image_patches.patches = self._image_normalization_config.normalize(image_patches.patches) + else: + image_patches = None return LanguageModelSample( self._tokens.get_document(index, begin, end), - None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), - None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), - None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), + ( + self._loss_masking_spans.get_document(index, begin, end) + if self._model_preprocessing.use_loss_masking_spans + else None + ), + ( + self._chosen_spans.get_document(index, begin, end) + if self._model_preprocessing.use_preference_spans + else None + ), + ( + self._rejected_spans.get_document(index, begin, end) + if self._model_preprocessing.use_preference_spans + else None + ), image_patches, ) @@ -334,7 +430,7 @@ def _get_config(self, begin: int, end: int | None): chosen_spans=chosen_spans, rejected_spans=rejected_spans, image_patches=image_patches, - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index a75684d76..9ec991cf0 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -5,6 +5,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapReader, @@ -91,6 +92,16 @@ def get_padding(self, size: int) -> typing.Self: [], ) + @classmethod + def get_empty(cls, size: int, shape: tuple[int, ...]) -> typing.Self: + return PatchSample( + self.patches.new_empty((0, *shape[1:])), + self.token_map.new_empty(0), + self.positions.new_empty([0, len(shape) - 2]), + size, + [], + ) + class PatchBatch(Batch): def __init__( @@ -188,8 +199,8 @@ def _expected_buffer_size(self) -> int: class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._patches = torch.frombuffer( self._buffer, dtype=self._config.data_type.torch, @@ -248,6 +259,16 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: ) +class EmptyPatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + return PatchSample( + torch.empty(0, *self._config.patch_shape, dtype=self._config.data_type.torch), + torch.empty(0, dtype=torch.int32), + torch.empty(0, self._config.grid_dims, dtype=torch.int32), + end - begin, + ) + + class PatchWriter(MemmapWriter): def __enter__(self): super().__enter__() @@ -300,5 +321,5 @@ def _get_config(self, begin: int, end: int): num_patch_groups=self._group_count_cumsum[-1], patch_shape=self._patch_shape, data_type=DataType.from_torch(self._data_type), - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 0022b3593..f34cc1343 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -4,6 +4,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapReader, @@ -85,8 +86,8 @@ def _expected_buffer_size(self) -> int: class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._ranges = torch.frombuffer( self._buffer, dtype=torch.int32, @@ -108,6 +109,11 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) +class EmptyRangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + return RangeSample([], end - begin) + + class RangeWriter(MemmapWriter): def __enter__(self): super().__enter__() @@ -135,5 +141,5 @@ def _get_config(self, begin: int, end: int): end=end, num_documents=len(self._count_cumsum) - 1, num_ranges=self._count_cumsum[-1], - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 3f5912e5e..04898a12f 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -4,6 +4,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, MemmapIndexedDatasetReader, @@ -111,8 +112,8 @@ def _expected_buffer_size(self) -> int: class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): - def __init__(self, config: ConfigType, buffer: memoryview): - super().__init__(config, buffer) + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) self._tokens = torch.frombuffer( self._buffer, dtype=self._config.data_type.torch, @@ -166,5 +167,5 @@ def _get_config(self, begin: int, end: int): num_documents=len(self._size_cumsum) - 1, num_tokens=self._size_cumsum[-1], data_type=DataType.from_torch(self._data_type), - preprocessing_config=self._preprocessing_config, + preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 1c7be33dd..768d3fdd7 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -2,7 +2,7 @@ import typing from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -19,20 +19,16 @@ def _get_data(self) -> GPTData: def _get_sampling_parameters( self, parameters: dict[str, typing.Any], *, _return_dict: bool = False - ) -> GPTSamplingParameters | dict[str, typing.Any]: + ) -> SamplingParameters | dict[str, typing.Any]: parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - # "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, - # "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - # OK since DPO is not supported for MTP. - # "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) - return parameters if _return_dict else GPTSamplingParameters(**parameters) + return parameters if _return_dict else SamplingParameters(**parameters) def _get_preprocessing_config( self, *, _return_dict: bool = False diff --git a/tests/data/common.py b/tests/data/common.py index ac8d8023c..210749864 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,10 +8,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, ShufflingType -from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, SamplingParameters, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -25,10 +26,10 @@ def get_sampling_data( cache_directory: pathlib.Path | None = None, phase=PhaseType.training, sequence_length: int = 512, - vocab_size: int | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, + preprocessing: LanguageModelPreprocessingConfig, ) -> GPTSamplingData: # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) @@ -38,12 +39,12 @@ def get_sampling_data( gpu=gpu, shuffle=shuffle, ), - parameters=GPTSamplingParameters( + parameters=SamplingParameters( num_samples=num_samples, sequence_length=sequence_length, - vocab_size=vocab_size, truncate_documents=truncate_documents, ), + preprocessing=preprocessing, cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, @@ -65,8 +66,8 @@ def get_test_data_and_compare_samples( shuffle: ShufflingType = ShufflingType.epoch, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - vocab_size: int | None = None, expected_samples: dict[str, list[list[int]]] | list[list[int]], + preprocessing: LanguageModelPreprocessingConfig, ) -> GPTData: distributed_config = DistributedConfig(seed=87522) distributed = Distributed(distributed_config, use_cpu=True) @@ -74,11 +75,7 @@ def get_test_data_and_compare_samples( samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} sampling_parameters = { - dataset_name: GPTSamplingParameters( - num_samples=num_samples, - sequence_length=sequence_length, - vocab_size=vocab_size, - ) + dataset_name: SamplingParameters(num_samples=num_samples, sequence_length=sequence_length) for dataset_name, num_samples in samples_per_dataset.items() } @@ -88,7 +85,7 @@ def get_test_data_and_compare_samples( assert "sampling" not in config config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) - data.setup(distributed, sampling_parameters, cache_directory) + data.setup(distributed, sampling_parameters, preprocessing, cache_directory) with NoAutoValidate(): batch_config = GPTBatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 88ecf2c99..5cad573ca 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,6 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( @@ -84,7 +85,7 @@ def test_blending(probs): # Use a list of integers as a mock dataset, encoding both indexes in the sample. [list(range(i * num_samples, (i + 1) * num_samples)) for i, _ in enumerate(probs)], # noqa probs, - get_sampling_data(num_samples), + get_sampling_data(num_samples, preprocessing=LanguageModelPreprocessingConfig(vocab_size=8192)), ) probs = normalize_probabilities(probs) samples = np.array([dataset[i] for i in range(num_samples)]) @@ -106,8 +107,8 @@ def test_blending(probs): def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() - _, alt_config, _ = get_alt_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() + _, alt_config, _, _ = get_alt_test_dataset() sampled = get_dataset_config( dataset_config := { "type": "blended", @@ -115,7 +116,7 @@ def test_gpt_blended(): "weights": [0.75, 0.25], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) # Test in data. @@ -124,12 +125,15 @@ def test_gpt_blended(): 8, sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, + preprocessing=preprocessing, ) def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() + # Random dataset needs an explicit vocab size. + preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 8192}) sampled = get_dataset_config( dataset_config := { "type": "blended", @@ -140,7 +144,7 @@ def test_gpt_blended_mixed(): "weights": [0.6, 0.4], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) # Test in data. @@ -148,6 +152,6 @@ def test_gpt_blended_mixed(): {"datasets": {"training": dataset_config}}, 8, sequence_length=5, - vocab_size=8192, expected_samples=GPT_BLENDED_MIXED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index d7e750c8b..1580842b7 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,6 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset_tokens, @@ -25,19 +26,19 @@ def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], - ).build() + ).build(LanguageModelPreprocessingConfig(vocab_size=0)) compare_indexed_dataset_tokens( dataset, 3 * COMMON_DATASET_LENGTH, 3 * COMMON_DATASET_TOKENS, {j * COMMON_DATASET_LENGTH + i: sample for j in range(3) for i, sample in COMMON_DATASET_SAMPLES.items()}, ) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) + sampled = dataset.sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) compare_sampled_dataset(sampled, GPT_CONCATENATED_SAMPLES) # Test in data. @@ -46,4 +47,5 @@ def test_gpt_concatenate(): 8, sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 0600c5258..fd1aefbd8 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -22,9 +22,9 @@ def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. - sampling_config = get_sampling_data(8, sequence_length=5) + sampling_config = get_sampling_data(8, sequence_length=5, preprocessing=preprocessing) sampled = get_dataset_config( dataset_config := { "type": "fim", @@ -45,4 +45,5 @@ def test_gpt_fim(): 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py index 86fe9c70a..9ef20a8a6 100644 --- a/tests/data/test_image_patch.py +++ b/tests/data/test_image_patch.py @@ -6,7 +6,8 @@ import PIL.Image import pytest -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert @@ -123,8 +124,10 @@ def _get_image_tokens( @pytest.mark.parametrize("image_break_token", (None, 55)) @pytest.mark.parametrize("image_end_token", (None, 132)) def test_gpt_data_with_image_patches(image_break_token, image_end_token): - _, config, hf_path = get_test_dataset_with_image_patches(image_break_token, image_end_token) - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_image_patches(image_break_token, image_end_token) + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) test_index = 2 * (image_break_token is not None) + (image_end_token is not None) hf_dataset = datasets.load_from_disk(hf_path)["train"] @@ -146,9 +149,7 @@ def test_gpt_data_with_image_patches(image_break_token, image_end_token): ) Assert.eq(hf_dataset[index]["image_positions"], DATASET_WITH_IMAGE_PATCHES_IMAGE_POSITIONS[index]) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_images=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) expected_tokens = [ tokens for token_or_patches in DATASET_WITH_IMAGE_PATCHES_SAMPLES[index] diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 443a26819..2d112a5c1 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -1,7 +1,8 @@ import datasets import pytest -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -37,8 +38,10 @@ @pytest.mark.slow def test_gpt_data_with_spans(): - _, config, hf_path = get_test_dataset_with_loss_masking_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_loss_masking_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -54,9 +57,7 @@ def test_gpt_data_with_spans(): hf_dataset[index]["text"], text_spans=[(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]], ) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) # Compare tokens and token spans. Assert.all_equal(document.tokens.tokens, expected_tokens) @@ -73,8 +74,6 @@ def test_gpt_data_with_spans(): for index in DATASET_WITH_SPAN_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) Assert.eq(hf_dataset[index]["loss_masking_spans"], HF_LOSS_MASKING_SPANS[index]) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_SPAN_SAMPLES[index]) Assert.eq(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index ef18337eb..35c290670 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -3,7 +3,8 @@ import pytest import torch -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample @@ -39,8 +40,10 @@ @pytest.mark.slow def test_gpt_data_with_spans(): - _, config, hf_path = get_test_dataset_with_preference_spans() - dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + _, config, hf_path, preprocessing = get_test_dataset_with_preference_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build( + preprocessing + ) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -79,9 +82,7 @@ def test_gpt_data_with_spans(): (token_length_cumsum[4], token_length_cumsum[5]), ] - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) token_spans = document.chosen_spans.ranges + document.rejected_spans.ranges # Compare tokens and token spans. @@ -100,8 +101,6 @@ def test_gpt_data_with_spans(): DATASET_WITH_PREFERENCE_SPAN_TEXT[index], ) - document = dataset.get_document( - index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) - ) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) Assert.eq(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 729888d9c..dd4375418 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -3,10 +3,11 @@ import datasets import pytest -from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import BlendedDatasetConfig, MemmapDatasetConfig, SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.utils import Assert from tests.data.common import get_dataset_config @@ -42,11 +43,11 @@ def test_common_prepared_dataset(): We already test the dataset preparator indirectly through the test dataset (`get_test_dataset`). Here we verify the correctness of the prepared dataset directly and check for regressions. """ - path, config, hf_path = get_common_test_dataset() - dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build() + path, config, hf_path, preprocessing = get_common_test_dataset() + dataset = get_dataset_config(config, GPTDatasetFromFileConfig).build(preprocessing) dataset_from_shard = get_dataset_config( {"type": "memmap", "path": path / "shard_0_0.fast_llm_dataset"}, MemmapDatasetConfig - ).build() + ).build(preprocessing) hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() @@ -71,18 +72,18 @@ def test_common_prepared_dataset(): # Check some numerical values. for index in COMMON_DATASET_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) - document = dataset.get_document(index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0)) + document = dataset.get_document(index, parameters=SamplingParameters(num_samples=0, sequence_length=0)) Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) @pytest.mark.slow def test_preparator_sharded(): - path, config, hf_path = get_sharded_test_dataset() + path, config, hf_path, preprocessing = get_sharded_test_dataset() dataset_config = get_dataset_config(config, GPTDatasetFromFileConfig)._load_config() Assert.custom(isinstance, dataset_config, BlendedDatasetConfig) Assert.eq(dataset_config.weights, [0.33003587104248827, 0.3455874161709333, 0.3243767127865784]) - datasets_ = [dataset_config_.build() for dataset_config_ in dataset_config.datasets] + datasets_ = [dataset_config_.build(preprocessing) for dataset_config_ in dataset_config.datasets] Assert.eq([len(dataset) for dataset in datasets_], lengths := [334, 333, 333]) Assert.eq([dataset.num_tokens for dataset in datasets_], [14813, 15511, 14559]) @@ -101,7 +102,7 @@ def test_preparator_sharded(): @pytest.mark.slow def test_preparator_split(): - path, config, hf_path = get_split_test_dataset() + path, config, hf_path, _ = get_split_test_dataset() dataset_config = { split: get_dataset_config(split_config, GPTDatasetFromFileConfig)._load_config().to_dict() for split, split_config in config.items() @@ -125,7 +126,7 @@ def test_preparator_split(): @pytest.mark.slow def test_preparator_split_sharded(): - path, config, hf_path = get_split_sharded_test_dataset() + path, config, hf_path, _ = get_split_sharded_test_dataset() dataset_config = { split: get_dataset_config(split_config, GPTDatasetFromFileConfig)._load_config().to_dict() for split, split_config in config.items() @@ -182,7 +183,9 @@ def test_dataset_preparator_from_hub(): assert (croissant_path := output_path / "croissant.json").is_file() Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k") - dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build() + dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build( + LanguageModelPreprocessingConfig(vocab_size=0) + ) Assert.custom(isinstance, dataset, MemmapDataset) hf_dataset = datasets.load_dataset("openai/gsm8k", "main", split="test") diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 7a31358b9..d32fb9880 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -1,4 +1,5 @@ from fast_llm.data.dataset.gpt.config import GPTRandomDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -16,8 +17,9 @@ def test_gpt_random_dataset(): # Make sure the random dataset works and check for unintended changes in behavior. + preprocessing = LanguageModelPreprocessingConfig(vocab_size=8192) sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( - get_sampling_data(4, sequence_length=7, vocab_size=8192) + get_sampling_data(4, sequence_length=7, preprocessing=preprocessing) ) compare_sampled_dataset(sampled, RANDOM_DATASET_EXPECTED_SAMPLES) @@ -26,6 +28,6 @@ def test_gpt_random_dataset(): {"datasets": {"training": config}}, 4, sequence_length=7, - vocab_size=8192, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, + preprocessing=preprocessing, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2d102be01..d6a935c61 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,9 +2,10 @@ import pytest import torch -from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import SamplingParameters, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert @@ -38,10 +39,10 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() sampled = get_dataset_config( dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) # Test in data. @@ -50,6 +51,7 @@ def test_gpt_sampled(): 8, sequence_length=5, expected_samples=GPT_MEMMAP_SAMPLES, + preprocessing=preprocessing, ) @@ -59,7 +61,7 @@ def __init__(self, samples): self._samples = samples def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None ) -> SampleType: if end is None: end = len(self._samples[index]) @@ -98,7 +100,15 @@ def test_gpt_sample(seed, shuffle): previous_samples = None # Loop instead of parametrizing for the check below. for num_samples in (20, 10, 6, 5, 2, 1): - sampled = TEST_DATASET.sample(get_sampling_data(num_samples, sequence_length=5, seed=seed, shuffle=shuffle)) + sampled = TEST_DATASET.sample( + get_sampling_data( + num_samples, + sequence_length=5, + seed=seed, + shuffle=shuffle, + preprocessing=LanguageModelPreprocessingConfig(vocab_size=0), + ) + ) samples = validate_indexed_dataset_sampling(sampled) if previous_samples is not None and shuffle != ShufflingType.full: # Check that the sequence is independent of `num_sample`. @@ -162,6 +172,7 @@ def test_gpt_sample_padding(): seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, + preprocessing=LanguageModelPreprocessingConfig(vocab_size=vocab_size), ) if total_tokens == 0: with pytest.raises(RuntimeError): diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 224b18270..54263b8e2 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -31,15 +31,15 @@ def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. - _, config, _ = get_common_test_dataset() + _, config, _, preprocessing = get_common_test_dataset() memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, DatasetSliceConfig[LanguageModelSample], - ).build() + ).build(preprocessing) compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) - sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) + sampled = dataset.sample(get_sampling_data(8, sequence_length=5, preprocessing=preprocessing)) validate_indexed_dataset_sampling(sampled, GPT_SLICE_VALIDATION_SAMPLES) # Test in data with multiple phases. @@ -72,4 +72,5 @@ def test_gpt_slice(): "training": GPT_SLICE_TRAINING_SAMPLES, "validation": GPT_SLICE_VALIDATION_SAMPLES, }, + preprocessing=preprocessing, ) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index e2cadf717..e29050b28 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -44,7 +44,7 @@ def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): and prefix.with_suffix(".bin").is_file() and prefix.parent.joinpath("fast_llm_config.yaml").is_file() ): - _, _, hf_path = get_common_test_dataset() + _, _, hf_path, _ = get_common_test_dataset() hf_dataset = datasets.load_from_disk(hf_path)["train"] tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() samples = [ diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index ed3f01307..7348a79ef 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -7,7 +7,9 @@ import PIL.Image from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.utils import padded_cumsum from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH @@ -158,7 +160,7 @@ def _get_test_dataset( path: pathlib.Path, seed: int, tokenizer_path: str = TOKENIZER_PATH, - vocab_size: int | None = None, + max_vocab_size: int | None = None, documents_per_shard: int = 10**6, num_documents: int = 1000, min_document_size: int = 5, @@ -173,7 +175,7 @@ def _get_test_dataset( min_image_size: int = 4, max_image_size: int = 32, config_only: bool = False, -) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path]: +) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: config_paths = ( [path / "fast_llm_config.yaml"] if splits is None @@ -214,7 +216,7 @@ def _get_test_dataset( "load_from_disk": True, "source_schema": source_schema, }, - "tokenizer": {"path": tokenizer_path, "max_vocab_size": vocab_size}, + "tokenizer": {"path": tokenizer_path, "max_vocab_size": max_vocab_size}, "output_path": path, "documents_per_shard": documents_per_shard, "splits": splits, @@ -231,28 +233,45 @@ def _get_test_dataset( for split, config_path in zip(splits, config_paths, strict=True) } ) - return path, config, hf_path + preprocessing = LanguageModelPreprocessingConfig( + tokenizer={"type": "tokenizer", "path": tokenizer_path, "max_vocab_size": max_vocab_size}, + image_patches=NullPreprocessingConfig() if image_patch_config is None else image_patch_config, + vocab_size=max_vocab_size or 0, + use_loss_masking_spans=max_loss_masking_spans > 0, + use_preference_spans=has_preference_spans, + ) + return path, config, hf_path, preprocessing -def get_common_test_dataset(): +def get_common_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "common_dataset", seed=1234) -def get_alt_test_dataset(): +def get_alt_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "other_dataset", seed=2345) -def get_sharded_test_dataset(): +def get_sharded_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "common_dataset_sharded", seed=1234, documents_per_shard=350) -def get_split_test_dataset(): +def get_split_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset( DATASET_CACHE / "common_dataset_split", seed=1234, splits={"training": 1, "validation": 1} ) -def get_split_sharded_test_dataset(): +def get_split_sharded_test_dataset() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset( DATASET_CACHE / "common_dataset_split_sharded", seed=1234, @@ -261,15 +280,21 @@ def get_split_sharded_test_dataset(): ) -def get_test_dataset_with_loss_masking_spans(): +def get_test_dataset_with_loss_masking_spans() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5) -def get_test_dataset_with_preference_spans(): +def get_test_dataset_with_preference_spans() -> ( + tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig] +): return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) -def get_test_dataset_with_image_patches(image_break_token: int | None = None, image_end_token: int | None = None): +def get_test_dataset_with_image_patches( + image_break_token: int | None = None, image_end_token: int | None = None +) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: return _get_test_dataset( DATASET_CACHE / f"dataset_with_image_patches_{image_break_token}_{image_end_token}", seed=1234, @@ -289,7 +314,7 @@ def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, - vocab_size=MODEL_TEST_VOCAB_SIZE, + max_vocab_size=MODEL_TEST_VOCAB_SIZE, splits={"training": 969, "validation": 30, "test": 1}, config_only=config_only, ) @@ -299,7 +324,7 @@ def get_multimodal_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset_multimodal", seed=1234, - vocab_size=MODEL_TEST_VOCAB_SIZE, + max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_images=2, image_patch_config=ImagePatchConfig( height=4, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 752e3a8c8..186991ed5 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -95,7 +95,7 @@ class ModelTestingConfig: ) def __post_init__(self): - _, config, _ = self.get_dataset(config_only=True) + _, config, _, _ = self.get_dataset(config_only=True) self.config_dict["data"]["datasets"] = config @functools.cached_property From d27a8151b638e3ef31eb947cef66d7bd8121cb34 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 17:21:29 -0500 Subject: [PATCH 05/14] fix --- fast_llm/data/preprocessing/tokenizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 9e11fa66c..2963e8e63 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -96,8 +96,11 @@ def tokenize( if self._config.max_vocab_size is not None: # In some cases creating a tensor before restricting the vocab size may cause an overflow. - ( - torch.tensor(tokens, dtype=torch.int64 if len(self.tokenizer) > torch.iinfo().max else data_type.torch) + tokens = ( + torch.tensor( + tokens, + dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + ) % self._config.max_vocab_size ).to(data_type.torch) else: From 5ab6cd03b28506e6847b7b7cea08f083600e9b07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 5 Dec 2025 22:30:27 -0500 Subject: [PATCH 06/14] fixes --- fast_llm/data/preprocessing/language_model.py | 5 ++-- fast_llm/data/sample/language_model.py | 1 - fast_llm/engine/checkpoint/distributed.py | 21 ++++++++------ fast_llm/engine/multi_stage/fsdp.py | 6 ++++ fast_llm/engine/multi_stage/multi_stage.py | 22 ++++++++++---- fast_llm/models/multimodal/trainer.py | 1 + fast_llm/utils.py | 12 +++++--- tests/data/common.py | 7 ++++- tests/data/test_blending.py | 11 ++++--- tests/data/test_concatenate.py | 2 +- tests/data/test_preparator.py | 2 +- tests/data/test_sampling.py | 3 -- tests/models/distributed_test_checkpoint.py | 2 +- tests/models/test_checkpoint.py | 4 ++- tests/test_varlen.py | 9 ++---- tests/utils/dataset.py | 2 +- tests/utils/model_configs.py | 29 +++++++++---------- 17 files changed, 81 insertions(+), 58 deletions(-) diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index 6c38c3f4e..88ec8f245 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -19,7 +19,7 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): # and in any case the tokenizer path may no longer be valid when loading a prepared dataset, # so we provide the vocab size and use it for compatibility checks. image_patches: PreprocessingConfig = Field() - vocab_size: int = Field() + vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) @@ -35,7 +35,8 @@ def use_image_patches(self) -> bool: def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? - Assert.geq(self.vocab_size, preprocessing.vocab_size) + if self.vocab_size is not None and preprocessing.vocab_size is not None: + Assert.leq(self.vocab_size, preprocessing.vocab_size) if preprocessing.use_preference_spans: # Preference spans are strictly needed for DPO loss. assert self.use_preference_spans diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 1331cf82a..beadb1161 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -156,7 +156,6 @@ def _validate(self) -> None: else: image_patches = NullPreprocessingConfig() self.preprocessing = LanguageModelPreprocessingConfig( - vocab_size=0, image_patches=image_patches, use_loss_masking_spans=isinstance(self.loss_masking_spans, RangeReaderConfig), use_preference_spans=isinstance(self.chosen_spans, RangeReaderConfig), diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index c2f4d8cdd..d953ea35d 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -120,12 +120,15 @@ def _copy_shard_overlaps(self, loaded_model, loaded_shards, context): self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in loaded_shards} - for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): - for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): - counter = self_fsdp.copy_shard_overlaps( - loaded_fsdp, - self_fsdp_shards, - loaded_fsdp_shards, - ) - for parameter, count in counter.items(): - context.mark_as_loaded(count, parameter, True) + for loaded_stage, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards): + # Skip tied weight copies to avoid duplicate loads. + # We can't call `loaded_stage.is_tied_weight_copy` because the loaded model isn't setup. + if not loaded_stage.index not in loaded_model.stages_owned: + for self_stage, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards): + counter = self_fsdp.copy_shard_overlaps( + loaded_fsdp, + self_fsdp_shards, + loaded_fsdp_shards, + ) + for parameter, count in counter.items(): + context.mark_as_loaded(count, parameter, True) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 827079f6e..36e8ff20d 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -1,4 +1,5 @@ import dataclasses +import logging import math import typing @@ -18,6 +19,8 @@ from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta from fast_llm.utils import Assert, clamp, padded_cumsum +logger = logging.getLogger(__name__) + class FSDP: _is_setup: bool = False @@ -276,6 +279,9 @@ def split_buffer(self, buffer: torch.Tensor) -> dict[str, torch.Tensor]: return {name: self._get_parameter_in_buffer(buffer, name) for name in self._parameter_metas} def _get_parameter_in_buffer(self, buffer: torch.Tensor, name: str) -> torch.Tensor: + logger.info( + f"{name}, {self.get_parameter_begin_in_buffer(name)}, {self.get_parameter_end_in_buffer(name)}, {buffer.shape}, {self._parameter_metas[name]}" + ) return buffer[self.get_parameter_begin_in_buffer(name) : self.get_parameter_end_in_buffer(name)].view( self._parameter_metas[name].shape ) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index f45f93862..89be60c24 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -3,6 +3,7 @@ import typing import warnings +import safetensors.torch import torch from torch._C._distributed_c10d import ProcessGroup @@ -21,6 +22,7 @@ from fast_llm.utils import Assert, get_unique logger = logging.getLogger(__name__) +safetensors.torch.safe_open class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): @@ -426,6 +428,10 @@ def stages(self) -> list[Stage]: def stages_on_device(self) -> dict[int, Stage]: return self._stages_on_device + @property + def stages_owned(self) -> dict[int, Stage]: + return self._stages_owned + @property def tied_parameters(self) -> dict[str, "TiedParameter"]: return self._tied_parameters @@ -485,11 +491,17 @@ def get_state_tensor_iterator( ) -> typing.Generator[tuple[str, str, torch.Tensor], None, None]: for shard_name in shard_names: shard_split = self._shards[shard_name].split(self._stage_weight_shard_sizes, 0) - for shard_index, (stage, shard) in enumerate(zip(self._stages_owned.values(), shard_split, strict=True)): - for name, tensor in stage._export_shard( - shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type - ): # noqa - yield name, shard_name, tensor + logger.info( + f"{shard_name}, {self._shards[shard_name].shape}, {self._stage_weight_shard_sizes}, {self._stages_owned.values}, {[x.shape for x in shard_split]}" + ) + for shard_index, ((stage_index, stage), shard) in enumerate( + zip(self._stages_on_device.items(), shard_split, strict=True) + ): + if stage_index in self._stages_owned: + for name, tensor in stage._export_shard( + shard.split(self._fsdp_weight_shard_sizes[shard_index]), data_type=data_type + ): # noqa + yield name, shard_name, tensor def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torch.Tensor | SafeTensorSlice): """ diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py index cd8e09cae..43a8f8885 100644 --- a/fast_llm/models/multimodal/trainer.py +++ b/fast_llm/models/multimodal/trainer.py @@ -14,6 +14,7 @@ def _get_preprocessing_config( ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: out = super()._get_preprocessing_config(_return_dict=True) out["image_patches"] = { + "type": "image_patch", "height": self._config.model.base_model.vision_encoder.embeddings.patch_height, "width": self._config.model.base_model.vision_encoder.embeddings.patch_width, # TODO: Max shape and special tokens are unspecified in the model. diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 83675ac74..259073e32 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -146,19 +146,23 @@ def multiple(x, y): assert x % y == 0, f"{x} not a multiple of {y}" @staticmethod - def rms_close(x, y, threshold): + def rms_close(x, y, threshold, *, msg=None): rms = rms_diff(x, y).detach().item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + "" if msg is None else f"| {msg}" + ) @staticmethod - def rms_close_relative(x, y, threshold, min_threshold=0): + def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): import torch Assert.eq(x.shape, y.shape) scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5 threshold = max(threshold * scale, min_threshold) rms = rms_diff(x, y).item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + "" if msg is None else f"| {msg}" + ) @staticmethod def all_equal(x, *args): diff --git a/tests/data/common.py b/tests/data/common.py index 210749864..34fdba321 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -29,10 +29,12 @@ def get_sampling_data( gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, - preprocessing: LanguageModelPreprocessingConfig, + preprocessing: LanguageModelPreprocessingConfig | None = None, ) -> GPTSamplingData: # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) + if preprocessing is None: + preprocessing = LanguageModelPreprocessingConfig() return GPTSamplingData( config=SamplingConfig( seed=seed, @@ -122,6 +124,9 @@ def compare_indexed_dataset_tokens( def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: + # Uncomment to print the current list of samples. + # for i in range(len(expected_samples)): + # print(i, sampled[i].tokens.tokens.tolist()) Assert.eq(len(sampled), len(expected_samples)) Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 5cad573ca..989e99b24 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,6 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( @@ -44,12 +43,12 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_MIXED_SAMPLES = [ [49152, 46, 10, 819, 19, 45], - [916, 6683, 7685, 1277, 5106, 378], + [25492, 15877, 37874, 8570, 31649, 15521], [45, 69, 17, 86, 38826, 15], - [3359, 6803, 780, 4561, 669, 7878], + [3359, 20945, 33437, 32454, 42084, 45942], [15, 25, 51, 31, 32348, 64], [64, 17, 93, 78, 40, 1793], - [6920, 2218, 2921, 3963, 7606, 6904], + [15112, 36731, 47864, 35586, 33356, 37537], [1793, 1, 1746, 38, 27, 58], ] @@ -85,7 +84,7 @@ def test_blending(probs): # Use a list of integers as a mock dataset, encoding both indexes in the sample. [list(range(i * num_samples, (i + 1) * num_samples)) for i, _ in enumerate(probs)], # noqa probs, - get_sampling_data(num_samples, preprocessing=LanguageModelPreprocessingConfig(vocab_size=8192)), + get_sampling_data(num_samples), ) probs = normalize_probabilities(probs) samples = np.array([dataset[i] for i in range(num_samples)]) @@ -133,7 +132,7 @@ def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. _, config, _, preprocessing = get_common_test_dataset() # Random dataset needs an explicit vocab size. - preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 8192}) + preprocessing = preprocessing.from_dict(preprocessing, {"vocab_size": 50000}) sampled = get_dataset_config( dataset_config := { "type": "blended", diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 1580842b7..19539cc8c 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -31,7 +31,7 @@ def test_gpt_concatenate(): dataset = get_dataset_config( dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], - ).build(LanguageModelPreprocessingConfig(vocab_size=0)) + ).build(LanguageModelPreprocessingConfig()) compare_indexed_dataset_tokens( dataset, 3 * COMMON_DATASET_LENGTH, diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index dd4375418..f4f6fab82 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -184,7 +184,7 @@ def test_dataset_preparator_from_hub(): Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k") dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build( - LanguageModelPreprocessingConfig(vocab_size=0) + LanguageModelPreprocessingConfig() ) Assert.custom(isinstance, dataset, MemmapDataset) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index d6a935c61..2e47fd6aa 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -5,7 +5,6 @@ from fast_llm.data.dataset.config import SamplingParameters, ShufflingType from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert @@ -106,7 +105,6 @@ def test_gpt_sample(seed, shuffle): sequence_length=5, seed=seed, shuffle=shuffle, - preprocessing=LanguageModelPreprocessingConfig(vocab_size=0), ) ) samples = validate_indexed_dataset_sampling(sampled) @@ -172,7 +170,6 @@ def test_gpt_sample_padding(): seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, - preprocessing=LanguageModelPreprocessingConfig(vocab_size=vocab_size), ) if total_tokens == 0: with pytest.raises(RuntimeError): diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py index 217ecd0e1..001eb36da 100644 --- a/tests/models/distributed_test_checkpoint.py +++ b/tests/models/distributed_test_checkpoint.py @@ -41,7 +41,7 @@ def _test_load_and_save_parallel( mode=StageMode.inference, ) for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): - logger.info(f"Loading {save_format.name} checkpoint to {config.save_path / save_format.name}") + logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)) del model gc.collect() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 9acf8a9d7..bb53de29e 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -428,8 +428,10 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: return None +# We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. +# This should still run after `test_save_and_load_in_parallel` @requires_cuda -@pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) +@pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config: DistributedSaveLoadConfig, diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 126a3e1e5..730bab2c9 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -8,6 +8,7 @@ from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import gdn as gdn_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.utils import Assert @pytest.fixture @@ -207,13 +208,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - torch.testing.assert_close( - _param_grad(param), - _param_grad(param_ref), - atol=1e-3, - rtol=1e-3, - msg=f"Grad mismatch for parameter {name}", - ) + Assert.rms_close_relative(_param_grad(param), _param_grad(param_ref), 1e-3, 1e-3, msg=name) if __name__ == "__main__": diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7348a79ef..47f254893 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -236,7 +236,7 @@ def _get_test_dataset( preprocessing = LanguageModelPreprocessingConfig( tokenizer={"type": "tokenizer", "path": tokenizer_path, "max_vocab_size": max_vocab_size}, image_patches=NullPreprocessingConfig() if image_patch_config is None else image_patch_config, - vocab_size=max_vocab_size or 0, + vocab_size=max_vocab_size, use_loss_masking_spans=max_loss_masking_spans > 0, use_preference_spans=has_preference_spans, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e43503137..b0a9acf36 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -584,8 +584,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 }, compare_factor=2, - # modes not supported with reference models - skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), + # Modes not supported with reference models + skip_tests=("sdp", "ms", "pp"), ) _update_and_add_testing_config( @@ -611,11 +611,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, df4, df4_sf, tp2, stp2, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=8, - # modes not supported with reference models - skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "stp2_ce4"), + # Modes not supported with reference models and/or activation distillation. + # TODO: Fix gradient accumulation and fp16, add TP support. + skip_tests=("sdp", "ms", "pp", "tp", "df", "bf", "fp16"), ) _update_and_add_testing_config( @@ -674,8 +675,8 @@ def _update_and_add_testing_config( checkpoint_format=AprielHybridSSMCheckpointFormat, # TODO: Add back generate as `normal` when stable. groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.unimportant, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, # TODO: Fix and bring back to `testing_groups` ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, ModelTestingGroup.generate: ModelTestingGroupAction.broken, @@ -684,7 +685,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=(r"sdp", r"ms"), + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( @@ -725,10 +726,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=( - r"sdp", - r"ms", - ), # "pp","dp", "ce","16", "bf", "df", "stp"), + skip_tests=("sdp", "ms"), ) @@ -846,15 +844,16 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + # TODO: Fix (`fast_llm/models/gpt/conversion/apriel.py:235: KeyError: 'value_head_dim'`) + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + compare_factor=10.0, # High diff for fp16 and bf16 due to rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! - skip_tests=(r"sdp", r"ms", r"^tp2$"), + skip_tests=("sdp", "ms", r"(? Date: Mon, 8 Dec 2025 22:55:36 -0500 Subject: [PATCH 07/14] stuff --- fast_llm/engine/base_model/config.py | 21 +- fast_llm/engine/multi_stage/fsdp.py | 1 + fast_llm/layers/attention/attention.py | 48 +--- fast_llm/layers/attention/config.py | 15 +- fast_llm/layers/attention/preprocessing.py | 58 +++++ fast_llm/layers/common/linear/convolution.py | 6 +- fast_llm/layers/ssm/config.py | 19 +- fast_llm/layers/ssm/gdn.py | 69 +----- fast_llm/layers/ssm/kda.py | 61 ++--- fast_llm/layers/ssm/mamba2.py | 36 ++- fast_llm/utils.py | 4 +- setup.cfg | 2 +- tests/data/test_sampling.py | 11 +- tests/functional/test_cross_entropy.py | 10 +- tests/functional/test_functional.py | 1 + tests/{ => layers}/test_attention.py | 0 .../{test_gdn_equivalence.py => test_gdn.py} | 6 +- .../{test_kda_equivalence.py => test_kda.py} | 14 +- tests/layers/test_lm_head.py | 2 - tests/layers/test_varlen.py | 97 ++++++++ tests/test_varlen.py | 234 ------------------ tests/utils/distributed_configs.py | 5 +- tests/utils/model_configs.py | 14 +- tests/utils/utils.py | 7 + 24 files changed, 299 insertions(+), 442 deletions(-) create mode 100644 fast_llm/layers/attention/preprocessing.py rename tests/{ => layers}/test_attention.py (100%) rename tests/layers/{test_gdn_equivalence.py => test_gdn.py} (95%) rename tests/layers/{test_kda_equivalence.py => test_kda.py} (92%) create mode 100644 tests/layers/test_varlen.py delete mode 100644 tests/test_varlen.py diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index f1eef47b9..0526b9dc2 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -8,6 +8,8 @@ from fast_llm.utils import Assert, compare_nested, log if typing.TYPE_CHECKING: + import torch + from fast_llm.engine.base_model.base_model import BaseModel @@ -58,6 +60,17 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: return self._serialize_value(value) +def set_model_names(model: "torch.nn.Module"): + from fast_llm.tensor import ParameterMeta + + for key, value in model.named_modules(): + value.module_name = key + for key, value in model.named_parameters(): + Assert.custom(isinstance, value, ParameterMeta) + # Rename to the parameter full name + value.tensor_name = key + + @config_class() class BaseModelConfig(ModuleConfig): """ @@ -65,17 +78,11 @@ class BaseModelConfig(ModuleConfig): """ def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel": - from fast_llm.tensor import ParameterMeta model = self.base_model_class(self, distributed_config) # Storing the global name of each module and tensor. # Done here because it needs to run right after `model.__init__()` - for key, value in model.named_modules(): - value.module_name = key - for key, value in model.named_parameters(): - Assert.custom(isinstance, value, ParameterMeta) - # Rename to the parameter full name - value.tensor_name = key + set_model_names(model) return model @property diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index a5a41f542..fbe6d3297 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -47,6 +47,7 @@ def __init__( ): self._name = name self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas} + Assert.eq(len(self._parameter_metas), len(parameter_metas)) # `set_model_names` ensure unique names. self._distributed_config = distributed_config self._fsdp_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.data) self._is_tied_weight_copy = is_tied_weight_copy diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 059469d94..3724ee413 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -11,11 +11,12 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div +from fast_llm.utils import div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -360,6 +361,7 @@ def _forward( key_value = key_value.transpose(0, 1).contiguous() key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) + print("AAAAA", input_.shape, query.shape, key.shape) query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) @@ -505,40 +507,10 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value def _preprocess_for_flash_attention(self, kwargs: dict[str, typing.Any]) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - if self._config.cross_document_attention: - return - device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device - - # TODO: ====== Fix (need to know how much first sequence was cropped) ====== - Assert.eq( - kwargs[AttentionKwargs.sequence_k_dim].global_size, kwargs[AttentionKwargs.sequence_q_dim].global_size - ) - - # TODO: Calculate these in batch preprocessing? - sequence_lengths_q = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[AttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths - ), - ], - dtype=torch.int32, - ) - max_sequence_length = sequence_lengths_q.max().item() - cu_seqlens_q = sequence_lengths_q.cumsum_(0).to(device) - max_seqlen_q = cu_seqlens_q.new_full((1,), max_sequence_length) - kwargs[AttentionKwargs.cu_seqlens_q] = cu_seqlens_q - kwargs[AttentionKwargs.cu_seqlens_k] = cu_seqlens_q - kwargs[AttentionKwargs.max_seqlen_q] = max_seqlen_q - kwargs[AttentionKwargs.max_seqlen_k] = max_seqlen_q + if not self._config.cross_document_attention: + preprocess_for_varlen( + kwargs, + kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_max_seqlen=True, + ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 6f589eeb4..626a8fde6 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -17,15 +17,20 @@ logger = logging.getLogger(__name__) -class AttentionKwargs(BlockKwargs): - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" +class MixerKwargs(BlockKwargs): cu_seqlens_q = "cu_seqlens_q" cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + seq_idx = "seq_idx" + position_ids = "position_ids" + + +class AttentionKwargs(MixerKwargs): + rotary_freq_q = "rotary_freq_q" + rotary_freq_k = "rotary_freq_k" + attention_mask = "attention_mask" + attention_mask_value = "attention_mask_value" # TODO: Review these presents = "presents" past_key_values = "past_key_values" diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py new file mode 100644 index 000000000..a9d9936c5 --- /dev/null +++ b/fast_llm/layers/attention/preprocessing.py @@ -0,0 +1,58 @@ +import typing + +import torch + +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.utils import Assert + + +def preprocess_for_varlen( + kwargs: dict[str, typing.Any], + device: torch.device, + return_cu_seqlens: bool = False, + return_max_seqlen: bool = False, + return_seq_idx: bool = False, + return_position_ids: bool = False, +) -> None: + """ + Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 + cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. + Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. + If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally + also contain previous tokens from the first document in micro-sequence. + We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. + """ + + # TODO: ====== Fix (need to know how much first sequence was cropped) ====== + Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size) + + sequence_lengths = [ + sequence_length + for sequence_lengths in kwargs[MixerKwargs.sequence_lengths] + for sequence_length in sequence_lengths + ] + if return_cu_seqlens: + cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum( + 0, dtype=torch.int32 + ) + kwargs[MixerKwargs.cu_seqlens_q] = cu_seqlens_q + kwargs[MixerKwargs.cu_seqlens_k] = cu_seqlens_q + if return_max_seqlen: + max_seqlen_q = torch.full((1,), max(sequence_lengths), dtype=torch.int32, device=device) + kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q + kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q + if return_seq_idx: + kwargs[MixerKwargs.seq_idx] = torch.cat( + [ + torch.full((sequence_length,), i, dtype=torch.int32, device=device) + for i, sequence_length in enumerate(sequence_lengths) + ] + ) + if return_position_ids: + kwargs[MixerKwargs.position_ids] = torch.cat( + [ + torch.arange(sequence_length, dtype=torch.int32, device=device) + for i, sequence_length in enumerate(sequence_lengths) + ] + ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 69018fd06..c336a7e99 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -34,7 +34,11 @@ def __init__( else self._forward_torch ) - def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: + if kwargs: + raise NotImplementedError( + f"Arguments {tuple(kwargs)} not implemented for torch implementation of 1d convolution." + ) return self._activation.activation_fn( torch.nn.functional.conv1d( input_, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 450591216..f0e3a1529 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig @@ -20,11 +19,6 @@ from fast_llm.tensor import ParameterMeta -class LinearAttentionKwargs(BlockKwargs): - cu_seqlens = "cu_seqlens" - seq_idx = "seq_idx" - - @config_class(dynamic_type={MixerConfig: "gdn"}) class GatedDeltaNetConfig(MixerConfig): """ @@ -179,13 +173,6 @@ def layer_class(self) -> "type[KimiDeltaAttention]": return KimiDeltaAttention - def _validate(self) -> None: - with self._set_implicit_default(): - if "activation" not in self.normalization._explicit_fields: - self.normalization.activation = "sigmoid" - - super()._validate() - @config_class() class SSMConfig(MixerConfig): @@ -334,6 +321,12 @@ class Mamba2Config(MambaBaseConfig): desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", hint=FieldHint.architecture, ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", + hint=FieldHint.feature, + ) @property def layer_class(self) -> "type[Mamba2]": diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 40f15837c..e59f3fc03 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -10,10 +10,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, LinearAttentionKwargs +from fast_llm.layers.ssm.config import GatedDeltaNetConfig from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -325,9 +327,6 @@ def _forward( # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) - projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) if sequence_first: @@ -347,9 +346,8 @@ def _forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2) - mixed_qkv = self.convolution( - mixed_qkv, seq_idx=seq_idx - ) # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + mixed_qkv = self.convolution(mixed_qkv, seq_idx=kwargs[MixerKwargs.seq_idx].unsqueeze(0)) mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, @@ -383,7 +381,7 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, + cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) z_shape_og = z.shape @@ -400,56 +398,13 @@ def _forward( return output - def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: - """ - Creates seqlens and cu_seqlens for packed forward. - This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. - Note: padding tokens are always on the right and get their own entry in LinearAttentionKwargs.sequence_lengths --> they are treated as seperate sequence. - - Sets: - - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token - - cu_seqlens to [N+1] tensor, where N is the total number of sequences in the batch, each element is the cumulative sequence length of packed sequences sofar - """ - - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - device = kwargs.get("device", None) - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - seqlens = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths - ), - ], - dtype=torch.int32, - ) - cu_seqlens = seqlens.cumsum_(0).to(device) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._preprocess_for_varlen(kwargs) + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_seq_idx=True, + ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 323e1ad13..270ac65bf 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -9,10 +9,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig from fast_llm.tensor import ParameterMeta, TensorMeta logger = logging.getLogger(__name__) @@ -229,8 +231,6 @@ def _forward( sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) # TODO: can be made more efficeint by rearranging hidden states directly and only once residual_dtype = hidden_states.dtype @@ -250,9 +250,10 @@ def _forward( # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) - q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) - k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) - v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) + seq_idx = kwargs[MixerKwargs.seq_idx].unsqueeze(0) + q = self._apply_conv(q, self.q_conv, seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) if sequence_first: @@ -281,7 +282,7 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, + cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) attn_out = attn_out.to(residual_dtype) @@ -303,44 +304,10 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() - def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - device = kwargs.get("device", None) - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - - seqlens = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths # bs - ), - ], - dtype=torch.int32, - ) - cu_seqlens = seqlens.cumsum_(0).to(device) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._preprocess_for_varlen(kwargs) + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_seq_idx=True, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 616d1152f..6e0ae0c60 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import inspect import logging import typing @@ -8,6 +9,8 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -19,8 +22,17 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa _mamba_available = True + sig = inspect.signature(selective_scan_fn) + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + if "position_indices" in sig.parameters: + _mamba_varlen_available = True + else: + _mamba_varlen_available = False + except (ImportError, RuntimeError): _mamba_available = False + _mamba_varlen_available = False logger = logging.getLogger(__name__) @@ -181,15 +193,19 @@ def _forward( # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) + convolution_kwargs = ( + {} if self._config.cross_document_attention else {"seq_idx": kwargs[MixerKwargs.seq_idx].unsqueeze(0)} + ) if self._config.repeat_kv_before_conv: x = self.convolution( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) - .flatten(1, 2) + .flatten(1, 2), + **convolution_kwargs, ) else: x = ( - self.convolution(x) + self.convolution(x, **convolution_kwargs) .unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) @@ -214,6 +230,9 @@ def _forward( self._debug(c, "c", self._bc_dims, kwargs) self._debug(dt, "dt", self._xz_dims, kwargs) + scan_kwargs = ( + {} if self._config.cross_document_attention else {"position_indices": kwargs[MixerKwargs.position_ids]} + ) y = selective_scan_fn( x, dt, @@ -224,6 +243,7 @@ def _forward( z, delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, + **scan_kwargs, ) self._debug(y, "y", self._xz_dims, kwargs) @@ -242,3 +262,15 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Implement. raise NotImplementedError() + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + if not self._config.cross_document_attention: + assert ( + _mamba_varlen_available + ), f"Varlen mamba requires custom mamba installation from `https://github.com/jxiw/varlen_mamba`" + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_seq_idx=True, + return_position_ids=True, + ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 259073e32..2ca61aa0e 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -160,7 +160,9 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5 threshold = max(threshold * scale, min_threshold) rms = rms_diff(x, y).item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + assert ( + rms <= threshold + ), f"Rms diff too big ({rms:.2e} > {threshold:.2e}, scale = {scale:.2e}) between tensors {x} and {y}" + ( "" if msg is None else f"| {msg}" ) diff --git a/setup.cfg b/setup.cfg index 58f8ea2d1..7bf88b4dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.6.post3 + mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5 GENERATION = diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2e47fd6aa..f28c9cce2 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -129,17 +129,10 @@ def test_build_padded_token_cumsum(): Assert.all_equal(token_cumsum, expected_cumsums) -def get_test_seeds(num_seeds): - np.random.seed(42) - seeds = np.random.randint(0, num_seeds * 100, num_seeds) - return seeds.tolist() - - @pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") def test_gpt_sample_padding(): - for seed in get_test_seeds(100): + for _ in range(10): vocab_size = 30 - np.random.seed(seed) num_sequences = np.random.randint(1, 20) sequence_length = np.random.randint(1, 20) doc_sizes = np.random.randint(1, 2 * sequence_length, num_sequences) @@ -167,7 +160,7 @@ def test_gpt_sample_padding(): sampling = get_sampling_data( num_samples=len(expected_samples), sequence_length=sequence_length, - seed=seed, + seed=np.random.randint(100000), shuffle=ShufflingType.disabled, truncate_documents=False, ) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index de95ca214..088250885 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -50,9 +50,9 @@ def _assert_loss_and_grad(logits, loss, grad): assert torch.isfinite(grad).all() +@pytest.mark.slow @pytest.mark.parametrize("use_mask", [False, True]) def test_reverse_kl_no_tp(use_mask): - torch.manual_seed(0) batch_size, seq_len, vocab_size = 2, 3, 5 logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) target = torch.randn(batch_size, seq_len, vocab_size) @@ -81,11 +81,10 @@ def test_reverse_kl_no_tp(use_mask): else: valid_tokens = logits.shape[0] * logits.shape[1] reference = per_sample.sum() / valid_tokens - torch.testing.assert_close(loss, reference, atol=1e-6, rtol=1e-6) + Assert.rms_close_relative(loss, reference, 1e-6, 1e-6) def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - torch.manual_seed(0) world_size = dist.get_world_size(group) batch_size, seq_len, vocab_per_rank = 2, 3, 5 @@ -124,11 +123,10 @@ def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): else: ref_loss = torch.zeros_like(loss) dist.broadcast(ref_loss, src=0, group=group) - torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + Assert.rms_close_relative(loss, ref_loss, 1e-6, 1e-6) def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - torch.manual_seed(0) world_size = dist.get_world_size(group) batch_size, seq_len, vocab_per_rank = 2, 3, 5 @@ -169,7 +167,7 @@ def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): else: ref_loss = torch.zeros_like(loss) dist.broadcast(ref_loss, src=0, group=group) - torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) + Assert.rms_close_relative(loss, ref_loss, 1e-6, 1e-6) def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 05fafe7a9..c48a0a531 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -136,6 +136,7 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow +@pytest.mark.skip("Dropless MoE is broken") @requires_cuda def test_dropless_mlp(): num_experts = 4 diff --git a/tests/test_attention.py b/tests/layers/test_attention.py similarity index 100% rename from tests/test_attention.py rename to tests/layers/test_attention.py diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn.py similarity index 95% rename from tests/layers/test_gdn_equivalence.py rename to tests/layers/test_gdn.py index 4af68ea16..9b4435d6c 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn.py @@ -4,6 +4,7 @@ from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.utils import Assert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -19,7 +20,6 @@ @pytest.mark.slow @requires_cuda def test_fast_llm_gdn_matches_qwen3_next_forward(): - torch.manual_seed(0) device = torch.device("cuda") dtype = torch.bfloat16 @@ -97,7 +97,7 @@ def test_fast_llm_gdn_matches_qwen3_next_forward(): } hf_state_dict = hf_layer.gdn.state_dict() for k, p in fast_layer.state_dict().items(): - torch.testing.assert_close(p, hf_state_dict[param_map[k]], atol=1e-5, rtol=1e-5) + Assert.rms_close_relative(p, hf_state_dict[param_map[k]], 1e-5, 1e-5) # need to monkey patch the hf implementation with our fix_query_key_value_ordering due to the layout differences hf_layer.gdn.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering @@ -118,4 +118,4 @@ def test_fast_llm_gdn_matches_qwen3_next_forward(): fast_layer.preprocess(fast_kwargs) fast_out, _ = fast_layer(hidden_states, fast_kwargs) - torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda.py similarity index 92% rename from tests/layers/test_kda_equivalence.py rename to tests/layers/test_kda.py index 8745236d4..477eefaa4 100644 --- a/tests/layers/test_kda_equivalence.py +++ b/tests/layers/test_kda.py @@ -5,6 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.layers.block.config import BlockKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda try: @@ -26,7 +27,6 @@ @pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_fast_llm_kda_matches_apriel_forward(): - torch.manual_seed(0) device = torch.device("cuda") dtype = torch.bfloat16 @@ -113,13 +113,13 @@ def test_fast_llm_kda_matches_apriel_forward(): "dt_bias": "dt_bias", "norm.weight": "o_norm.weight", } - for fast_name, hf_name in param_map.items(): - fast_param = fast_layer.state_dict()[fast_name] - hf_param = hf_layer.state_dict()[hf_name] + hf_params = hf_layer.state_dict() + for fast_name, fast_param in fast_layer.state_dict().items(): + hf_param = hf_params[param_map[fast_name]] if fast_param.shape != hf_param.shape: + Assert.eq(fast_param.numel(), hf_param.numel(), msg=fast_name) hf_param = hf_param.reshape_as(fast_param) - print(f"Comparing parameter {fast_name} with shape {fast_param.shape}") - torch.testing.assert_close(fast_param, hf_param, atol=1e-5, rtol=1e-5) + Assert.rms_close_relative(fast_param, hf_param, 1e-5, 1e-5, msg=fast_name) hidden_states = torch.randn(2, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) hf_layer.training = True @@ -135,4 +135,4 @@ def test_fast_llm_kda_matches_apriel_forward(): fast_layer.preprocess(fast_kwargs) fast_out, _ = fast_layer(hidden_states, fast_kwargs) - torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6383e6aae..108261d4b 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -163,8 +163,6 @@ def test_lm_head( loss_masking: bool, prediction_heads: int, ): - torch.cuda.manual_seed(0) - torch.manual_seed(0) head_config = { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py new file mode 100644 index 000000000..b231aedae --- /dev/null +++ b/tests/layers/test_varlen.py @@ -0,0 +1,97 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import gdn as gdn_module +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config +from fast_llm.utils import Assert +from tests.utils.utils import get_stage, requires_cuda + + +# TODO: include mamba varlen +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "config", + [ + AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), + Mamba2Config(d_inner=128, d_xb=64, state_size=16, dt_rank=8, cross_document_attention=False), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), + ], +) +def test_mixer_varlen_stacking_equivalence(config: MixerConfig): + """ + Check that Gated Delta Net forward/backward match with and without packing. + """ + hidden_size = 32 + hidden_dim = TensorDim("hidden", hidden_size) + distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.float16)) + mixer = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + stage = get_stage([mixer], distributed) + + batch_size = 2 # cu_seqlens path requires flattened batch + seq_len = 15 + + sequence_lengths = [[6, 9], [4, 1, 10]] + hidden_states = torch.randn( + batch_size, + seq_len, + hidden_size, + device=distributed.device, + dtype=distributed_config.compute_dtype.torch, + requires_grad=True, + ) + + kwargs = { + BlockKwargs.device: distributed.device, + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + BlockKwargs.sequence_q_dim: TensorDim("", seq_len), + BlockKwargs.sequence_k_dim: TensorDim("", seq_len), + } + + kwargs_packed = {**kwargs, BlockKwargs.sequence_lengths: sequence_lengths} + mixer.preprocess(kwargs_packed) + + out_packed, context = stage.forward(hidden_states, kwargs_packed) + stage.backward(torch.ones_like(out_packed), context) + + names, parameters = zip(*list(mixer.named_parameters())) + grads_packed = [parameter.grad_buffer.clone() for parameter in parameters] + + stage.reset_gradients() + # Run reference path separately per sequence without varlen packing, then concatenate. + out_refs = [] + for i in range(batch_size): + for seq in torch.split(hidden_states[i], sequence_lengths[i], dim=0): + kwargs_seq = {**kwargs, BlockKwargs.sequence_lengths: [[len(seq)]]} + mixer.preprocess(kwargs_seq) + out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) + stage.backward(torch.ones_like(out), context) + out_refs.append(out) + out_ref = torch.cat(out_refs, dim=1).view_as(out_packed) + + Assert.rms_close_relative(out_packed, out_ref, 1e-3, 1e-4) + + for name, parameter, grad_packed in zip(names, parameters, grads_packed, strict=True): + Assert.rms_close_relative(grad_packed, parameter.grad_buffer, 1e-3, 1e-4, msg=name) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_varlen.py b/tests/test_varlen.py deleted file mode 100644 index 256da95d4..000000000 --- a/tests/test_varlen.py +++ /dev/null @@ -1,234 +0,0 @@ -import pytest -import torch - -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.decoder.config import MixerConfig -from fast_llm.layers.ssm import gdn as gdn_module -from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig -from fast_llm.utils import Assert - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def unpack_and_padd(packed_hidden_states, cu_seqlens, package_num): - batch_size = packed_hidden_states.shape[0] - seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - hidden_dim = packed_hidden_states.shape[2] - hidden_states = torch.zeros( - package_num * batch_size, - seq_len, - hidden_dim, - dtype=packed_hidden_states.dtype, - device=packed_hidden_states.device, - ) - for j in range(batch_size): - for i in range(package_num): - line = j * package_num + i - hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ - j, cu_seqlens[i] : cu_seqlens[i + 1], : - ] - return hidden_states - - -def pack(hidden_states, cu_seqlens, batch_size): - package_num, seq_len, hidden_dim = hidden_states.shape - seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] - seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) - indices_3d = ( - torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) - ) - mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) - packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) - return packed_hidden_states - - -def generate_random_seq_len(seq_len, packages_num=2): - if packages_num < 1: - raise ValueError("packages_num must be at least 1") - - # base size of each chunk, and how many get an extra token - base, rem = divmod(seq_len, packages_num) - # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] - lengths = [base + 1 if i < rem else base for i in range(packages_num)] - assert sum(lengths) == seq_len - assert len(lengths) == packages_num - return lengths - - -def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: - """ - Materialize meta parameters on the requested device for KDA mixer layers. - """ - for name, param in module.named_parameters(): - if param.device.type != "meta": - continue - param_data = torch.empty_like(param, device=device) - param.init_parameter(param_data, distributed) - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - target = module - if module_path is not None: - for part in module_path.split("."): - target = getattr(target, part) - new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - new_param.grad = None - new_param.grad_buffer = torch.zeros_like(param_data) - new_param.param_grad_is_zero = True - target._parameters[param_name] = new_param - - -def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: - return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad - - -# TODO: include mamba varlen -@pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.parametrize( - "config, sequence_first", - [ - pytest.param( - GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), - False, - marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" - ), - ), - pytest.param( - GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), - True, - marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" - ), - ), - pytest.param( - KimiDeltaAttentionConfig(heads=4, head_dim=16), - False, - marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), - ), - pytest.param( - KimiDeltaAttentionConfig(heads=4, head_dim=16), - True, - marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), - ), - ], -) -def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): - """ - Check that Gated Delta Net forward/backward match with and without packing. - """ - device = torch.device("cuda") - dtype = torch.float16 - hidden_size = 32 - hidden_dim = TensorDim("hidden", hidden_size) - mixer_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - mixer_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - mixer_packed.setup(distributed) - mixer_ref.setup(distributed) - _materialize_mixer_tensors(mixer_packed, distributed, device) - _materialize_mixer_tensors(mixer_ref, distributed, device) - mixer_ref.load_state_dict(mixer_packed.state_dict()) - mixer_packed.to(device=device, dtype=dtype) - mixer_ref.to(device=device, dtype=dtype) - - batch_size = 2 # cu_seqlens path requires flattened batch - seq_len = 15 - packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) - sequence_lengths = [ - generate_random_seq_len(seq_len, packages_num=packages_num[i].item()) for i in range(batch_size) - ] - - packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) - if sequence_first: - packed = packed.transpose(0, 1) - - kwargs_packed = { - BlockKwargs.device: device, - BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.sequence_first: sequence_first, - BlockKwargs.hidden_dims: (hidden_dim,), - } - mixer_packed.preprocess(kwargs_packed) - - kwargs_ref = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), - } - - out_packed = mixer_packed(packed, kwargs_packed) - if sequence_first: - out_packed = out_packed.transpose(0, 1) - # Run reference path separately per sequence without varlen packing, then concatenate. - ref_outs = [] - for b in range(batch_size): - out_batch = [] - length = sequence_lengths[b] - if sequence_first: - ref_seqs = torch.split(packed[:, b].unsqueeze(0), length, dim=1) - else: - ref_seqs = torch.split(packed[b].unsqueeze(0), length, dim=1) - for seq in ref_seqs: - kwargs_ref_seq = { - **kwargs_ref, - BlockKwargs.sequence_lengths: [seq.shape[1]], - } - out_batch.append(mixer_ref(seq, kwargs_ref_seq)) - ref_outs.append(torch.cat(out_batch, dim=1)) - out_ref = torch.cat(ref_outs, dim=0) - out_ref_packed = out_ref - - assert out_ref_packed.shape == out_packed.shape - assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) - - out_packed.sum().backward() - out_ref_packed.sum().backward() - - for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): - if param.requires_grad: - Assert.rms_close_relative(_param_grad(param), _param_grad(param_ref), 1e-3, 1e-3, msg=name) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index fac595905..dd414d901 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -79,7 +79,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon SIMPLE_TESTING_CONFIG = DistributedTestingConfig( name="simple", compare=None, - config_args=["training.num_workers=2"], + config_args=[], num_gpus=1, ) @@ -87,7 +87,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="bf16", compare="simple", - config_args=["model.distributed.compute_dtype=bf16"], + # Also tests parallel data loader. + config_args=["model.distributed.compute_dtype=bf16", "training.num_workers=2"], num_gpus=1, compare_config=_bf16_compare, ), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 736e4897e..7f66bec7d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -288,8 +288,8 @@ def _update_and_add_testing_config( ], checkpoint_format=None, groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.main, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: PP checkpoint failing for tied weights. ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, @@ -636,12 +636,12 @@ def _update_and_add_testing_config( checkpoint_format=MixtralCheckpointFormat, # TODO: New base image broke mixtral groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, - ModelTestingGroup.megatron: ModelTestingGroupAction.normal, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.megatron: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=2.0, ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 098f0240e..5ff238756 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -12,6 +12,7 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import set_model_names from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig @@ -43,6 +44,12 @@ def get_stage( tied_parameter_duplicates: typing.Iterable[str] = (), tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, ): + + for layer in layers: + if not layer._is_setup: + layer.setup(distributed) + # Normally called in `BaseModelConfig.get_base_model`, but may be missing here. + set_model_names(torch.nn.ModuleList(layers)) # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), From 916af7a163f856d8831de2133d9e4b9c53a6bd20 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Dec 2025 18:40:50 -0500 Subject: [PATCH 08/14] cleanup --- fast_llm/layers/attention/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 3724ee413..073599479 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -361,7 +361,6 @@ def _forward( key_value = key_value.transpose(0, 1).contiguous() key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) - print("AAAAA", input_.shape, query.shape, key.shape) query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) From bd7a8e6797cc5e4f8ba204bb109fcd8b364e399e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Dec 2025 19:03:09 -0500 Subject: [PATCH 09/14] cleanup --- tests/layers/test_gdn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/layers/test_gdn.py b/tests/layers/test_gdn.py index 60dd538e4..2845fc404 100644 --- a/tests/layers/test_gdn.py +++ b/tests/layers/test_gdn.py @@ -84,7 +84,7 @@ def test_fast_llm_gdn_matches_apriel2_forward(): hidden_states = torch.randn(1, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) hf_state_dict = hf_layer.state_dict() for k, p in fast_layer.state_dict().items(): - Assert.rms_close_relative(p, hf_state_dict[param_map[k]], 1e-5, 1e-5) + Assert.rms_close_relative(p, hf_state_dict[k], 1e-5, 1e-5) hf_out = hf_layer(hidden_states)[0] From 660fecc3093cf21e8a71b624ef9b6cb8e56b59f9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Dec 2025 19:15:49 -0500 Subject: [PATCH 10/14] fix --- tests/functional/test_functional.py | 1 + tests/utils/model_configs.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 05fafe7a9..c48a0a531 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -136,6 +136,7 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow +@pytest.mark.skip("Dropless MoE is broken") @requires_cuda def test_dropless_mlp(): num_experts = 4 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 736e4897e..7f66bec7d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -288,8 +288,8 @@ def _update_and_add_testing_config( ], checkpoint_format=None, groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.main, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.main, + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, # TODO: PP checkpoint failing for tied weights. ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, @@ -636,12 +636,12 @@ def _update_and_add_testing_config( checkpoint_format=MixtralCheckpointFormat, # TODO: New base image broke mixtral groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.basic: ModelTestingGroupAction.broken, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.broken, + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, - ModelTestingGroup.megatron: ModelTestingGroupAction.normal, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + ModelTestingGroup.megatron: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.broken, }, compare_factor=2.0, ) From db93bb56dd90ed02253ace0498030e4d54687561 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 9 Dec 2025 20:31:29 -0500 Subject: [PATCH 11/14] fixes --- fast_llm/data/sample/token.py | 2 +- tests/layers/test_gdn_equivalence.py | 4 ---- tests/utils/model_configs.py | 17 ++++++++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index b456baaa2..1bc9ef1a1 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -127,7 +127,7 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: begin_ = self._size_cumsums[index].item() # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. # Convert begin and end to int to avoid numpy dtype overflow when adding to begin_ - return TokenSample(self._tokens[begin_ + int(begin) : begin_ + int(end)].to(torch.int64), [end - begin]) + return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) def get_document_sizes(self) -> torch.Tensor: return self._size_cumsums[1:] - self._size_cumsums[:-1] diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py index dfbaab9c9..803d2eaac 100644 --- a/tests/layers/test_gdn_equivalence.py +++ b/tests/layers/test_gdn_equivalence.py @@ -100,7 +100,3 @@ def test_fast_llm_gdn_matches_apriel2_forward(): fast_out, _ = fast_layer(hidden_states, fast_kwargs) torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7f66bec7d..9231168aa 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -40,6 +40,9 @@ _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) +TP_NO_STP = r"(?:^|(?<=[^s]))tp" + + class ModelTestingGroup(enum.StrEnum): basic = "basic" checkpoint = "checkpoint" @@ -863,7 +866,7 @@ def _update_and_add_testing_config( compare_factor=10.0, # High diff for fp16 and bf16 due to rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! - skip_tests=("sdp", "ms", r"(? Date: Wed, 10 Dec 2025 19:03:53 -0500 Subject: [PATCH 12/14] misc --- tests/functional/test_cross_entropy.py | 340 +++++++++++-------------- tests/layers/test_attention.py | 47 +--- tests/layers/test_gdn.py | 102 -------- tests/layers/test_kda.py | 138 ---------- tests/layers/test_lm_head.py | 2 + tests/layers/test_ssm.py | 132 ++++++++++ tests/utils/model_configs.py | 7 +- tests/utils/utils.py | 6 +- 8 files changed, 290 insertions(+), 484 deletions(-) delete mode 100644 tests/layers/test_gdn.py delete mode 100644 tests/layers/test_kda.py create mode 100644 tests/layers/test_ssm.py diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 088250885..3c6facafe 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -1,10 +1,10 @@ import os import tempfile +import traceback +import typing import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward @@ -12,174 +12,39 @@ from tests.utils.utils import requires_cuda -def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): - fn = combined_worker - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) - try: - fn(rank, dist.group.WORLD, *fn_args) - finally: - dist.destroy_process_group() - - -def _spawn_dist(world_size: int, fn, *fn_args): - """ - Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. - """ - with tempfile.NamedTemporaryFile(delete=False) as tmp: - init_method = f"file://{tmp.name}" - - try: - mp.spawn( - _mp_worker, - args=(world_size, init_method, fn_args), - nprocs=world_size, - join=True, - start_method="spawn", - ) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) - - -def _assert_loss_and_grad(logits, loss, grad): - assert isinstance(loss, torch.Tensor) - assert loss.dim() == 0 - assert grad is None or grad.shape == logits.shape - assert torch.isfinite(loss) - if grad is not None: - assert torch.isfinite(grad).all() - - -@pytest.mark.slow -@pytest.mark.parametrize("use_mask", [False, True]) -def test_reverse_kl_no_tp(use_mask): - batch_size, seq_len, vocab_size = 2, 3, 5 - logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) - target = torch.randn(batch_size, seq_len, vocab_size) - loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None - - loss, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - group=None, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) - _assert_loss_and_grad(logits, loss, grad) - - # Manual reference: sum over vocab then average over valid tokens. - teacher_log_probs = torch.log_softmax(target, dim=-1) - student_log_probs = torch.log_softmax(logits, dim=-1) - per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - valid_tokens = loss_mask.sum() +def _get_cross_entropy_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, device="cuda") / 3 + loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) + loss_mask = None else: - valid_tokens = logits.shape[0] * logits.shape[1] - reference = per_sample.sum() / valid_tokens - Assert.rms_close_relative(loss, reference, 1e-6, 1e-6) - - -def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - world_size = dist.get_world_size(group) - - batch_size, seq_len, vocab_per_rank = 2, 3, 5 - full_vocab = vocab_per_rank * world_size - full_logits = torch.randn(batch_size, seq_len, full_vocab) - full_target = torch.randn(batch_size, seq_len, full_vocab) - full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None + target = torch.randn(256, num_columns, device="cuda") + logits = target + logits_var + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) + return logits, target, loss_mask - start = rank * vocab_per_rank - end = start + vocab_per_rank - logits = full_logits[:, :, start:end].clone().requires_grad_(True) - target = full_target[:, :, start:end].clone() - loss_mask = full_mask.clone() if full_mask is not None else None - loss, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=None, - group=group, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) - _assert_loss_and_grad(logits, loss, grad) - - if rank == 0: - ref_loss, _ = reverse_kl_forward_backward( - logits=full_logits.clone(), - target=full_target.clone(), - loss_mask=full_mask.clone() if full_mask is not None else None, - grad_output=None, - group=None, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) - else: - ref_loss = torch.zeros_like(loss) - dist.broadcast(ref_loss, src=0, group=group) +def _compare_cross_entropy_outputs( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, +): Assert.rms_close_relative(loss, ref_loss, 1e-6, 1e-6) - - -def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - world_size = dist.get_world_size(group) - - batch_size, seq_len, vocab_per_rank = 2, 3, 5 - full_vocab = vocab_per_rank * world_size - full_logits = torch.randn(batch_size, seq_len, full_vocab) - full_target = torch.randn(batch_size, seq_len, full_vocab) - full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None - - start = rank * vocab_per_rank - end = start + vocab_per_rank - logits = full_logits[:, :, start:end].clone().requires_grad_(True) - target = full_target[:, :, start:end].clone() - loss_mask = full_mask.clone() if full_mask is not None else None - - loss, grad = cross_entropy_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=None, - group=group, - implementation=CrossEntropyImpl.fused, - target_format=TargetFormat.logits, - logits_scale_factor=1.0, - ) - _assert_loss_and_grad(logits, loss, grad) - - if rank == 0: - ref_loss, _ = cross_entropy_forward_backward( - logits=full_logits.clone(), - target=full_target.clone(), - loss_mask=full_mask.clone() if full_mask is not None else None, - grad_output=None, - group=None, - implementation=CrossEntropyImpl.fused, - target_format=TargetFormat.logits, - logits_scale_factor=1.0, - ) + if has_grad: + Assert.rms_close_relative(grad, ref_grad, 1e-6, 1e-6) else: - ref_loss = torch.zeros_like(loss) - dist.broadcast(ref_loss, src=0, group=group) - Assert.rms_close_relative(loss, ref_loss, 1e-6, 1e-6) - - -def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - _vocab_tp_worker(rank, group, use_mask) - _ce_vocab_tp_worker(rank, group, use_mask) - - -# TODO: maybe merge these tests using same parametrization -@pytest.mark.slow -@pytest.mark.parametrize("use_mask", [True, False]) -def test_distillation_losses(use_mask): - _spawn_dist(2, combined_worker, use_mask) + assert grad is None + assert ref_grad is None @requires_cuda @@ -201,21 +66,7 @@ def test_distillation_losses(use_mask): def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") - logits = (torch.nn.functional.one_hot(target, num_columns) + logits_var).requires_grad_() - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") - logits = (target + logits_var).requires_grad_() - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - + logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { "logits": logits, "target": target, @@ -226,26 +77,129 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski } # Torch serves as the reference implementation. out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) - out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) + Assert.rms_close(out_fused, out_torch, 5e-3) - if grad_output is None: - assert grad_torch is None - assert grad_fused is None - else: - Assert.rms_close(grad_fused, grad_torch, 5e-3) + _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch) if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) else: out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - if grad_output is None: - assert grad_triton is None - else: - Assert.rms_close(grad_triton, grad_torch, 5e-3) - Assert.rms_close(out_triton, out_torch, 5e-3) + _compare_cross_entropy_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) + + +def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tensor, loss_mask: torch.Tensor): + # Manual reference: sum over vocab then average over valid tokens. + logits = logits.detach().requires_grad_() + teacher_log_probs = torch.log_softmax(target, dim=-1) + student_log_probs = torch.log_softmax(logits, dim=-1) + per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + valid_tokens = loss_mask.sum() + else: + valid_tokens = logits.shape[0] * logits.shape[1] + output = per_sample.sum() / valid_tokens + output.backward() + return output, logits.grad + + +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_reverse_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(10000, loss_masking, target_format) + out, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref) + + +def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): + try: + torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) + fn_args[0](rank, torch.distributed.group.WORLD, *fn_args[1:]) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def _spawn_dist(world_size: int, *fn_args): + """ + Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. + """ + with tempfile.NamedTemporaryFile(delete=False) as tmp: + init_method = f"file://{tmp.name}" + + try: + torch.multiprocessing.spawn( + _mp_worker, + args=(world_size, init_method, fn_args), + nprocs=world_size, + join=True, + start_method="spawn", + ) + finally: + if os.path.exists(tmp.name): + os.remove(tmp.name) + +def _compare_parallel_cross_entropy( + rank: int, + group: torch.distributed.ProcessGroup, + target_format: TargetFormat, + function: typing.Callable, + loss_masking: bool, +): + # Ensure all workers have the same inputs. + torch.manual_seed(0) + world_size = torch.distributed.get_world_size(group) + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + + out, grad = function( + logits=logits.chunk(world_size, 1)[rank], + target=target.chunk(world_size, 1)[rank], + loss_mask=loss_mask, + grad_output=1, + group=group, + target_format=target_format, + ) -if __name__ == "__main__": - pytest.main([__file__]) + out_ref, grad_ref = function( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1, + target_format=target_format, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref) + + +def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): + success = True + for function in (cross_entropy_forward_backward, reverse_kl_forward_backward): + for target_format in (TargetFormat.logits,): + for loss_masking in [True, False]: + try: + _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) + except Exception: + print(f"{function}, target_format, use_mask={loss_masking}") + traceback.print_exc() + success = False + if not success: + raise RuntimeError("Test failed") + + +@pytest.mark.slow +def test_distillation_losses(): + _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index f1409b95c..508597173 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -4,56 +4,11 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.utils import Assert from tests.utils.utils import requires_cuda -# TODO: ====== micro-sequence ====== -@pytest.mark.skip -def test_varlen_preprocessing(): - sequence_lengths = [[8, 13, 4, 11], [11, 16, 9]] - # First micro-sequence: - # [0...7,0...3] + [0...10,0] -> [0,8,12,23,24] - # Second micro-sequence: - # [4...12,0...2] + [1...12] -> [0,9,12,24] - # Third micro-sequence: - # [3,0...10] + [13...15, 0...8] -> [1,12,15,24] - cumulative_sequences_q = [ - torch.tensor([0, 8, 12, 23, 24], dtype=torch.int32), - torch.tensor([0, 0, 9, 12, 12, 24], dtype=torch.int32), - torch.tensor([0, 0, 0, 1, 12, 12, 15, 24], dtype=torch.int32), - ] - cumulative_sequences_k = [ - torch.tensor([0, 8, 12, 23, 24], dtype=torch.int32), - torch.tensor([0, 8, 21, 24, 35, 48], dtype=torch.int32), - torch.tensor([0, 8, 21, 25, 36, 47, 63, 72], dtype=torch.int32), - ] - micro_sequence_length = 12 - sequence_length = 36 - attention = Attention( - AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), - DistributedConfig(compute_dtype="bfloat16"), - hidden_dim=TensorDim("", 1), - lr_scale=None, - peft=None, - ) - for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): - kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), - AttentionKwargs.sequence_k_dim: TensorDim( - BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length - ), - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_lengths: sequence_lengths, - AttentionKwargs.device: torch.device("cpu"), - } - attention.preprocess(kwargs) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) - - @requires_cuda @pytest.mark.parametrize("cross_document_attention", (True, False)) @pytest.mark.parametrize(("causal", "window_size"), ((True, None), (True, 50), (False, None))) diff --git a/tests/layers/test_gdn.py b/tests/layers/test_gdn.py deleted file mode 100644 index 2845fc404..000000000 --- a/tests/layers/test_gdn.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -VOCAB_SIZE = 500 -HIDDEN_SIZE = 16 -SEQ_LEN = 65 -NUM_V_HEADS = 4 -NUM_K_HEADS = 2 -HEAD_DIM = 4 -KERNEL_SIZE = 4 - - -@pytest.mark.slow -@requires_cuda -def test_fast_llm_gdn_matches_apriel2_forward(): - device = torch.device("cuda") - dtype = torch.bfloat16 - - config_gdn = { - "value_heads": NUM_V_HEADS, - "key_heads": NUM_K_HEADS, - "key_head_dim": HEAD_DIM, - "value_head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "norm_eps": 1e-5, - } - - hf_layer = ( - Apriel2GatedDeltaNet(HIDDEN_SIZE, config_gdn, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype).eval() - ) - - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": { - "mixer": { - "type": "gdn", - "value_heads": NUM_V_HEADS, - "key_heads": NUM_K_HEADS, - "key_head_dim": HEAD_DIM, - "value_head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "normalization": {"epsilon": 1e-5}, - } - }, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype).eval() - - with torch.no_grad(): - fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) - fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) - fast_layer.convolution.weight.copy_(hf_layer.convolution.weight) - if fast_layer.convolution.bias is not None and hf_layer.convolution.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.convolution.bias) - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.dt_bias.copy_(hf_layer.dt_bias) - fast_layer.norm.weight.copy_(hf_layer.norm.weight) - - hidden_states = torch.randn(1, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - hf_state_dict = hf_layer.state_dict() - for k, p in fast_layer.state_dict().items(): - Assert.rms_close_relative(p, hf_state_dict[k], 1e-5, 1e-5) - - hf_out = hf_layer(hidden_states)[0] - - sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - BlockKwargs.sequence_length: SEQ_LEN, - BlockKwargs.sequence_lengths: sequence_lengths, - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) diff --git a/tests/layers/test_kda.py b/tests/layers/test_kda.py deleted file mode 100644 index 477eefaa4..000000000 --- a/tests/layers/test_kda.py +++ /dev/null @@ -1,138 +0,0 @@ -import pytest -import torch - -import fast_llm.layers.ssm.kda as kda_module -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -try: - from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig - from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention -except ImportError: - AprielHybridSSMConfig, KimiDeltaAttention = None, None - -VOCAB_SIZE = 500 -HIDDEN_SIZE = 16 -SEQ_LEN = 65 -NUM_HEADS = 4 -HEAD_DIM = 4 -KERNEL_SIZE = 4 - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") -@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") -def test_fast_llm_kda_matches_apriel_forward(): - device = torch.device("cuda") - dtype = torch.bfloat16 - - hf_config = AprielHybridSSMConfig( - hidden_size=HIDDEN_SIZE, - num_attention_heads=NUM_HEADS, - num_hidden_layers=1, - rms_norm_eps=1e-6, - ) - hf_config.short_conv_kernel_size = KERNEL_SIZE - hf_config.head_dim = HEAD_DIM - hf_config.num_heads = NUM_HEADS - hf_layer = KimiDeltaAttention(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() - - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": { - "mixer": { - "type": "kda", - "heads": NUM_HEADS, - "head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "normalization": {"epsilon": hf_config.rms_norm_eps, "activation": "sigmoid"}, - } - }, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype).eval() - - with torch.no_grad(): - fast_layer.q_proj.weight.copy_(hf_layer.q_proj.weight) - fast_layer.k_proj.weight.copy_(hf_layer.k_proj.weight) - fast_layer.v_proj.weight.copy_(hf_layer.v_proj.weight) - fast_layer.q_conv.weight.copy_(hf_layer.q_conv1d.weight) - fast_layer.k_conv.weight.copy_(hf_layer.k_conv1d.weight) - fast_layer.v_conv.weight.copy_(hf_layer.v_conv1d.weight) - if fast_layer.q_conv.bias is not None and hf_layer.q_conv1d.bias is not None: - fast_layer.q_conv.bias.copy_(hf_layer.q_conv1d.bias) - if fast_layer.k_conv.bias is not None and hf_layer.k_conv1d.bias is not None: - fast_layer.k_conv.bias.copy_(hf_layer.k_conv1d.bias) - if fast_layer.v_conv.bias is not None and hf_layer.v_conv1d.bias is not None: - fast_layer.v_conv.bias.copy_(hf_layer.v_conv1d.bias) - fast_layer.f_a_proj.weight.copy_(hf_layer.f_a_proj.weight) - fast_layer.f_b_proj.weight.copy_(hf_layer.f_b_proj.weight) - fast_layer.g_a_proj.weight.copy_(hf_layer.g_a_proj.weight) - fast_layer.g_b_proj.weight.copy_(hf_layer.g_b_proj.weight) - fast_layer.beta_proj.weight.copy_(hf_layer.b_proj.weight) - fast_layer.o_proj.weight.copy_(hf_layer.o_proj.weight) - fast_layer.A_log.copy_(hf_layer.A_log.reshape_as(fast_layer.A_log)) - fast_layer.dt_bias.copy_(hf_layer.dt_bias.reshape_as(fast_layer.dt_bias)) - fast_layer.norm.weight.copy_(hf_layer.o_norm.weight) - - param_map = { - "q_proj.weight": "q_proj.weight", - "k_proj.weight": "k_proj.weight", - "v_proj.weight": "v_proj.weight", - "q_conv.weight": "q_conv1d.weight", - "k_conv.weight": "k_conv1d.weight", - "v_conv.weight": "v_conv1d.weight", - "f_a_proj.weight": "f_a_proj.weight", - "f_b_proj.weight": "f_b_proj.weight", - "g_a_proj.weight": "g_a_proj.weight", - "g_b_proj.weight": "g_b_proj.weight", - "beta_proj.weight": "b_proj.weight", - "o_proj.weight": "o_proj.weight", - "A_log": "A_log", - "dt_bias": "dt_bias", - "norm.weight": "o_norm.weight", - } - hf_params = hf_layer.state_dict() - for fast_name, fast_param in fast_layer.state_dict().items(): - hf_param = hf_params[param_map[fast_name]] - if fast_param.shape != hf_param.shape: - Assert.eq(fast_param.numel(), hf_param.numel(), msg=fast_name) - hf_param = hf_param.reshape_as(fast_param) - Assert.rms_close_relative(fast_param, hf_param, 1e-5, 1e-5, msg=fast_name) - - hidden_states = torch.randn(2, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - hf_layer.training = True - hf_out = hf_layer(hidden_states) - - sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 108261d4b..623a30d82 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -264,6 +264,8 @@ def test_lm_head( distributed, tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + # Names must be kept as-is for tied weights. + set_names=False, ) # Get reference outputs and grads diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py new file mode 100644 index 000000000..a1d0df52b --- /dev/null +++ b/tests/layers/test_ssm.py @@ -0,0 +1,132 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig +from fast_llm.utils import Assert +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig +from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention +from tests.utils.utils import get_stage, requires_cuda + +HIDDEN_SIZE = 16 +SEQ_LEN = 65 +NUM_HEADS = 4 +NUM_V_HEADS = 4 +NUM_K_HEADS = 2 +HEAD_DIM = 4 +KERNEL_SIZE = 4 + + +def _compare_mixers(fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str]): + distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16)) + fast_llm_layer = fast_llm_config.get_layer( + distributed_config, + TensorDim("", HIDDEN_SIZE), + lr_scale=None, + peft=None, + ).eval() + get_stage([fast_llm_layer], distributed, [], {}) + hf_layer = hf_layer.to(device=distributed.device, dtype=distributed_config.compute_dtype.torch) + + with torch.no_grad(): + hf_state_dict = hf_layer.state_dict() + for name, param in fast_llm_layer.named_parameters(): + param.copy_(hf_state_dict[param_map.get(name, name)].view_as(param)) + + hf_params = hf_layer.state_dict() + for name, fast_param in fast_llm_layer.state_dict().items(): + hf_param = hf_params[param_map.get(name, name)] + Assert.rms_close_relative(fast_param, hf_param.view_as(fast_param), 1e-5, 1e-5, msg=name) + + hidden_states = torch.randn( + 2, + SEQ_LEN, + HIDDEN_SIZE, + device=distributed.device, + dtype=distributed_config.compute_dtype.torch, + requires_grad=False, + ) + + hf_layer.train() + hf_out = hf_layer(hidden_states) + if isinstance(hf_out, tuple): + (hf_out,) = hf_out + + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] + fast_kwargs = { + BlockKwargs.device: distributed.device, + BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + BlockKwargs.sequence_q_dim: TensorDim("", SEQ_LEN), + BlockKwargs.sequence_k_dim: TensorDim("", SEQ_LEN), + } + fast_llm_layer.train() + fast_llm_layer.preprocess(fast_kwargs) + fast_out = fast_llm_layer(hidden_states, fast_kwargs) + print("AAAA", fast_out.shape, [x.shape for x in hf_out]) + + Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) + + +@pytest.mark.slow +@requires_cuda +def test_gdn(): + device = torch.device("cuda") + dtype = torch.bfloat16 + + config_common = { + "value_heads": NUM_V_HEADS, + "key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + } + + hf_layer = ( + Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) + .to(device=device, dtype=dtype) + .eval() + ) + fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) + _compare_mixers(fast_llm_config, hf_layer, {}) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") +@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") +def test_kda(): + hf_config = AprielHybridSSMConfig( + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + num_hidden_layers=1, + rms_norm_eps=1e-6, + ) + hf_config.short_conv_kernel_size = KERNEL_SIZE + hf_config.head_dim = HEAD_DIM + hf_config.num_heads = NUM_HEADS + hf_layer = KimiDeltaAttention(hf_config, layer_idx=0) + + fast_llm_config = KimiDeltaAttentionConfig( + heads=NUM_HEADS, + head_dim=HEAD_DIM, + convolution_layer={"kernel_size": KERNEL_SIZE, "activation": "silu"}, + normalization={"epsilon": 1e-6, "activation": "sigmoid"}, + ) + + param_map = { + "q_conv.weight": "q_conv1d.weight", + "k_conv.weight": "k_conv1d.weight", + "v_conv.weight": "v_conv1d.weight", + "beta_proj.weight": "b_proj.weight", + "norm.weight": "o_norm.weight", + } + _compare_mixers(fast_llm_config, hf_layer, param_map) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 9231168aa..63f977471 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -41,6 +41,7 @@ TP_NO_STP = r"(?:^|(?<=[^s]))tp" +GRAD_ACC = r"df(?!16)|bf" class ModelTestingGroup(enum.StrEnum): @@ -619,7 +620,7 @@ def _update_and_add_testing_config( compare_factor=8, # Modes not supported with reference models and/or activation distillation. # TODO: Fix gradient accumulation and fp16, add TP support. - skip_tests=("sdp", "ms", "pp", "tp", "df", "bf", "fp16"), + skip_tests=("sdp", "ms", "pp", "tp", GRAD_ACC, "fp16"), ) _update_and_add_testing_config( @@ -813,7 +814,7 @@ def _update_and_add_testing_config( compare_factor=6.0, # Micro-sequence split and sequence-first not supported. # TODO: Gradient accumulation works but comparison is broken. - skip_tests=("sdp", "ms", "bf4", "df"), + skip_tests=("sdp", "ms", GRAD_ACC), auto_model_class=transformers.AutoModelForImageTextToText, ) @@ -1018,7 +1019,7 @@ def _update_and_add_testing_config( compare_factor=6.0, # Micro-sequence split and sequence-first not supported for Mamba. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). - skip_tests=("sdp", "ms", "bf4", "df4", TP_NO_STP), + skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP), ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 5ff238756..3b79f7607 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -43,13 +43,15 @@ def get_stage( distributed: Distributed, tied_parameter_duplicates: typing.Iterable[str] = (), tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, + set_names: bool = True, ): for layer in layers: if not layer._is_setup: layer.setup(distributed) - # Normally called in `BaseModelConfig.get_base_model`, but may be missing here. - set_model_names(torch.nn.ModuleList(layers)) + if set_names: + # Normally called in `BaseModelConfig.get_base_model`, but may be missing here. + set_model_names(torch.nn.ModuleList(layers)) # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), From e5fe8b2db76c6fe349adadfad9820a3e02037b5b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 10 Dec 2025 23:07:57 -0500 Subject: [PATCH 13/14] stuff --- tests/functional/test_cross_entropy.py | 1 - tests/layers/test_mamba_equivalence.py | 175 ------------------------- tests/layers/test_ssm.py | 69 ++++++++-- tests/layers/test_varlen.py | 9 +- 4 files changed, 66 insertions(+), 188 deletions(-) delete mode 100644 tests/layers/test_mamba_equivalence.py diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 3c6facafe..4524e515e 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -79,7 +79,6 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) - Assert.rms_close(out_fused, out_torch, 5e-3) _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch) if num_columns > 65536: diff --git a/tests/layers/test_mamba_equivalence.py b/tests/layers/test_mamba_equivalence.py deleted file mode 100644 index ccf2dba41..000000000 --- a/tests/layers/test_mamba_equivalence.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Test numerical equivalence between Fast-LLM Mamba2 and Apriel2 Mamba. - -Note: Fast-LLM's "mamba_2" type is actually a Mamba 1 variant (not the true Mamba 2 -architecture). It corresponds to the HuggingFace/Apriel Mamba implementation. -""" - -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm.config import Mamba2Config # Ensures mamba_2 type is registered -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -# Ensure Mamba2Config is registered for dynamic type lookup -_ = Mamba2Config - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Mamba -except ImportError: - Apriel2Mamba = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn - - _mamba_kernel_available = True -except (ImportError, RuntimeError): - _mamba_kernel_available = False - -# Test constants -VOCAB_SIZE = 500 -HIDDEN_SIZE = 64 -SEQ_LEN = 65 -BATCH_SIZE = 2 -D_INNER = 128 -D_XB = 64 -D_STATE = 16 -D_CONV = 4 -DT_RANK = 4 - - -def _copy_weights(fast_layer, hf_layer): - """Copy weights from Apriel2 Mamba to Fast-LLM Mamba2.""" - with torch.no_grad(): - # Main projections - fast_layer.in_proj.weight.copy_(hf_layer.in_proj.weight) - if fast_layer.in_proj.bias is not None and hf_layer.in_proj.bias is not None: - fast_layer.in_proj.bias.copy_(hf_layer.in_proj.bias) - - # DT projections - fast_layer.dt_in_proj.weight.copy_(hf_layer.dt_in_proj.weight) - if fast_layer.dt_in_proj.bias is not None and hf_layer.dt_in_proj.bias is not None: - fast_layer.dt_in_proj.bias.copy_(hf_layer.dt_in_proj.bias) - - fast_layer.dt_proj.weight.copy_(hf_layer.dt_proj.weight) - if fast_layer.dt_proj.bias is not None and hf_layer.dt_proj.bias is not None: - fast_layer.dt_proj.bias.copy_(hf_layer.dt_proj.bias) - - # Convolution (Fast-LLM uses "convolution", Apriel2 uses "conv1d") - fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) - if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) - - # SSM parameters - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.D.copy_(hf_layer.D) - - # Output projection - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - if fast_layer.out_proj.bias is not None and hf_layer.out_proj.bias is not None: - fast_layer.out_proj.bias.copy_(hf_layer.out_proj.bias) - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") -@pytest.mark.skipif(not _mamba_kernel_available, reason="Mamba CUDA kernels not available") -@pytest.mark.parametrize("add_linear_biases", [True, False]) -@pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) -def test_fast_llm_mamba2_matches_apriel2(add_linear_biases, repeat_kv_before_conv): - """Verify Fast-LLM Mamba2 output matches Apriel2 Mamba. - - Args: - add_linear_biases: Whether to add biases to linear layers. - repeat_kv_before_conv: Whether to repeat KV before or after convolution. - """ - torch.manual_seed(42) - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Create Apriel2 Mamba layer - # Note: Apriel2 has separate conv_bias and dt_proj_bias controls. - # We align them with Fast-LLM's single add_linear_biases flag. - mamba_config = { - "d_inner": D_INNER, - "d_xb": D_XB, - "state_size": D_STATE, - "d_conv": D_CONV, - "dt_rank": DT_RANK, - "conv_bias": add_linear_biases, - "dt_proj_bias": add_linear_biases, - "add_linear_biases": add_linear_biases, - "repeat_kv_before_conv": repeat_kv_before_conv, - "dt_min": 0.001, - "dt_max": 0.1, - "dt_init_floor": 1e-4, - } - hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype) - hf_layer.eval() - - # Create Fast-LLM Mamba2 layer - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": { - "mixer": { - "type": "mamba_2", - "d_inner": D_INNER, - "d_xb": D_XB, - "state_size": D_STATE, - "convolution_layer": {"kernel_size": D_CONV}, - "dt_rank": DT_RANK, - "add_linear_biases": add_linear_biases, - "repeat_kv_before_conv": repeat_kv_before_conv, - } - }, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype) - fast_layer.eval() - - # Copy weights - _copy_weights(fast_layer, hf_layer) - - # Verify key parameters match (not all names match between implementations) - Assert.all_equal(fast_layer.in_proj.weight, hf_layer.in_proj.weight) - Assert.all_equal(fast_layer.convolution.weight, hf_layer.conv1d.weight) - Assert.all_equal(fast_layer.A_log, hf_layer.A_log) - Assert.all_equal(fast_layer.D, hf_layer.D) - Assert.all_equal(fast_layer.out_proj.weight, hf_layer.out_proj.weight) - - # Forward passes - hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - - hf_out = hf_layer(hidden_states)[0] - - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - # Compare outputs (slightly looser tolerance for Mamba due to numerical differences) - Assert.rms_close(fast_out, hf_out, 1e-4) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index a1d0df52b..e6422c597 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -8,23 +8,20 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config from fast_llm.utils import Assert -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention from tests.utils.utils import get_stage, requires_cuda HIDDEN_SIZE = 16 SEQ_LEN = 65 -NUM_HEADS = 4 -NUM_V_HEADS = 4 -NUM_K_HEADS = 2 -HEAD_DIM = 4 -KERNEL_SIZE = 4 -def _compare_mixers(fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str]): +def _compare_mixers( + fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str], threshold=1e-5 +): distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16)) fast_llm_layer = fast_llm_config.get_layer( distributed_config, @@ -43,7 +40,7 @@ def _compare_mixers(fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, par hf_params = hf_layer.state_dict() for name, fast_param in fast_llm_layer.state_dict().items(): hf_param = hf_params[param_map.get(name, name)] - Assert.rms_close_relative(fast_param, hf_param.view_as(fast_param), 1e-5, 1e-5, msg=name) + Assert.rms_close_relative(fast_param, hf_param.view_as(fast_param), threshold, 1e-5, msg=name) hidden_states = torch.randn( 2, @@ -71,9 +68,8 @@ def _compare_mixers(fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, par fast_llm_layer.train() fast_llm_layer.preprocess(fast_kwargs) fast_out = fast_llm_layer(hidden_states, fast_kwargs) - print("AAAA", fast_out.shape, [x.shape for x in hf_out]) - Assert.rms_close_relative(fast_out, hf_out, 1e-5, 1e-5) + Assert.rms_close_relative(fast_out, hf_out, threshold, 1e-5) @pytest.mark.slow @@ -82,6 +78,11 @@ def test_gdn(): device = torch.device("cuda") dtype = torch.bfloat16 + NUM_V_HEADS = 4 + NUM_K_HEADS = 2 + HEAD_DIM = 4 + KERNEL_SIZE = 4 + config_common = { "value_heads": NUM_V_HEADS, "key_heads": NUM_K_HEADS, @@ -104,6 +105,10 @@ def test_gdn(): @pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_kda(): + NUM_HEADS = 4 + HEAD_DIM = 4 + KERNEL_SIZE = 4 + hf_config = AprielHybridSSMConfig( hidden_size=HIDDEN_SIZE, num_attention_heads=NUM_HEADS, @@ -130,3 +135,45 @@ def test_kda(): "norm.weight": "o_norm.weight", } _compare_mixers(fast_llm_config, hf_layer, param_map) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize("add_linear_biases", [True, False]) +@pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) +@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") +def test_mamba(add_linear_biases, repeat_kv_before_conv): + D_INNER = 128 + D_XB = 64 + D_STATE = 16 + D_CONV = 4 + DT_RANK = 4 + + config_common = { + "d_inner": D_INNER, + "d_xb": D_XB, + "state_size": D_STATE, + "dt_rank": DT_RANK, + "repeat_kv_before_conv": repeat_kv_before_conv, + "add_linear_biases": add_linear_biases, + } + + mamba_config = { + "conv_bias": add_linear_biases, + "dt_proj_bias": add_linear_biases, + **config_common, + } + hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0) + + # Create Fast-LLM Mamba2 layer + fast_llm_config = Mamba2Config( + convolution_layer={"kernel_size": D_CONV}, + **config_common, + ) + + param_map = { + "convolution.weight": "conv1d.weight", + "convolution.bias": "conv1d.bias", + } + # TODO: This is a really high threshold. + _compare_mixers(fast_llm_config, hf_layer, param_map, threshold=1e-2) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index b231aedae..32cd00cd2 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -22,7 +22,14 @@ "config", [ AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), - Mamba2Config(d_inner=128, d_xb=64, state_size=16, dt_rank=8, cross_document_attention=False), + Mamba2Config( + d_inner=128, + d_xb=64, + state_size=16, + dt_rank=8, + cross_document_attention=False, + marks=pytest.mark.skip("Mamba varlen kernel not available"), + ), pytest.param( GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), marks=pytest.mark.skipif( From 68f457b1f23947a6a4dcc39eead567cf44f2a262 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 11 Dec 2025 01:28:57 -0500 Subject: [PATCH 14/14] fixes --- tests/functional/test_cross_entropy.py | 45 ++++++++++++++------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 4524e515e..a23b49f8e 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -1,4 +1,5 @@ import os +import sys import tempfile import traceback import typing @@ -16,7 +17,7 @@ def _get_cross_entropy_inputs( num_columns: int, loss_masking: bool, target_format: TargetFormat ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, device="cuda") / 3 + logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None if target_format == TargetFormat.labels: target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") @@ -25,7 +26,7 @@ def _get_cross_entropy_inputs( logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) loss_mask = None else: - target = torch.randn(256, num_columns, device="cuda") + target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") logits = target + logits_var if target_format == TargetFormat.probabilities: target = torch.softmax(target, -1) @@ -38,10 +39,11 @@ def _compare_cross_entropy_outputs( has_grad: bool, grad: torch.Tensor | None, ref_grad: torch.Tensor | None, + threshold=1e-5, ): - Assert.rms_close_relative(loss, ref_loss, 1e-6, 1e-6) + Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) if has_grad: - Assert.rms_close_relative(grad, ref_grad, 1e-6, 1e-6) + Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) else: assert grad is None assert ref_grad is None @@ -79,30 +81,30 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) - _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch) + # TODO: Why is the error so high with logit scaling? + threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 + _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) else: out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - _compare_cross_entropy_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) + _compare_cross_entropy_outputs( + out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold + ) -def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tensor, loss_mask: torch.Tensor): +def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tensor, loss_mask: torch.Tensor | None): # Manual reference: sum over vocab then average over valid tokens. logits = logits.detach().requires_grad_() - teacher_log_probs = torch.log_softmax(target, dim=-1) - student_log_probs = torch.log_softmax(logits, dim=-1) per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True + torch.log_softmax(target.float(), dim=-1), + torch.log_softmax(logits.float(), dim=-1), + reduction="none", + log_target=True, ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - valid_tokens = loss_mask.sum() - else: - valid_tokens = logits.shape[0] * logits.shape[1] - output = per_sample.sum() / valid_tokens + output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum() output.backward() return output, logits.grad @@ -113,6 +115,7 @@ def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tenso @pytest.mark.parametrize("target_format", (TargetFormat.logits,)) def test_reverse_kl(loss_masking, target_format): logits, target, loss_mask = _get_cross_entropy_inputs(10000, loss_masking, target_format) + out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) out, grad = reverse_kl_forward_backward( logits=logits, target=target, @@ -120,8 +123,8 @@ def test_reverse_kl(loss_masking, target_format): grad_output=1.0, target_format=TargetFormat.logits, ) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref) + # TODO: Error looks + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): @@ -181,7 +184,7 @@ def _compare_parallel_cross_entropy( grad_output=1, target_format=target_format, ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): @@ -192,7 +195,9 @@ def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGr try: _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) except Exception: - print(f"{function}, target_format, use_mask={loss_masking}") + print( + f" >>>>>> Failed {function.__name__}, target_format, use_mask={loss_masking}", file=sys.stderr + ) traceback.print_exc() success = False if not success: