Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion docs/source/en/model_doc/granitemoehybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ for i in output:

This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).

## Notes

- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.

Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.

- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
- Each of the [`FlashAttentionKwargs`]
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
- `max_length_q: int`: the longest query length in the batch.
- `max_length_k: int`: the longest key length in the batch.

The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.

```python
from transformers import DataCollatorWithFlattening

# Example of using padding-free training
data_collator = DataCollatorWithFlattening(
tokenizer=tokenizer,
return_seq_idx=True,
return_flash_attn_kwargs=True
)
```

## GraniteMoeHybridConfig

Expand All @@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co
## GraniteMoeHybridForCausalLM

[[autodoc]] GraniteMoeHybridForCausalLM
- forward
- forward
2 changes: 2 additions & 0 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -947,6 +948,7 @@ def forward(
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

# Only compute necessary logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TypedDict, Union

import torch
import torch.nn.functional as F
Expand All @@ -34,6 +34,7 @@
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
Expand Down Expand Up @@ -918,6 +919,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


class GraniteMoeHybridRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand Down Expand Up @@ -1125,7 +1151,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -1149,8 +1175,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -1161,6 +1187,7 @@ def forward(
cache_position=cache_position,
cache_params=past_key_value,
attention_mask=attention_mask,
**kwargs,
)
# No attention weights for state space layers
self_attn_weights = None
Expand Down Expand Up @@ -1303,6 +1330,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1374,6 +1402,7 @@ def forward(
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1706,6 +1735,7 @@ def forward(
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

# Only compute necessary logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

from ...cache_utils import Cache
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, logging
from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
from ..granitemoeshared.modeling_granitemoeshared import (
GraniteFlashAttentionKwargs,
GraniteMoeSharedAttention,
GraniteMoeSharedDecoderLayer,
GraniteMoeSharedForCausalLM,
Expand Down Expand Up @@ -84,7 +86,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -108,8 +110,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -120,6 +122,7 @@ def forward(
cache_position=cache_position,
cache_params=past_key_value,
attention_mask=attention_mask,
**kwargs,
)
# No attention weights for state space layers
self_attn_weights = None
Expand Down Expand Up @@ -198,6 +201,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -269,6 +273,7 @@ def forward(
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
from typing import Callable, Optional, TypedDict, Union

import torch
import torch.nn.functional as F
Expand All @@ -33,6 +33,7 @@
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_granitemoeshared import GraniteMoeSharedConfig

Expand All @@ -46,6 +47,31 @@
logger = logging.get_logger(__name__)


class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


class GraniteMoeSharedMLP(nn.Module):
"""
MLP layer for shared experts
Expand Down Expand Up @@ -431,7 +457,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -455,8 +481,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states

Expand Down Expand Up @@ -593,6 +619,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -979,6 +1006,7 @@ def forward(
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

# Only compute necessary logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, TypedDict

import torch
from torch import nn

from ...activations import ACT2FN
from ...cache_utils import Cache
from ...processing_utils import Unpack
from ...utils import logging
from ..granitemoe.modeling_granitemoe import (
GraniteMoeDecoderLayer,
Expand All @@ -33,6 +34,31 @@
logger = logging.get_logger(__name__)


class GraniteFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


class GraniteMoeSharedMLP(nn.Module):
"""
MLP layer for shared experts
Expand Down Expand Up @@ -75,7 +101,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
**kwargs: Unpack[GraniteFlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -99,8 +125,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `GraniteFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states

Expand Down
13 changes: 13 additions & 0 deletions tests/models/bamba/test_modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,15 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
# dtype conversions. This also lets us compare losses.
labels = inputs_dict["input_ids"].clone()
# Mask padding tokens
labels[~dummy_attention_mask.bool()] = -100
# Also need to mask the first non-trivial token to match the padding-free batch.
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
inputs_dict["labels"] = labels

model = (
model_class.from_pretrained(
Expand Down Expand Up @@ -586,6 +595,10 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_id
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)

loss_padded = res_padded.loss
loss_padfree = res_padfree.loss
torch.testing.assert_close(loss_padded, loss_padfree)


@slow
@require_torch
Expand Down