Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
6fc6a05
fix for async io
flybird11111 Feb 13, 2025
510ff7b
Merge branch 'hpcaitech:main' into main
flybird11111 Feb 14, 2025
3ecb500
test for upgrading transformers
flybird11111 Mar 27, 2025
40cf89d
Merge branch 'hpcaitech:main' into upgrade-transformers
flybird11111 Mar 27, 2025
0b81be7
add ci machine
flybird11111 Mar 28, 2025
6c728df
fix
flybird11111 Mar 31, 2025
43885a4
fix
flybird11111 Mar 31, 2025
837a503
fix
flybird11111 Mar 31, 2025
8c66b7c
fix
flybird11111 Mar 31, 2025
621cb93
fix
flybird11111 Mar 31, 2025
822556a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2025
4b8b67a
fix
flybird11111 Apr 1, 2025
3491a9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
ca91414
Update test_fp16_torch.py
flybird11111 Apr 9, 2025
397875e
Update build_on_pr.yml
flybird11111 Apr 9, 2025
28cf1e2
fix
flybird11111 Apr 9, 2025
b38d45e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
c0811d7
fix
flybird11111 Apr 9, 2025
466b61e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
a4e5ed9
fix
flybird11111 Apr 9, 2025
e92a692
Merge branch 'upgrade-transformers' of github.com:flybird11111/Coloss…
flybird11111 Apr 9, 2025
57d7b16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
0e900ac
fix
flybird11111 Apr 9, 2025
d5a3d1a
fix
flybird11111 Apr 9, 2025
603e229
fix
flybird11111 Apr 9, 2025
dce2212
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2025
25c5e42
fix
flybird11111 Apr 9, 2025
99298c6
Merge branch 'upgrade-transformers' of github.com:flybird11111/Coloss…
flybird11111 Apr 9, 2025
eaef783
fix
flybird11111 Apr 10, 2025
964f9a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
e8a3d52
fix
flybird11111 Apr 10, 2025
5c56a7f
Merge branch 'upgrade-transformers' of github.com:flybird11111/Coloss…
flybird11111 Apr 10, 2025
6997862
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
de4f7a1
fix
flybird11111 Apr 10, 2025
517bedc
Merge branch 'upgrade-transformers' of github.com:flybird11111/Coloss…
flybird11111 Apr 10, 2025
0d09c0e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
914b179
fix
flybird11111 Apr 10, 2025
c37107c
Merge branch 'upgrade-transformers' of github.com:flybird11111/Coloss…
flybird11111 Apr 10, 2025
21707a7
fix
flybird11111 Apr 10, 2025
910433f
fix
flybird11111 Apr 10, 2025
0950b07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
db4c73f
fix
flybird11111 Apr 11, 2025
fd69a82
fix
flybird11111 Apr 11, 2025
dc60efe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
a2e623d
fix
flybird11111 Apr 17, 2025
afe07a6
fiux
flybird11111 Apr 17, 2025
7af46ab
fix
flybird11111 Apr 17, 2025
52ead00
fix
flybird11111 Apr 18, 2025
0c5ed65
fix
flybird11111 Apr 18, 2025
6869827
upgrade llama
flybird11111 Apr 24, 2025
e891501
fix
flybird11111 Apr 24, 2025
d7a9eb0
Merge branch 'hpcaitech:main' into upgrade-transformers
flybird11111 Apr 24, 2025
2f615a4
fix
flybird11111 Apr 24, 2025
c6291be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2025
8497ecc
Merge pull request #6276 from flybird11111/upgrade-transformers
BurkeHulk Apr 24, 2025
5d167f2
fix
wangbluo Apr 28, 2025
885210d
fix
wangbluo Apr 28, 2025
08787f0
upgrade_bert
wangbluo May 5, 2025
5480b81
upgrade_bloom
wangbluo May 6, 2025
a4c6e18
[upgrade] upgrade gpt2 (#6291)
flybird11111 May 8, 2025
a9bb7cb
upgrade command
wangbluo May 8, 2025
0672449
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2025
e78c456
fix
wangbluo May 8, 2025
cefdfc4
add explanation
wangbluo May 8, 2025
4eced5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2025
fe94d73
fix
wangbluo May 8, 2025
b124603
fix
wangbluo May 8, 2025
d6f3508
fix
wangbluo May 13, 2025
4fbbf47
fix
wangbluo May 13, 2025
f118146
[upgrade]Upgrade qwen2 (#6302)
flybird11111 May 13, 2025
2237531
update_bloom
wangbluo May 13, 2025
07349e0
fix
wangbluo May 14, 2025
d665d67
add explantion
wangbluo May 14, 2025
c28b3c3
Merge pull request #6305 from wangbluo/update_bert
BurkeHulk May 14, 2025
1ace29b
Merge pull request #6299 from wangbluo/upgrade_bloom
BurkeHulk May 14, 2025
0dede48
Merge branch 'upgrade_transformers' into upgrade_falcon
wangbluo May 14, 2025
89917e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2025
b032cf9
upgrade_sam
wangbluo May 14, 2025
0e9d628
add the explanation
wangbluo May 14, 2025
5374601
Merge pull request #6283 from wangbluo/upgrade_falcon
BurkeHulk May 14, 2025
2223b64
upgrade_t
wangbluo May 15, 2025
ba9fb54
fix
wangbluo May 15, 2025
10bc6af
fix
wangbluo May 15, 2025
ced6b5e
fix
wangbluo May 16, 2025
e1925b3
upgrade_gptj
wangbluo May 16, 2025
4e49f05
fix
wangbluo May 16, 2025
07fa048
fix
wangbluo May 20, 2025
efb2d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2025
2aa295e
[upgrade]upgrade opt (#6307)
flybird11111 May 21, 2025
d0e13b8
[upgrade]Upgrade mixtral (#6317)
flybird11111 May 21, 2025
04516bb
[upgrade]Upgrade vit (#6308)
flybird11111 May 21, 2025
6875a8a
[upgrade]upgrade mistral (#6296)
flybird11111 May 21, 2025
e7ce582
Merge pull request #6313 from wangbluo/upgrade_gptj
BurkeHulk May 22, 2025
33614b8
Merge pull request #6306 from wangbluo/upgrade_sam
BurkeHulk May 22, 2025
6196faa
Merge pull request #6318 from wangbluo/upgrade_t5
BurkeHulk May 22, 2025
6a29abd
Merge pull request #6298 from wangbluo/upgrade_command
BurkeHulk May 22, 2025
bad9c8a
fix
flybird11111 May 22, 2025
bafc80c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2025
e1c72fd
fix
flybird11111 May 22, 2025
4a077e5
fix falcon
wangbluo May 22, 2025
ef8084a
Merge pull request #6322 from wangbluo/fix_falcon
BurkeHulk May 22, 2025
252efa6
fix
wangbluo May 23, 2025
b7df868
Update test_shard_deepseek.py
BurkeHulk May 23, 2025
7b39848
Merge pull request #6323 from wangbluo/fix_deepseek
BurkeHulk May 23, 2025
f009d3c
Update build_on_pr.yml
flybird11111 May 23, 2025
552778f
Update requirements.txt
flybird11111 May 23, 2025
63dc73d
fix (#6327)
flybird11111 May 26, 2025
559f15a
fix (#6328)
flybird11111 May 26, 2025
17654cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 26, 2025
611c124
Update bert.py
flybird11111 May 27, 2025
ba93bba
fix (#6329)
flybird11111 May 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
runs-on: ubuntu-latest
runs-on: [self-hosted,ubuntu-latest]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change
cancel-in-progress: true
Expand Down Expand Up @@ -87,7 +87,7 @@ jobs:
name: Build and Test Colossal-AI
needs: detect
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
runs-on: ubuntu-latest
runs-on: [self-hosted,ubuntu-latest]
container:
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch
Expand Down
79 changes: 20 additions & 59 deletions colossalai/inference/modeling/models/glide_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@

import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)

from colossalai.inference.spec import GlideInput
Expand Down Expand Up @@ -156,31 +153,29 @@ def glide_llama_model_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if hasattr(glide_input, "n_spec_tokens"):
position_ids = position_ids + glide_input.n_spec_tokens

# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -189,9 +184,9 @@ def glide_llama_model_forward(
# GlideLlamaDecoderLayer
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -200,9 +195,6 @@ def glide_llama_model_forward(

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand All @@ -212,16 +204,11 @@ def glide_llama_model_forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down Expand Up @@ -267,41 +254,17 @@ def __init__(self, config: GlideLlamaConfig):

self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Optional[torch.Tensor]:
Expand All @@ -319,8 +282,7 @@ def forward(
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)

# for RoPE
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
cos, sin = position_embeddings
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
Expand Down Expand Up @@ -367,9 +329,10 @@ def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlam
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
Expand Down Expand Up @@ -399,10 +362,10 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -425,9 +388,10 @@ def forward(

hidden_states = self.cross_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
position_ids=position_ids,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=True,
)
Expand All @@ -441,9 +405,6 @@ def forward(

outputs = (hidden_states,)

if use_cache:
outputs += (present_key_value,)

return outputs


Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,9 @@ def from_native_module(
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
num_heads=module.config.num_attention_heads,
hidden_size=module.config.hidden_size,
num_key_value_heads=module.config.num_key_value_heads,
)

return attn_layer
Expand Down
6 changes: 4 additions & 2 deletions colossalai/inference/spec/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from transformers.cache_utils import DynamicCache

from colossalai.utils import get_current_device

Expand Down Expand Up @@ -93,9 +94,8 @@ def speculate(

for _ in range(n_spec_tokens):
# update past key values
kwargs["past_key_values"] = past_key_values

outputs = self._drafter_model(input_ids, **kwargs)
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
next_token_logits = outputs.logits[:, -1, :]

# NOTE Only use greedy search for speculating.
Expand All @@ -114,6 +114,8 @@ def speculate(
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
if isinstance(past_key_values, DynamicCache):
past_key_values = past_key_values.to_legacy_cache()

out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
Expand Down
4 changes: 2 additions & 2 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def new_from_pretrained(
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
Expand Down Expand Up @@ -286,7 +286,7 @@ def new_from_pretrained(
config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = [no_init_weights()]

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
Expand Down
85 changes: 84 additions & 1 deletion colossalai/shardformer/modeling/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def bert_model_forward(
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
# TODO(jianghai): add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1037,6 +1037,89 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T
return forward


# Fix the tgt_len size in sequence parallel attention:
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention

def forward(
self: BertSdpaSelfAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2)
_, _, tgt_len, _ = query_layer.shape
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

return forward


def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
Expand Down
Loading