diff --git a/docs/source/en/model_doc/granitemoehybrid.md b/docs/source/en/model_doc/granitemoehybrid.md index 49fe0b04ac21..92d6e3b70ac4 100644 --- a/docs/source/en/model_doc/granitemoehybrid.md +++ b/docs/source/en/model_doc/granitemoehybrid.md @@ -48,6 +48,32 @@ for i in output: This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944). +## Notes + +- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens. + + Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`. + + - `position_ids: torch.LongTensor`: the position index of each token in each sequence. + - `seq_idx: torch.IntTensor`: the index of each sequence in the batch. + - Each of the [`FlashAttentionKwargs`] + - `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries. + - `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys. + - `max_length_q: int`: the longest query length in the batch. + - `max_length_k: int`: the longest key length in the batch. + + The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information. + + ```python + from transformers import DataCollatorWithFlattening + + # Example of using padding-free training + data_collator = DataCollatorWithFlattening( + tokenizer=tokenizer, + return_seq_idx=True, + return_flash_attn_kwargs=True + ) + ``` ## GraniteMoeHybridConfig @@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co ## GraniteMoeHybridForCausalLM [[autodoc]] GraniteMoeHybridForCausalLM - - forward \ No newline at end of file + - forward diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 5ea293e5bfff..fffe51d794bd 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -641,6 +641,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -947,6 +948,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index ab31709f3d5f..408cb861a142 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypedDict, Union import torch import torch.nn.functional as F @@ -34,6 +34,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -918,6 +919,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + class GraniteMoeHybridRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -1125,7 +1151,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1149,8 +1175,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1161,6 +1187,7 @@ def forward( cache_position=cache_position, cache_params=past_key_value, attention_mask=attention_mask, + **kwargs, ) # No attention weights for state space layers self_attn_weights = None @@ -1303,6 +1330,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1374,6 +1402,7 @@ def forward( cache_position=cache_position, output_router_logits=output_router_logits, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1706,6 +1735,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 4274af1b1920..242c95b90767 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -20,10 +20,12 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast +from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging from ..bamba.configuration_bamba import BambaConfig from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache from ..granitemoeshared.modeling_granitemoeshared import ( + GraniteFlashAttentionKwargs, GraniteMoeSharedAttention, GraniteMoeSharedDecoderLayer, GraniteMoeSharedForCausalLM, @@ -84,7 +86,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -108,8 +110,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -120,6 +122,7 @@ def forward( cache_position=cache_position, cache_params=past_key_value, attention_mask=attention_mask, + **kwargs, ) # No attention weights for state space layers self_attn_weights = None @@ -198,6 +201,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -269,6 +273,7 @@ def forward( cache_position=cache_position, output_router_logits=output_router_logits, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index b10369e767f3..21e9d13f7195 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from typing import Callable, Optional, TypedDict, Union import torch import torch.nn.functional as F @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_granitemoeshared import GraniteMoeSharedConfig @@ -46,6 +47,31 @@ logger = logging.get_logger(__name__) +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + class GraniteMoeSharedMLP(nn.Module): """ MLP layer for shared experts @@ -431,7 +457,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -455,8 +481,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states @@ -593,6 +619,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -979,6 +1006,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Only compute necessary logits diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index 29342cb6251d..630e5aa18439 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -13,13 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, TypedDict import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache +from ...processing_utils import Unpack from ...utils import logging from ..granitemoe.modeling_granitemoe import ( GraniteMoeDecoderLayer, @@ -33,6 +34,31 @@ logger = logging.get_logger(__name__) +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training and fewer `torch.compile` graph breaks. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`torch.IntTensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: torch.LongTensor + cu_seq_lens_k: torch.LongTensor + max_length_q: int + max_length_k: int + seq_idx: torch.IntTensor + + class GraniteMoeSharedMLP(nn.Module): """ MLP layer for shared experts @@ -75,7 +101,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, + **kwargs: Unpack[GraniteFlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -99,8 +125,8 @@ def forward( Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. """ residual = hidden_states diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index e1d8128a2c22..cf660b7fa09d 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -551,6 +551,15 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) dummy_attention_mask = inputs_dict["attention_mask"] inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id + # Ensure inputs_dict also has labels in it, as their presence/absence can induce + # dtype conversions. This also lets us compare losses. + labels = inputs_dict["input_ids"].clone() + # Mask padding tokens + labels[~dummy_attention_mask.bool()] = -100 + # Also need to mask the first non-trivial token to match the padding-free batch. + first_nonneg_idx = (labels >= 0).int().argmax(dim=1) + labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100 + inputs_dict["labels"] = labels model = ( model_class.from_pretrained( @@ -586,6 +595,10 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id tol = torch.finfo(torch.float16).eps torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + loss_padded = res_padded.loss + loss_padfree = res_padfree.loss + torch.testing.assert_close(loss_padded, loss_padfree) + @slow @require_torch