Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm', 'FusedRMSNorm'
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D_Col",
"Linear1D_Row",
'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row',
'DropoutForParallelInput',
'DropoutForReplicatedInput',
"cross_entropy_1d",
'FusedLayerNorm',
'FusedRMSNorm',
]
88 changes: 88 additions & 0 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Optional, Tuple

import torch

__all__ = ['opt_flash_attention_forward']


def opt_flash_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
# query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

src_len = key_states.size(1)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
attention_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, tgt_len).contiguous()

try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")

attn_output = me_attention(query_states,
key_states,
value_states,
attn_bias=attention_mask,
p=self.dropout,
scale=self.scaling)

attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
9 changes: 8 additions & 1 deletion colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .._utils import getattr_, setattr_
from ..modeling.opt import opt_flash_attention_forward
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -89,6 +90,12 @@ def module_policy(self):
policy=policy,
target_key=OPTDecoderLayer)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
'forward': opt_flash_attention_forward,
})

return policy

def postprocess(self):
Expand All @@ -107,12 +114,12 @@ def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM

policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)

return policy

def postprocess(self):
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ShardConfig:
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
enable_flash_attention: bool = False

# TODO: add support for tensor parallel
# pipeline_parallel_size: int
Expand All @@ -39,7 +40,6 @@ def __post_init__(self):
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)

# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()
Expand All @@ -50,3 +50,4 @@ def _turn_on_all_optimization(self):
"""
# you can add all the optimization flag here
self.enable_fused_normalization = True
self.enable_flash_attention = True
8 changes: 6 additions & 2 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from colossalai.shardformer import ShardConfig, ShardFormer


def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
def build_model(model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False):
# create new model
org_model = model_fn().cuda()

# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model_copy).cuda()
Expand Down
9 changes: 6 additions & 3 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
Expand Down Expand Up @@ -71,18 +72,20 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_opt_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()


def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
run_opt_test()


@pytest.mark.dist
Expand Down