Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
Expand Down
156 changes: 147 additions & 9 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CausalLMOutputWithCrossAttentions,
)
from transformers.models.bert.modeling_bert import (
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForPreTrainingOutput,
BertLMHeadModel,
Expand Down Expand Up @@ -135,7 +137,6 @@ def module_policy(self):
],
policy=policy,
target_key=BertLayer)

# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
Expand All @@ -144,6 +145,7 @@ def module_policy(self):
)],
policy=policy,
target_key=BertEmbeddings)

return policy

def add_lm_head_policy(self, base_policy):
Expand Down Expand Up @@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
if self.pipeline_stage_manager:
# set None as default
module_policy[BertModel] = ModulePolicyDescription(
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
return module_policy

def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
Expand Down Expand Up @@ -444,6 +455,13 @@ def bert_model_forward(
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
else:
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape
Expand All @@ -466,14 +484,6 @@ def bert_model_forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
Expand Down Expand Up @@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel,
hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict
return {'hidden_states': hidden_states}


def bert_for_masked_lm_forward(
self: BertForMaskedLM,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
):
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
pass


def bert_for_next_sentence_prediction_forward(
self: BertForNextSentencePrediction,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
**kwargs,
):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring). Indices should be in `[0, 1]`:

- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")

>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")

>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
```
"""

if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
" `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output)

next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict
return {'hidden_states': hidden_states}
4 changes: 3 additions & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from types import MethodType
from typing import Any, Callable, Dict, List, Union

import torch.nn as nn
Expand Down Expand Up @@ -134,7 +135,8 @@ def _replace_param(
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
for method_name, new_method in method_replacement.items():
# bind the new method to the module
setattr(module, method_name, new_method.__get__(module, module.__class__))
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)

def _replace_sub_module(
self,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import nullcontext

from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer


Expand All @@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
return org_model.cuda(), sharded_model.cuda()


def build_pipeline_model(model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)

# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()


def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
Expand Down
85 changes: 85 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bert_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
import torch

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
pass


@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
Comment thread
FrankLeeeee marked this conversation as resolved.
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bert
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)

sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bert':
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)

if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
# print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
# print(output[0].shape)
assert output[0].shape == (2, 3, 128)

torch.cuda.empty_cache()


def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 4)


if __name__ == "__main__":
test_bert()