Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(self,
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'

if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
assert tp_size > 1, 'Tensor parallelism must be enabled when using sequence parallelism'

self.tp_size = tp_size
self.pp_size = pp_size
Expand Down Expand Up @@ -422,7 +422,7 @@ def configure(
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.dp_size > 1, "Data parallel size should be greater than 1 when using Zero."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
Expand Down
820 changes: 820 additions & 0 deletions colossalai/shardformer/modeling/gptj.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ class PolicyLocation:
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),

# GPTJ
"transformers.models.gptj.modeling_gptj.GPTJModel":
PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForCausalLM":
PolicyLocation(file_name="gptj", class_name="GPTJForCausalLMPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering":
PolicyLocation(file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification":
PolicyLocation(file_name="gptj", class_name="GPTJForSequenceClassificationPolicy"),

# ViT
"transformers.models.vit.modeling_vit.ViTModel":
Expand Down
317 changes: 317 additions & 0 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
from functools import partial
from typing import Callable, Dict, List

from torch import Tensor, nn

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward, gptj_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'GPTJPolicy', 'GPTJModelPolicy', 'GPTJForCausalLMPolicy', 'GPTJForSequenceClassificationPolicy',
'GPTJForQuestionAnsweringPolicy', 'FlaxGPTJPolicy', 'FlaxGPTJForCausalLMPolicy'
]


class GPTJPolicy(Policy):

def config_sanity_check(self):
pass

def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel

policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPTJModel] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
])

policy[GPTJBlock] = ModulePolicyDescription(
attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.rotary_dim": self.model.config.rotary_dim // self.shard_config.tensor_parallel_size,
"attn.num_attention_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="attn.out_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="mlp.fc_out",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])

# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPTJModel)

self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
)
],
policy=policy,
target_key=GPTJBlock)

if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_gptj_flash_attention_forward(),
},
policy=policy,
target_key=GPTJAttention)

if self.shard_config.enable_sequence_parallelism:
policy[GPTJModel].method_replacement = {"forward": gptj_sequence_parallel_forward_fn(self.shard_config)}

return policy

def postprocess(self):
return self.model

def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None

if self.model.__class__.__name__ == 'GPTJModel':
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.wte)
#held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers

def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPTJModel':
module = self.model
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

# GPTJModel
class GPTJModelPolicy(GPTJPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJModel

policy = super().module_policy()

if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPTJModel,
new_forward=GPTJPipelineForwards.gptj_model_forward,
policy=policy)
return policy

def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2Model."""
return []

# GPTJForCausalLM
class GPTJForCausalLMPolicy(GPTJPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM

policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPTJForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
policy.update(addon_module)

if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPTJForCausalLM,
new_forward=GPTJPipelineForwards.gptj_causallm_model_forward,
policy=policy)
return policy

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
return []

# GPTJForSequenceClassification
class GPTJForSequenceClassificationPolicy(GPTJPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForSequenceClassification

policy = super().module_policy()

if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPTJForSequenceClassification,
new_forward=GPTJPipelineForwards.gptj_for_sequence_classification_forward,
policy=policy)
return policy

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPTJForSequenceClassification."""
return []

# GPTJForQuestionAnswering
class GPTJForQuestionAnsweringPolicy(GPTJPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJForQuestionAnswering

policy = super().module_policy()

if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPTJForQuestionAnswering,
new_forward=GPTJPipelineForwards.gptj_for_question_answering_forward,
policy=policy)
return policy

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in GPT2ForQuestionAnswering."""
return []
4 changes: 1 addition & 3 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def main():
"weight_decay": 0.0,
},
]

optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

# lr scheduler
Expand Down Expand Up @@ -308,10 +307,9 @@ def _criterion(outputs, inputs):
data_builder.eval_splits, booster, coordinator)

if coordinator.is_master():
print(results)
if args.target_f1 is not None and 'f1' in results:
assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'


if __name__ == '__main__':
main()
main()
2 changes: 1 addition & 1 deletion examples/language/bert/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tqdm
transformers
scipy
scikit-learn
ptflops
ptflops
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .bloom import *
from .chatglm2 import *
from .gpt import *
from .gptj import *
from .llama import *
from .opt import *
from .sam import *
Expand Down
Loading