Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def command_model_forward(
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def forward(

# in this case, attention_mask is a dict rather than a tensor
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
mask_shape = (inputs_embeds.shape[0], 1, seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def mistral_model_forward(

if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length)
mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def qwen2_model_forward(
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
Expand Down Expand Up @@ -651,6 +651,10 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
Expand All @@ -668,7 +672,7 @@ def forward(

if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
Expand Down