Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
481 changes: 435 additions & 46 deletions colossalai/shardformer/modeling/mistral.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def module_policy(self):

if self.shard_config.enable_flash_attention:
warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")

return policy

def postprocess(self):
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def module_policy(self):
policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def module_policy(self):
policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
Expand Down
152 changes: 142 additions & 10 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from functools import partial
from typing import Callable, Dict, Union
from typing import Callable, Dict, List, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.shardformer.layer import (
FusedRMSNorm,
Expand All @@ -14,7 +16,11 @@
VocabParallelLMHead1D,
)

from ..modeling.mistral import MistralForwards, get_mistral_flash_attention_forward
from ..modeling.mistral import (
MistralForwards,
get_mistral_flash_attention_forward,
get_mistral_model_forward_for_flash_attn,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"]
Expand Down Expand Up @@ -45,6 +51,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
Expand Down Expand Up @@ -145,16 +152,83 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
self.append_or_create_method_replacement(
description={
"forward": get_mistral_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=MistralModel,
)

return policy

def postprocess(self):
return self.model

def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
method_replacement = {"forward": partial(new_forward)}
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager is None:
return

stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MistralModel":
module = self.model
else:
module = self.model.model

if stage_manager.is_interleave:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}

else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}

self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None

if self.model.__class__.__name__ == "MistralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager

held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.norm)

else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers


class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
Expand All @@ -164,17 +238,28 @@ def module_policy(self):
policy = super().module_policy()
from transformers.models.mistral.modeling_mistral import MistralModel

self.set_forward(model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy
)

return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in mistral model"""
return []


class MistralForCausalLMPolicy(MistralPolicy):
def module_policy(self):
from transformers import MistralForCausalLM

policy = super().module_policy()
if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
Expand Down Expand Up @@ -207,8 +292,38 @@ def module_policy(self):

policy.update(new_item)

if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy
)

return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
mistral_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: mistral_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []


class MistralForSequenceClassificationPolicy(MistralPolicy):
def module_policy(self):
Expand All @@ -227,9 +342,26 @@ def module_policy(self):
]
)
}
policy.update(new_item)

if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MistralForSequenceClassification,
new_forward=MistralForwards.mistral_for_sequence_classification_forward,
policy=policy,
)

policy.update(new_item)
return policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def module_policy(self):
policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ ray
sentencepiece
google
protobuf
transformers==4.36.0
transformers==4.36.2
3 changes: 3 additions & 0 deletions tests/kit/model_zoo/transformers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def data_gen_for_sequence_classification():
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
)

if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id

model_zoo.register(
name="transformers_mistral",
model_fn=lambda: transformers.MistralModel(config),
Expand Down
21 changes: 19 additions & 2 deletions tests/test_shardformer/test_model/test_shard_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
Expand All @@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
Expand Down Expand Up @@ -156,7 +174,6 @@ def check_mistral(rank, world_size, port):
run_mistral_test()


@pytest.mark.skip("something wrong with pipeline parallelism")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down