Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def qwen2_model_forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -131,14 +132,6 @@ def qwen2_model_forward(
else:
position_ids = position_ids.view(-1, seq_length).long()

if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
Expand All @@ -152,16 +145,16 @@ def qwen2_model_forward(
is_causal=True,
)
else:
if self._attn_implementation == "flash_attention_2":
if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
elif self.config._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
hidden_states,
past_key_values_length,
)
else:
Expand Down Expand Up @@ -195,6 +188,8 @@ def qwen2_model_forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

position_embeddings = self.rotary_emb(hidden_states, position_ids)

start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
Expand All @@ -214,7 +209,7 @@ def qwen2_model_forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None
past_key_values[idx] if past_key_values is not None else None

if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
Expand All @@ -225,15 +220,19 @@ def qwen2_model_forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -491,11 +490,10 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
Expand All @@ -519,9 +517,9 @@ def forward(
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -533,9 +531,8 @@ def forward(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
Expand Down Expand Up @@ -563,7 +560,7 @@ def forward(
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
Expand Down Expand Up @@ -605,11 +602,11 @@ def forward(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
return attn_output, None

return forward

Expand All @@ -627,6 +624,7 @@ def forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
force_sp_output_gather: bool = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
Expand All @@ -648,6 +646,9 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

seq_length_with_past = seq_length
past_key_values_length = 0

Expand All @@ -664,9 +665,6 @@ def forward(
else:
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# embed positions
hidden_states = inputs_embeds

Expand Down Expand Up @@ -700,6 +698,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
position_embeddings = self.rotary_emb(hidden_states, position_ids)

if sp_mode in ["ring", "split_gather"]:
hidden_states = split_forward_gather_backward(
Expand All @@ -723,22 +722,23 @@ def forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand Down
12 changes: 4 additions & 8 deletions colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,9 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
ATTN_IMPLEMENTATION = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}

policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
Expand All @@ -93,7 +87,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size

policy[attn_cls] = ModulePolicyDescription(
policy[Qwen2Attention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

Expand Down Expand Up @@ -301,12 +295,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)

if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention)
self.append_or_create_method_replacement(
description={
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,
target_key=Qwen2Attention,
)
if self.pipeline_stage_manager is None:
# replace qwen2 model forward method
Expand Down Expand Up @@ -370,6 +365,7 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
Expand Down
Loading