diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index f1eef47b9..0526b9dc2 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -8,6 +8,8 @@ from fast_llm.utils import Assert, compare_nested, log if typing.TYPE_CHECKING: + import torch + from fast_llm.engine.base_model.base_model import BaseModel @@ -58,6 +60,17 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: return self._serialize_value(value) +def set_model_names(model: "torch.nn.Module"): + from fast_llm.tensor import ParameterMeta + + for key, value in model.named_modules(): + value.module_name = key + for key, value in model.named_parameters(): + Assert.custom(isinstance, value, ParameterMeta) + # Rename to the parameter full name + value.tensor_name = key + + @config_class() class BaseModelConfig(ModuleConfig): """ @@ -65,17 +78,11 @@ class BaseModelConfig(ModuleConfig): """ def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel": - from fast_llm.tensor import ParameterMeta model = self.base_model_class(self, distributed_config) # Storing the global name of each module and tensor. # Done here because it needs to run right after `model.__init__()` - for key, value in model.named_modules(): - value.module_name = key - for key, value in model.named_parameters(): - Assert.custom(isinstance, value, ParameterMeta) - # Rename to the parameter full name - value.tensor_name = key + set_model_names(model) return model @property diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index a5a41f542..fbe6d3297 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -47,6 +47,7 @@ def __init__( ): self._name = name self._parameter_metas = {parameter_meta.tensor_name: parameter_meta for parameter_meta in parameter_metas} + Assert.eq(len(self._parameter_metas), len(parameter_metas)) # `set_model_names` ensure unique names. self._distributed_config = distributed_config self._fsdp_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.data) self._is_tied_weight_copy = is_tied_weight_copy diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 059469d94..073599479 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -11,11 +11,12 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div +from fast_llm.utils import div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -505,40 +506,10 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value def _preprocess_for_flash_attention(self, kwargs: dict[str, typing.Any]) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - if self._config.cross_document_attention: - return - device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device - - # TODO: ====== Fix (need to know how much first sequence was cropped) ====== - Assert.eq( - kwargs[AttentionKwargs.sequence_k_dim].global_size, kwargs[AttentionKwargs.sequence_q_dim].global_size - ) - - # TODO: Calculate these in batch preprocessing? - sequence_lengths_q = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[AttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths - ), - ], - dtype=torch.int32, - ) - max_sequence_length = sequence_lengths_q.max().item() - cu_seqlens_q = sequence_lengths_q.cumsum_(0).to(device) - max_seqlen_q = cu_seqlens_q.new_full((1,), max_sequence_length) - kwargs[AttentionKwargs.cu_seqlens_q] = cu_seqlens_q - kwargs[AttentionKwargs.cu_seqlens_k] = cu_seqlens_q - kwargs[AttentionKwargs.max_seqlen_q] = max_seqlen_q - kwargs[AttentionKwargs.max_seqlen_k] = max_seqlen_q + if not self._config.cross_document_attention: + preprocess_for_varlen( + kwargs, + kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_max_seqlen=True, + ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 6f589eeb4..626a8fde6 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -17,15 +17,20 @@ logger = logging.getLogger(__name__) -class AttentionKwargs(BlockKwargs): - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" +class MixerKwargs(BlockKwargs): cu_seqlens_q = "cu_seqlens_q" cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + seq_idx = "seq_idx" + position_ids = "position_ids" + + +class AttentionKwargs(MixerKwargs): + rotary_freq_q = "rotary_freq_q" + rotary_freq_k = "rotary_freq_k" + attention_mask = "attention_mask" + attention_mask_value = "attention_mask_value" # TODO: Review these presents = "presents" past_key_values = "past_key_values" diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py new file mode 100644 index 000000000..a9d9936c5 --- /dev/null +++ b/fast_llm/layers/attention/preprocessing.py @@ -0,0 +1,58 @@ +import typing + +import torch + +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.utils import Assert + + +def preprocess_for_varlen( + kwargs: dict[str, typing.Any], + device: torch.device, + return_cu_seqlens: bool = False, + return_max_seqlen: bool = False, + return_seq_idx: bool = False, + return_position_ids: bool = False, +) -> None: + """ + Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 + cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. + Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. + If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally + also contain previous tokens from the first document in micro-sequence. + We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. + """ + + # TODO: ====== Fix (need to know how much first sequence was cropped) ====== + Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size) + + sequence_lengths = [ + sequence_length + for sequence_lengths in kwargs[MixerKwargs.sequence_lengths] + for sequence_length in sequence_lengths + ] + if return_cu_seqlens: + cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum( + 0, dtype=torch.int32 + ) + kwargs[MixerKwargs.cu_seqlens_q] = cu_seqlens_q + kwargs[MixerKwargs.cu_seqlens_k] = cu_seqlens_q + if return_max_seqlen: + max_seqlen_q = torch.full((1,), max(sequence_lengths), dtype=torch.int32, device=device) + kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q + kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q + if return_seq_idx: + kwargs[MixerKwargs.seq_idx] = torch.cat( + [ + torch.full((sequence_length,), i, dtype=torch.int32, device=device) + for i, sequence_length in enumerate(sequence_lengths) + ] + ) + if return_position_ids: + kwargs[MixerKwargs.position_ids] = torch.cat( + [ + torch.arange(sequence_length, dtype=torch.int32, device=device) + for i, sequence_length in enumerate(sequence_lengths) + ] + ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 69018fd06..c336a7e99 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -34,7 +34,11 @@ def __init__( else self._forward_torch ) - def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: + if kwargs: + raise NotImplementedError( + f"Arguments {tuple(kwargs)} not implemented for torch implementation of 1d convolution." + ) return self._activation.activation_fn( torch.nn.functional.conv1d( input_, diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 450591216..f0e3a1529 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig @@ -20,11 +19,6 @@ from fast_llm.tensor import ParameterMeta -class LinearAttentionKwargs(BlockKwargs): - cu_seqlens = "cu_seqlens" - seq_idx = "seq_idx" - - @config_class(dynamic_type={MixerConfig: "gdn"}) class GatedDeltaNetConfig(MixerConfig): """ @@ -179,13 +173,6 @@ def layer_class(self) -> "type[KimiDeltaAttention]": return KimiDeltaAttention - def _validate(self) -> None: - with self._set_implicit_default(): - if "activation" not in self.normalization._explicit_fields: - self.normalization.activation = "sigmoid" - - super()._validate() - @config_class() class SSMConfig(MixerConfig): @@ -334,6 +321,12 @@ class Mamba2Config(MambaBaseConfig): desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", hint=FieldHint.architecture, ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", + hint=FieldHint.feature, + ) @property def layer_class(self) -> "type[Mamba2]": diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 096b2cceb..474108482 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -10,10 +10,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, LinearAttentionKwargs +from fast_llm.layers.ssm.config import GatedDeltaNetConfig from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -293,9 +295,6 @@ def _forward( # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) - projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) if sequence_first: @@ -315,9 +314,8 @@ def _forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2) - mixed_qkv = self.convolution( - mixed_qkv, seq_idx=seq_idx - ) # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + mixed_qkv = self.convolution(mixed_qkv, seq_idx=kwargs[MixerKwargs.seq_idx].unsqueeze(0)) mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, @@ -351,7 +349,7 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, + cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) z_shape_og = z.shape @@ -368,56 +366,13 @@ def _forward( return output - def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: - """ - Creates seqlens and cu_seqlens for packed forward. - This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. - Note: padding tokens are always on the right and get their own entry in LinearAttentionKwargs.sequence_lengths --> they are treated as seperate sequence. - - Sets: - - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token - - cu_seqlens to [N+1] tensor, where N is the total number of sequences in the batch, each element is the cumulative sequence length of packed sequences sofar - """ - - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - device = kwargs.get("device", None) - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - seqlens = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths - ), - ], - dtype=torch.int32, - ) - cu_seqlens = seqlens.cumsum_(0).to(device) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._preprocess_for_varlen(kwargs) + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_seq_idx=True, + ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 323e1ad13..270ac65bf 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -9,10 +9,12 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig from fast_llm.tensor import ParameterMeta, TensorMeta logger = logging.getLogger(__name__) @@ -229,8 +231,6 @@ def _forward( sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) # TODO: can be made more efficeint by rearranging hidden states directly and only once residual_dtype = hidden_states.dtype @@ -250,9 +250,10 @@ def _forward( # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) - q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) - k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) - v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) + seq_idx = kwargs[MixerKwargs.seq_idx].unsqueeze(0) + q = self._apply_conv(q, self.q_conv, seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) if sequence_first: @@ -281,7 +282,7 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, + cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) attn_out = attn_out.to(residual_dtype) @@ -303,44 +304,10 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() - def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - device = kwargs.get("device", None) - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - - seqlens = torch.tensor( - [ - 0, - *( - sequence_length - for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] - for sequence_length in sequence_lengths # bs - ), - ], - dtype=torch.int32, - ) - cu_seqlens = seqlens.cumsum_(0).to(device) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._preprocess_for_varlen(kwargs) + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_cu_seqlens=True, + return_seq_idx=True, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 616d1152f..6e0ae0c60 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import inspect import logging import typing @@ -8,6 +9,8 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import MixerKwargs +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -19,8 +22,17 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa _mamba_available = True + sig = inspect.signature(selective_scan_fn) + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + if "position_indices" in sig.parameters: + _mamba_varlen_available = True + else: + _mamba_varlen_available = False + except (ImportError, RuntimeError): _mamba_available = False + _mamba_varlen_available = False logger = logging.getLogger(__name__) @@ -181,15 +193,19 @@ def _forward( # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) + convolution_kwargs = ( + {} if self._config.cross_document_attention else {"seq_idx": kwargs[MixerKwargs.seq_idx].unsqueeze(0)} + ) if self._config.repeat_kv_before_conv: x = self.convolution( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) - .flatten(1, 2) + .flatten(1, 2), + **convolution_kwargs, ) else: x = ( - self.convolution(x) + self.convolution(x, **convolution_kwargs) .unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) @@ -214,6 +230,9 @@ def _forward( self._debug(c, "c", self._bc_dims, kwargs) self._debug(dt, "dt", self._xz_dims, kwargs) + scan_kwargs = ( + {} if self._config.cross_document_attention else {"position_indices": kwargs[MixerKwargs.position_ids]} + ) y = selective_scan_fn( x, dt, @@ -224,6 +243,7 @@ def _forward( z, delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, + **scan_kwargs, ) self._debug(y, "y", self._xz_dims, kwargs) @@ -242,3 +262,15 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Implement. raise NotImplementedError() + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + if not self._config.cross_document_attention: + assert ( + _mamba_varlen_available + ), f"Varlen mamba requires custom mamba installation from `https://github.com/jxiw/varlen_mamba`" + preprocess_for_varlen( + kwargs, + kwargs[MixerKwargs.device] if MixerKwargs.device in kwargs else self._distributed.device, + return_seq_idx=True, + return_position_ids=True, + ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 259073e32..2ca61aa0e 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -160,7 +160,9 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5 threshold = max(threshold * scale, min_threshold) rms = rms_diff(x, y).item() - assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + ( + assert ( + rms <= threshold + ), f"Rms diff too big ({rms:.2e} > {threshold:.2e}, scale = {scale:.2e}) between tensors {x} and {y}" + ( "" if msg is None else f"| {msg}" ) diff --git a/setup.cfg b/setup.cfg index 58f8ea2d1..005ae5a8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,8 @@ HUGGINGFACE = # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.6.post3 + # TODO: This is required for varlen mamba, but fails to compile in nvcr.io/nvidia/pytorch:25.11-py3. + # mamba_ssm[causal-conv1d] @ git+https://github.com/jxiw/varlen_mamba.git@varlen_mamba flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@67eee20c8503cd19eeb52aa1b99821308e9260c5 GENERATION = diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 2e47fd6aa..f28c9cce2 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -129,17 +129,10 @@ def test_build_padded_token_cumsum(): Assert.all_equal(token_cumsum, expected_cumsums) -def get_test_seeds(num_seeds): - np.random.seed(42) - seeds = np.random.randint(0, num_seeds * 100, num_seeds) - return seeds.tolist() - - @pytest.mark.skipif(not _extension_available, reason="CPP Extension not available") def test_gpt_sample_padding(): - for seed in get_test_seeds(100): + for _ in range(10): vocab_size = 30 - np.random.seed(seed) num_sequences = np.random.randint(1, 20) sequence_length = np.random.randint(1, 20) doc_sizes = np.random.randint(1, 2 * sequence_length, num_sequences) @@ -167,7 +160,7 @@ def test_gpt_sample_padding(): sampling = get_sampling_data( num_samples=len(expected_samples), sequence_length=sequence_length, - seed=seed, + seed=np.random.randint(100000), shuffle=ShufflingType.disabled, truncate_documents=False, ) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index de95ca214..a23b49f8e 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -1,10 +1,11 @@ import os +import sys import tempfile +import traceback +import typing import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward @@ -12,176 +13,40 @@ from tests.utils.utils import requires_cuda -def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): - fn = combined_worker - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) - try: - fn(rank, dist.group.WORLD, *fn_args) - finally: - dist.destroy_process_group() - - -def _spawn_dist(world_size: int, fn, *fn_args): - """ - Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. - """ - with tempfile.NamedTemporaryFile(delete=False) as tmp: - init_method = f"file://{tmp.name}" - - try: - mp.spawn( - _mp_worker, - args=(world_size, init_method, fn_args), - nprocs=world_size, - join=True, - start_method="spawn", - ) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) - - -def _assert_loss_and_grad(logits, loss, grad): - assert isinstance(loss, torch.Tensor) - assert loss.dim() == 0 - assert grad is None or grad.shape == logits.shape - assert torch.isfinite(loss) - if grad is not None: - assert torch.isfinite(grad).all() - - -@pytest.mark.parametrize("use_mask", [False, True]) -def test_reverse_kl_no_tp(use_mask): - torch.manual_seed(0) - batch_size, seq_len, vocab_size = 2, 3, 5 - logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) - target = torch.randn(batch_size, seq_len, vocab_size) - loss_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None - - loss, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - group=None, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) - _assert_loss_and_grad(logits, loss, grad) - - # Manual reference: sum over vocab then average over valid tokens. - teacher_log_probs = torch.log_softmax(target, dim=-1) - student_log_probs = torch.log_softmax(logits, dim=-1) - per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - valid_tokens = loss_mask.sum() - else: - valid_tokens = logits.shape[0] * logits.shape[1] - reference = per_sample.sum() / valid_tokens - torch.testing.assert_close(loss, reference, atol=1e-6, rtol=1e-6) - - -def _vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - torch.manual_seed(0) - world_size = dist.get_world_size(group) - - batch_size, seq_len, vocab_per_rank = 2, 3, 5 - full_vocab = vocab_per_rank * world_size - full_logits = torch.randn(batch_size, seq_len, full_vocab) - full_target = torch.randn(batch_size, seq_len, full_vocab) - full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None - - start = rank * vocab_per_rank - end = start + vocab_per_rank - logits = full_logits[:, :, start:end].clone().requires_grad_(True) - target = full_target[:, :, start:end].clone() - loss_mask = full_mask.clone() if full_mask is not None else None - - loss, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=None, - group=group, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) - _assert_loss_and_grad(logits, loss, grad) - - if rank == 0: - ref_loss, _ = reverse_kl_forward_backward( - logits=full_logits.clone(), - target=full_target.clone(), - loss_mask=full_mask.clone() if full_mask is not None else None, - grad_output=None, - group=None, - target_format=TargetFormat.logits, - sequence_parallel_logits=False, - ) +def _get_cross_entropy_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 + loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) + loss_mask = None else: - ref_loss = torch.zeros_like(loss) - dist.broadcast(ref_loss, src=0, group=group) - torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) - - -def _ce_vocab_tp_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - torch.manual_seed(0) - world_size = dist.get_world_size(group) - - batch_size, seq_len, vocab_per_rank = 2, 3, 5 - full_vocab = vocab_per_rank * world_size - full_logits = torch.randn(batch_size, seq_len, full_vocab) - full_target = torch.randn(batch_size, seq_len, full_vocab) - full_mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]) if use_mask else None - - start = rank * vocab_per_rank - end = start + vocab_per_rank - logits = full_logits[:, :, start:end].clone().requires_grad_(True) - target = full_target[:, :, start:end].clone() - loss_mask = full_mask.clone() if full_mask is not None else None - - loss, grad = cross_entropy_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=None, - group=group, - implementation=CrossEntropyImpl.fused, - target_format=TargetFormat.logits, - logits_scale_factor=1.0, - ) - _assert_loss_and_grad(logits, loss, grad) - - if rank == 0: - ref_loss, _ = cross_entropy_forward_backward( - logits=full_logits.clone(), - target=full_target.clone(), - loss_mask=full_mask.clone() if full_mask is not None else None, - grad_output=None, - group=None, - implementation=CrossEntropyImpl.fused, - target_format=TargetFormat.logits, - logits_scale_factor=1.0, - ) + target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + logits = target + logits_var + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) + return logits, target, loss_mask + + +def _compare_cross_entropy_outputs( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, + threshold=1e-5, +): + Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) + if has_grad: + Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) else: - ref_loss = torch.zeros_like(loss) - dist.broadcast(ref_loss, src=0, group=group) - torch.testing.assert_close(loss, ref_loss, atol=1e-6, rtol=1e-6) - - -def combined_worker(rank: int, group: dist.ProcessGroup, use_mask: bool): - _vocab_tp_worker(rank, group, use_mask) - _ce_vocab_tp_worker(rank, group, use_mask) - - -# TODO: maybe merge these tests using same parametrization -@pytest.mark.slow -@pytest.mark.parametrize("use_mask", [True, False]) -def test_distillation_losses(use_mask): - _spawn_dist(2, combined_worker, use_mask) + assert grad is None + assert ref_grad is None @requires_cuda @@ -203,21 +68,7 @@ def test_distillation_losses(use_mask): def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") - logits = (torch.nn.functional.one_hot(target, num_columns) + logits_var).requires_grad_() - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") - logits = (target + logits_var).requires_grad_() - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - + logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { "logits": logits, "target": target, @@ -228,26 +79,131 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski } # Torch serves as the reference implementation. out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) - out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) - Assert.rms_close(out_fused, out_torch, 5e-3) - if grad_output is None: - assert grad_torch is None - assert grad_fused is None - else: - Assert.rms_close(grad_fused, grad_torch, 5e-3) + + # TODO: Why is the error so high with logit scaling? + threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 + _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) else: out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - if grad_output is None: - assert grad_triton is None - else: - Assert.rms_close(grad_triton, grad_torch, 5e-3) - Assert.rms_close(out_triton, out_torch, 5e-3) + _compare_cross_entropy_outputs( + out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold + ) + + +def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tensor, loss_mask: torch.Tensor | None): + # Manual reference: sum over vocab then average over valid tokens. + logits = logits.detach().requires_grad_() + per_sample = torch.nn.functional.kl_div( + torch.log_softmax(target.float(), dim=-1), + torch.log_softmax(logits.float(), dim=-1), + reduction="none", + log_target=True, + ).sum(dim=-1) + output = per_sample.mean() if loss_mask is None else (per_sample * loss_mask).sum() / loss_mask.sum() + output.backward() + return output, logits.grad + + +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_reverse_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(10000, loss_masking, target_format) + out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) + out, grad = reverse_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + # TODO: Error looks + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) -if __name__ == "__main__": - pytest.main([__file__]) +def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): + try: + torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) + fn_args[0](rank, torch.distributed.group.WORLD, *fn_args[1:]) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def _spawn_dist(world_size: int, *fn_args): + """ + Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. + """ + with tempfile.NamedTemporaryFile(delete=False) as tmp: + init_method = f"file://{tmp.name}" + + try: + torch.multiprocessing.spawn( + _mp_worker, + args=(world_size, init_method, fn_args), + nprocs=world_size, + join=True, + start_method="spawn", + ) + finally: + if os.path.exists(tmp.name): + os.remove(tmp.name) + + +def _compare_parallel_cross_entropy( + rank: int, + group: torch.distributed.ProcessGroup, + target_format: TargetFormat, + function: typing.Callable, + loss_masking: bool, +): + # Ensure all workers have the same inputs. + torch.manual_seed(0) + world_size = torch.distributed.get_world_size(group) + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + + out, grad = function( + logits=logits.chunk(world_size, 1)[rank], + target=target.chunk(world_size, 1)[rank], + loss_mask=loss_mask, + grad_output=1, + group=group, + target_format=target_format, + ) + + out_ref, grad_ref = function( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1, + target_format=target_format, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) + + +def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): + success = True + for function in (cross_entropy_forward_backward, reverse_kl_forward_backward): + for target_format in (TargetFormat.logits,): + for loss_masking in [True, False]: + try: + _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) + except Exception: + print( + f" >>>>>> Failed {function.__name__}, target_format, use_mask={loss_masking}", file=sys.stderr + ) + traceback.print_exc() + success = False + if not success: + raise RuntimeError("Test failed") + + +@pytest.mark.slow +def test_distillation_losses(): + _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/test_attention.py b/tests/layers/test_attention.py similarity index 50% rename from tests/test_attention.py rename to tests/layers/test_attention.py index f1409b95c..508597173 100644 --- a/tests/test_attention.py +++ b/tests/layers/test_attention.py @@ -4,56 +4,11 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.utils import Assert from tests.utils.utils import requires_cuda -# TODO: ====== micro-sequence ====== -@pytest.mark.skip -def test_varlen_preprocessing(): - sequence_lengths = [[8, 13, 4, 11], [11, 16, 9]] - # First micro-sequence: - # [0...7,0...3] + [0...10,0] -> [0,8,12,23,24] - # Second micro-sequence: - # [4...12,0...2] + [1...12] -> [0,9,12,24] - # Third micro-sequence: - # [3,0...10] + [13...15, 0...8] -> [1,12,15,24] - cumulative_sequences_q = [ - torch.tensor([0, 8, 12, 23, 24], dtype=torch.int32), - torch.tensor([0, 0, 9, 12, 12, 24], dtype=torch.int32), - torch.tensor([0, 0, 0, 1, 12, 12, 15, 24], dtype=torch.int32), - ] - cumulative_sequences_k = [ - torch.tensor([0, 8, 12, 23, 24], dtype=torch.int32), - torch.tensor([0, 8, 21, 24, 35, 48], dtype=torch.int32), - torch.tensor([0, 8, 21, 25, 36, 47, 63, 72], dtype=torch.int32), - ] - micro_sequence_length = 12 - sequence_length = 36 - attention = Attention( - AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), - DistributedConfig(compute_dtype="bfloat16"), - hidden_dim=TensorDim("", 1), - lr_scale=None, - peft=None, - ) - for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): - kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), - AttentionKwargs.sequence_k_dim: TensorDim( - BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length - ), - AttentionKwargs.sequence_length: sequence_length, - AttentionKwargs.sequence_lengths: sequence_lengths, - AttentionKwargs.device: torch.device("cpu"), - } - attention.preprocess(kwargs) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) - - @requires_cuda @pytest.mark.parametrize("cross_document_attention", (True, False)) @pytest.mark.parametrize(("causal", "window_size"), ((True, None), (True, 50), (False, None))) diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py deleted file mode 100644 index dae4f52b2..000000000 --- a/tests/layers/test_gdn_equivalence.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Test numerical equivalence between Fast-LLM GDN and Apriel2 GatedDeltaNet.""" - -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet -except ImportError: - Apriel2GatedDeltaNet = None - -try: - from fla.ops.gated_delta_rule import chunk_gated_delta_rule - - _gdn_kernel_available = True -except ImportError: - _gdn_kernel_available = False - -# Test constants -VOCAB_SIZE = 500 -HIDDEN_SIZE = 64 -SEQ_LEN = 65 -BATCH_SIZE = 2 -NUM_V_HEADS = 4 -NUM_K_HEADS = 2 -HEAD_DIM = 16 -KERNEL_SIZE = 4 - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.skipif(Apriel2GatedDeltaNet is None, reason="Apriel2 GDN not available") -@pytest.mark.skipif(not _gdn_kernel_available, reason="GDN CUDA kernels not available") -def test_fast_llm_gdn_matches_apriel2_forward(): - """Verify Fast-LLM GDN output matches Apriel2 GatedDeltaNet.""" - torch.manual_seed(42) - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Create Apriel2 GDN layer - gdn_config = { - "value_heads": NUM_V_HEADS, - "key_heads": NUM_K_HEADS, - "key_head_dim": HEAD_DIM, - "value_head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "norm_eps": 1e-5, - } - hf_layer = Apriel2GatedDeltaNet(HIDDEN_SIZE, gdn_config, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype) - hf_layer.eval() - - # Create Fast-LLM GDN layer - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": { - "mixer": { - "type": "gdn", - "value_heads": NUM_V_HEADS, - "key_heads": NUM_K_HEADS, - "key_head_dim": HEAD_DIM, - "value_head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "normalization": {"epsilon": 1e-5}, - } - }, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype) - fast_layer.eval() - - # Copy weights: parameter names match exactly, so use load_state_dict - hf_layer.load_state_dict(fast_layer.state_dict()) - - # Verify all parameters match - hf_state = hf_layer.state_dict() - for name, fast_param in fast_layer.state_dict().items(): - assert name in hf_state, f"Parameter {name} missing in HF layer" - hf_param = hf_state[name] - if fast_param.shape != hf_param.shape: - hf_param = hf_param.reshape_as(fast_param) - Assert.all_equal(fast_param, hf_param) - - # Forward passes - hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - - hf_out = hf_layer(hidden_states)[0] - - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - BlockKwargs.sequence_length: SEQ_LEN, - BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - # Compare outputs - Assert.rms_close(fast_out, hf_out, 1e-5) diff --git a/tests/layers/test_kda_equivalence.py b/tests/layers/test_kda_equivalence.py deleted file mode 100644 index fb0042c45..000000000 --- a/tests/layers/test_kda_equivalence.py +++ /dev/null @@ -1,99 +0,0 @@ -"""Test numerical equivalence between Fast-LLM KDA and Apriel2 KimiDeltaAttention.""" - -import pytest -import torch - -import fast_llm.layers.ssm.kda as kda_module -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import KimiDeltaAttention -except ImportError: - KimiDeltaAttention = None - -# Test constants -VOCAB_SIZE = 500 -HIDDEN_SIZE = 64 -SEQ_LEN = 65 -BATCH_SIZE = 2 -NUM_HEADS = 4 -HEAD_DIM = 16 -KERNEL_SIZE = 4 - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.skipif(KimiDeltaAttention is None, reason="Apriel2 KDA not available") -@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") -def test_fast_llm_kda_matches_apriel2_forward(): - """Verify Fast-LLM KDA output matches Apriel2 KimiDeltaAttention.""" - torch.manual_seed(42) - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Shared config - parameter names match exactly between implementations - kda_config = { - "heads": NUM_HEADS, - "head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, - "normalization": {"epsilon": 1e-5, "activation": "sigmoid"}, - } - - # Create Apriel2 KDA layer - hf_layer = KimiDeltaAttention(HIDDEN_SIZE, kda_config, layer_idx=0).to(device=device, dtype=dtype) - hf_layer.eval() - - # Create Fast-LLM KDA layer - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": {"mixer": {"type": "kda", **kda_config}}, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype) - fast_layer.eval() - - # Copy weights: parameter names match exactly, so use load_state_dict - hf_layer.load_state_dict(fast_layer.state_dict()) - - # Verify all parameters match - hf_state = hf_layer.state_dict() - for name, fast_param in fast_layer.state_dict().items(): - Assert.all_equal(fast_param, hf_state[name]) - - # Forward passes - hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - - hf_out = hf_layer(hidden_states)[0] - - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - # Compare outputs - Assert.rms_close(fast_out, hf_out, 1e-5) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6383e6aae..623a30d82 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -163,8 +163,6 @@ def test_lm_head( loss_masking: bool, prediction_heads: int, ): - torch.cuda.manual_seed(0) - torch.manual_seed(0) head_config = { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, @@ -266,6 +264,8 @@ def test_lm_head( distributed, tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + # Names must be kept as-is for tied weights. + set_names=False, ) # Get reference outputs and grads diff --git a/tests/layers/test_mamba_equivalence.py b/tests/layers/test_mamba_equivalence.py deleted file mode 100644 index ccf2dba41..000000000 --- a/tests/layers/test_mamba_equivalence.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Test numerical equivalence between Fast-LLM Mamba2 and Apriel2 Mamba. - -Note: Fast-LLM's "mamba_2" type is actually a Mamba 1 variant (not the true Mamba 2 -architecture). It corresponds to the HuggingFace/Apriel Mamba implementation. -""" - -import pytest -import torch - -from fast_llm.config import UpdateType -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm.config import Mamba2Config # Ensures mamba_2 type is registered -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig -from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - -# Ensure Mamba2Config is registered for dynamic type lookup -_ = Mamba2Config - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Mamba -except ImportError: - Apriel2Mamba = None - -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn - - _mamba_kernel_available = True -except (ImportError, RuntimeError): - _mamba_kernel_available = False - -# Test constants -VOCAB_SIZE = 500 -HIDDEN_SIZE = 64 -SEQ_LEN = 65 -BATCH_SIZE = 2 -D_INNER = 128 -D_XB = 64 -D_STATE = 16 -D_CONV = 4 -DT_RANK = 4 - - -def _copy_weights(fast_layer, hf_layer): - """Copy weights from Apriel2 Mamba to Fast-LLM Mamba2.""" - with torch.no_grad(): - # Main projections - fast_layer.in_proj.weight.copy_(hf_layer.in_proj.weight) - if fast_layer.in_proj.bias is not None and hf_layer.in_proj.bias is not None: - fast_layer.in_proj.bias.copy_(hf_layer.in_proj.bias) - - # DT projections - fast_layer.dt_in_proj.weight.copy_(hf_layer.dt_in_proj.weight) - if fast_layer.dt_in_proj.bias is not None and hf_layer.dt_in_proj.bias is not None: - fast_layer.dt_in_proj.bias.copy_(hf_layer.dt_in_proj.bias) - - fast_layer.dt_proj.weight.copy_(hf_layer.dt_proj.weight) - if fast_layer.dt_proj.bias is not None and hf_layer.dt_proj.bias is not None: - fast_layer.dt_proj.bias.copy_(hf_layer.dt_proj.bias) - - # Convolution (Fast-LLM uses "convolution", Apriel2 uses "conv1d") - fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) - if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: - fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) - - # SSM parameters - fast_layer.A_log.copy_(hf_layer.A_log) - fast_layer.D.copy_(hf_layer.D) - - # Output projection - fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) - if fast_layer.out_proj.bias is not None and hf_layer.out_proj.bias is not None: - fast_layer.out_proj.bias.copy_(hf_layer.out_proj.bias) - - -@pytest.mark.slow -@requires_cuda -@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") -@pytest.mark.skipif(not _mamba_kernel_available, reason="Mamba CUDA kernels not available") -@pytest.mark.parametrize("add_linear_biases", [True, False]) -@pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) -def test_fast_llm_mamba2_matches_apriel2(add_linear_biases, repeat_kv_before_conv): - """Verify Fast-LLM Mamba2 output matches Apriel2 Mamba. - - Args: - add_linear_biases: Whether to add biases to linear layers. - repeat_kv_before_conv: Whether to repeat KV before or after convolution. - """ - torch.manual_seed(42) - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Create Apriel2 Mamba layer - # Note: Apriel2 has separate conv_bias and dt_proj_bias controls. - # We align them with Fast-LLM's single add_linear_biases flag. - mamba_config = { - "d_inner": D_INNER, - "d_xb": D_XB, - "state_size": D_STATE, - "d_conv": D_CONV, - "dt_rank": DT_RANK, - "conv_bias": add_linear_biases, - "dt_proj_bias": add_linear_biases, - "add_linear_biases": add_linear_biases, - "repeat_kv_before_conv": repeat_kv_before_conv, - "dt_min": 0.001, - "dt_max": 0.1, - "dt_init_floor": 1e-4, - } - hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0, dtype=dtype).to(device=device, dtype=dtype) - hf_layer.eval() - - # Create Fast-LLM Mamba2 layer - config = GPTBaseModelConfig.from_dict( - { - "decoder": { - "num_blocks": 1, - "block": { - "mixer": { - "type": "mamba_2", - "d_inner": D_INNER, - "d_xb": D_XB, - "state_size": D_STATE, - "convolution_layer": {"kernel_size": D_CONV}, - "dt_rank": DT_RANK, - "add_linear_biases": add_linear_biases, - "repeat_kv_before_conv": repeat_kv_before_conv, - } - }, - }, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "hidden_size": HIDDEN_SIZE, - }, - update_type=UpdateType.update, - ) - - model, distributed = get_base_model( - GPTModelConfig.from_dict( - { - "base_model": config, - "distributed": {}, - }, - ) - ) - fast_layer = model.decoder[0].mixer - get_stage([fast_layer], distributed, [], {}) - fast_layer.to(device=device, dtype=dtype) - fast_layer.eval() - - # Copy weights - _copy_weights(fast_layer, hf_layer) - - # Verify key parameters match (not all names match between implementations) - Assert.all_equal(fast_layer.in_proj.weight, hf_layer.in_proj.weight) - Assert.all_equal(fast_layer.convolution.weight, hf_layer.conv1d.weight) - Assert.all_equal(fast_layer.A_log, hf_layer.A_log) - Assert.all_equal(fast_layer.D, hf_layer.D) - Assert.all_equal(fast_layer.out_proj.weight, hf_layer.out_proj.weight) - - # Forward passes - hidden_states = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype, requires_grad=False) - - hf_out = hf_layer(hidden_states)[0] - - fast_kwargs = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.sequence_lengths: [[SEQ_LEN] for _ in range(BATCH_SIZE)], - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), - } - fast_layer.preprocess(fast_kwargs) - fast_out, _ = fast_layer(hidden_states, fast_kwargs) - - # Compare outputs (slightly looser tolerance for Mamba due to numerical differences) - Assert.rms_close(fast_out, hf_out, 1e-4) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py new file mode 100644 index 000000000..e6422c597 --- /dev/null +++ b/tests/layers/test_ssm.py @@ -0,0 +1,179 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config +from fast_llm.utils import Assert +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba +from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig +from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention +from tests.utils.utils import get_stage, requires_cuda + +HIDDEN_SIZE = 16 +SEQ_LEN = 65 + + +def _compare_mixers( + fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str], threshold=1e-5 +): + distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16)) + fast_llm_layer = fast_llm_config.get_layer( + distributed_config, + TensorDim("", HIDDEN_SIZE), + lr_scale=None, + peft=None, + ).eval() + get_stage([fast_llm_layer], distributed, [], {}) + hf_layer = hf_layer.to(device=distributed.device, dtype=distributed_config.compute_dtype.torch) + + with torch.no_grad(): + hf_state_dict = hf_layer.state_dict() + for name, param in fast_llm_layer.named_parameters(): + param.copy_(hf_state_dict[param_map.get(name, name)].view_as(param)) + + hf_params = hf_layer.state_dict() + for name, fast_param in fast_llm_layer.state_dict().items(): + hf_param = hf_params[param_map.get(name, name)] + Assert.rms_close_relative(fast_param, hf_param.view_as(fast_param), threshold, 1e-5, msg=name) + + hidden_states = torch.randn( + 2, + SEQ_LEN, + HIDDEN_SIZE, + device=distributed.device, + dtype=distributed_config.compute_dtype.torch, + requires_grad=False, + ) + + hf_layer.train() + hf_out = hf_layer(hidden_states) + if isinstance(hf_out, tuple): + (hf_out,) = hf_out + + sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] + fast_kwargs = { + BlockKwargs.device: distributed.device, + BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.hidden_dims: (HIDDEN_SIZE,), + BlockKwargs.sequence_q_dim: TensorDim("", SEQ_LEN), + BlockKwargs.sequence_k_dim: TensorDim("", SEQ_LEN), + } + fast_llm_layer.train() + fast_llm_layer.preprocess(fast_kwargs) + fast_out = fast_llm_layer(hidden_states, fast_kwargs) + + Assert.rms_close_relative(fast_out, hf_out, threshold, 1e-5) + + +@pytest.mark.slow +@requires_cuda +def test_gdn(): + device = torch.device("cuda") + dtype = torch.bfloat16 + + NUM_V_HEADS = 4 + NUM_K_HEADS = 2 + HEAD_DIM = 4 + KERNEL_SIZE = 4 + + config_common = { + "value_heads": NUM_V_HEADS, + "key_heads": NUM_K_HEADS, + "key_head_dim": HEAD_DIM, + "value_head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + } + + hf_layer = ( + Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) + .to(device=device, dtype=dtype) + .eval() + ) + fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) + _compare_mixers(fast_llm_config, hf_layer, {}) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") +@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") +def test_kda(): + NUM_HEADS = 4 + HEAD_DIM = 4 + KERNEL_SIZE = 4 + + hf_config = AprielHybridSSMConfig( + hidden_size=HIDDEN_SIZE, + num_attention_heads=NUM_HEADS, + num_hidden_layers=1, + rms_norm_eps=1e-6, + ) + hf_config.short_conv_kernel_size = KERNEL_SIZE + hf_config.head_dim = HEAD_DIM + hf_config.num_heads = NUM_HEADS + hf_layer = KimiDeltaAttention(hf_config, layer_idx=0) + + fast_llm_config = KimiDeltaAttentionConfig( + heads=NUM_HEADS, + head_dim=HEAD_DIM, + convolution_layer={"kernel_size": KERNEL_SIZE, "activation": "silu"}, + normalization={"epsilon": 1e-6, "activation": "sigmoid"}, + ) + + param_map = { + "q_conv.weight": "q_conv1d.weight", + "k_conv.weight": "k_conv1d.weight", + "v_conv.weight": "v_conv1d.weight", + "beta_proj.weight": "b_proj.weight", + "norm.weight": "o_norm.weight", + } + _compare_mixers(fast_llm_config, hf_layer, param_map) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize("add_linear_biases", [True, False]) +@pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) +@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") +def test_mamba(add_linear_biases, repeat_kv_before_conv): + D_INNER = 128 + D_XB = 64 + D_STATE = 16 + D_CONV = 4 + DT_RANK = 4 + + config_common = { + "d_inner": D_INNER, + "d_xb": D_XB, + "state_size": D_STATE, + "dt_rank": DT_RANK, + "repeat_kv_before_conv": repeat_kv_before_conv, + "add_linear_biases": add_linear_biases, + } + + mamba_config = { + "conv_bias": add_linear_biases, + "dt_proj_bias": add_linear_biases, + **config_common, + } + hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0) + + # Create Fast-LLM Mamba2 layer + fast_llm_config = Mamba2Config( + convolution_layer={"kernel_size": D_CONV}, + **config_common, + ) + + param_map = { + "convolution.weight": "conv1d.weight", + "convolution.bias": "conv1d.bias", + } + # TODO: This is a really high threshold. + _compare_mixers(fast_llm_config, hf_layer, param_map, threshold=1e-2) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py new file mode 100644 index 000000000..32cd00cd2 --- /dev/null +++ b/tests/layers/test_varlen.py @@ -0,0 +1,104 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import gdn as gdn_module +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config +from fast_llm.utils import Assert +from tests.utils.utils import get_stage, requires_cuda + + +# TODO: include mamba varlen +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "config", + [ + AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), + Mamba2Config( + d_inner=128, + d_xb=64, + state_size=16, + dt_rank=8, + cross_document_attention=False, + marks=pytest.mark.skip("Mamba varlen kernel not available"), + ), + pytest.param( + GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), + marks=pytest.mark.skipif( + gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + ), + ), + pytest.param( + KimiDeltaAttentionConfig(heads=4, head_dim=16), + marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + ), + ], +) +def test_mixer_varlen_stacking_equivalence(config: MixerConfig): + """ + Check that Gated Delta Net forward/backward match with and without packing. + """ + hidden_size = 32 + hidden_dim = TensorDim("hidden", hidden_size) + distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.float16)) + mixer = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + stage = get_stage([mixer], distributed) + + batch_size = 2 # cu_seqlens path requires flattened batch + seq_len = 15 + + sequence_lengths = [[6, 9], [4, 1, 10]] + hidden_states = torch.randn( + batch_size, + seq_len, + hidden_size, + device=distributed.device, + dtype=distributed_config.compute_dtype.torch, + requires_grad=True, + ) + + kwargs = { + BlockKwargs.device: distributed.device, + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + BlockKwargs.sequence_q_dim: TensorDim("", seq_len), + BlockKwargs.sequence_k_dim: TensorDim("", seq_len), + } + + kwargs_packed = {**kwargs, BlockKwargs.sequence_lengths: sequence_lengths} + mixer.preprocess(kwargs_packed) + + out_packed, context = stage.forward(hidden_states, kwargs_packed) + stage.backward(torch.ones_like(out_packed), context) + + names, parameters = zip(*list(mixer.named_parameters())) + grads_packed = [parameter.grad_buffer.clone() for parameter in parameters] + + stage.reset_gradients() + # Run reference path separately per sequence without varlen packing, then concatenate. + out_refs = [] + for i in range(batch_size): + for seq in torch.split(hidden_states[i], sequence_lengths[i], dim=0): + kwargs_seq = {**kwargs, BlockKwargs.sequence_lengths: [[len(seq)]]} + mixer.preprocess(kwargs_seq) + out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) + stage.backward(torch.ones_like(out), context) + out_refs.append(out) + out_ref = torch.cat(out_refs, dim=1).view_as(out_packed) + + Assert.rms_close_relative(out_packed, out_ref, 1e-3, 1e-4) + + for name, parameter, grad_packed in zip(names, parameters, grads_packed, strict=True): + Assert.rms_close_relative(grad_packed, parameter.grad_buffer, 1e-3, 1e-4, msg=name) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_varlen.py b/tests/test_varlen.py deleted file mode 100644 index 256da95d4..000000000 --- a/tests/test_varlen.py +++ /dev/null @@ -1,234 +0,0 @@ -import pytest -import torch - -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.decoder.config import MixerConfig -from fast_llm.layers.ssm import gdn as gdn_module -from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig -from fast_llm.utils import Assert - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def unpack_and_padd(packed_hidden_states, cu_seqlens, package_num): - batch_size = packed_hidden_states.shape[0] - seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - hidden_dim = packed_hidden_states.shape[2] - hidden_states = torch.zeros( - package_num * batch_size, - seq_len, - hidden_dim, - dtype=packed_hidden_states.dtype, - device=packed_hidden_states.device, - ) - for j in range(batch_size): - for i in range(package_num): - line = j * package_num + i - hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ - j, cu_seqlens[i] : cu_seqlens[i + 1], : - ] - return hidden_states - - -def pack(hidden_states, cu_seqlens, batch_size): - package_num, seq_len, hidden_dim = hidden_states.shape - seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] - seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) - indices_3d = ( - torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) - ) - mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) - packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) - return packed_hidden_states - - -def generate_random_seq_len(seq_len, packages_num=2): - if packages_num < 1: - raise ValueError("packages_num must be at least 1") - - # base size of each chunk, and how many get an extra token - base, rem = divmod(seq_len, packages_num) - # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] - lengths = [base + 1 if i < rem else base for i in range(packages_num)] - assert sum(lengths) == seq_len - assert len(lengths) == packages_num - return lengths - - -def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: - """ - Materialize meta parameters on the requested device for KDA mixer layers. - """ - for name, param in module.named_parameters(): - if param.device.type != "meta": - continue - param_data = torch.empty_like(param, device=device) - param.init_parameter(param_data, distributed) - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - target = module - if module_path is not None: - for part in module_path.split("."): - target = getattr(target, part) - new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - new_param.grad = None - new_param.grad_buffer = torch.zeros_like(param_data) - new_param.param_grad_is_zero = True - target._parameters[param_name] = new_param - - -def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: - return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad - - -# TODO: include mamba varlen -@pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") -@pytest.mark.parametrize( - "config, sequence_first", - [ - pytest.param( - GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), - False, - marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" - ), - ), - pytest.param( - GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), - True, - marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" - ), - ), - pytest.param( - KimiDeltaAttentionConfig(heads=4, head_dim=16), - False, - marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), - ), - pytest.param( - KimiDeltaAttentionConfig(heads=4, head_dim=16), - True, - marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), - ), - ], -) -def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): - """ - Check that Gated Delta Net forward/backward match with and without packing. - """ - device = torch.device("cuda") - dtype = torch.float16 - hidden_size = 32 - hidden_dim = TensorDim("hidden", hidden_size) - mixer_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - mixer_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - mixer_packed.setup(distributed) - mixer_ref.setup(distributed) - _materialize_mixer_tensors(mixer_packed, distributed, device) - _materialize_mixer_tensors(mixer_ref, distributed, device) - mixer_ref.load_state_dict(mixer_packed.state_dict()) - mixer_packed.to(device=device, dtype=dtype) - mixer_ref.to(device=device, dtype=dtype) - - batch_size = 2 # cu_seqlens path requires flattened batch - seq_len = 15 - packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) - sequence_lengths = [ - generate_random_seq_len(seq_len, packages_num=packages_num[i].item()) for i in range(batch_size) - ] - - packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) - if sequence_first: - packed = packed.transpose(0, 1) - - kwargs_packed = { - BlockKwargs.device: device, - BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.sequence_first: sequence_first, - BlockKwargs.hidden_dims: (hidden_dim,), - } - mixer_packed.preprocess(kwargs_packed) - - kwargs_ref = { - BlockKwargs.device: device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), - } - - out_packed = mixer_packed(packed, kwargs_packed) - if sequence_first: - out_packed = out_packed.transpose(0, 1) - # Run reference path separately per sequence without varlen packing, then concatenate. - ref_outs = [] - for b in range(batch_size): - out_batch = [] - length = sequence_lengths[b] - if sequence_first: - ref_seqs = torch.split(packed[:, b].unsqueeze(0), length, dim=1) - else: - ref_seqs = torch.split(packed[b].unsqueeze(0), length, dim=1) - for seq in ref_seqs: - kwargs_ref_seq = { - **kwargs_ref, - BlockKwargs.sequence_lengths: [seq.shape[1]], - } - out_batch.append(mixer_ref(seq, kwargs_ref_seq)) - ref_outs.append(torch.cat(out_batch, dim=1)) - out_ref = torch.cat(ref_outs, dim=0) - out_ref_packed = out_ref - - assert out_ref_packed.shape == out_packed.shape - assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) - - out_packed.sum().backward() - out_ref_packed.sum().backward() - - for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): - if param.requires_grad: - Assert.rms_close_relative(_param_grad(param), _param_grad(param_ref), 1e-3, 1e-3, msg=name) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 53373e0ca..83ed6836a 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -79,7 +79,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon SIMPLE_TESTING_CONFIG = DistributedTestingConfig( name="simple", compare=None, - config_args=["training.num_workers=2"], + config_args=[], num_gpus=1, ) @@ -87,7 +87,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="bf16", compare="simple", - config_args=["model.distributed.compute_dtype=bf16"], + # Also tests parallel data loader. + config_args=["model.distributed.compute_dtype=bf16", "training.num_workers=2"], num_gpus=1, compare_config=_bf16_compare, ), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e943dc96a..22b3b6569 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -41,6 +41,7 @@ TP_NO_STP = r"(?:^|(?<=[^s]))tp" +GRAD_ACC = r"df(?!16)|bf" class ModelTestingGroup(enum.StrEnum): @@ -621,7 +622,7 @@ def _update_and_add_testing_config( compare_factor=8, # Modes not supported with reference models and/or activation distillation. # TODO: Fix gradient accumulation and fp16, add TP support. - skip_tests=("sdp", "ms", "pp", "tp", "df", "bf", "fp16"), + skip_tests=("sdp", "ms", "pp", "tp", GRAD_ACC, "fp16"), ) _update_and_add_testing_config( @@ -815,7 +816,7 @@ def _update_and_add_testing_config( compare_factor=6.0, # Micro-sequence split and sequence-first not supported. # TODO: Gradient accumulation works but comparison is broken. - skip_tests=("sdp", "ms", "bf4", "df"), + skip_tests=("sdp", "ms", GRAD_ACC), auto_model_class=transformers.AutoModelForImageTextToText, ) @@ -1033,9 +1034,7 @@ def _update_and_add_testing_config( compare_factor=6.0, # Micro-sequence split and sequence-first not supported for Mamba. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). - # bf2_df2 depends on df4, so must also be skipped. - skip_tests=("sdp", "ms", "bf4", "df4", "bf2_df2", TP_NO_STP), - auto_model_class=transformers.AutoModelForImageTextToText, + skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP), ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 098f0240e..3b79f7607 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -12,6 +12,7 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import set_model_names from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig @@ -42,7 +43,15 @@ def get_stage( distributed: Distributed, tied_parameter_duplicates: typing.Iterable[str] = (), tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, + set_names: bool = True, ): + + for layer in layers: + if not layer._is_setup: + layer.setup(distributed) + if set_names: + # Normally called in `BaseModelConfig.get_base_model`, but may be missing here. + set_model_names(torch.nn.ModuleList(layers)) # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(),