Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def _broadcast_object_list(object_list: List[Any],
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
if torch.__version__ >= "1.13.0":
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
Comment thread
CjhHa1 marked this conversation as resolved.
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def forward_backward_step(self,
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)

input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)

if last_iteration:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class PolicyLocation:
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),

# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
Expand Down
832 changes: 731 additions & 101 deletions colossalai/shardformer/policies/bert.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,13 @@ def get_held_layers(self) -> List[Module]:
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
}]
return []


Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/torchrec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .torchrec import *
#from .torchrec import *
17 changes: 17 additions & 0 deletions tests/kit/model_zoo/transformers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def data_gen_for_mcq():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)


def data_gen_for_qa():
# generating data for question answering
# no need for labels and use start and end position instead
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data


# define output transform function
output_transform_fn = lambda x: x

Expand Down Expand Up @@ -150,3 +161,9 @@ def data_gen_for_mcq():
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_question_answering',
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
19 changes: 12 additions & 7 deletions tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn
Expand Down Expand Up @@ -35,25 +36,29 @@ def check_bert_for_pretraining_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)

x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_for_pretraining_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output['hidden_states'].shape)
output = bert_for_pretraining_forward(
self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index,
)
assert output['hidden_states'].shape == (2, 3, 768)

else:
attention_mask = torch.ones((2, 3))
output = bert_for_pretraining_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
stage_manager=stage_manager,
stage_index=stage_index)
assert output[0].shape == (2, 3, 30522)
# assert output[1].shape == (2, 768)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lmhead_forward
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy, bert_lm_head_model_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_bert_lmhead_forward():
def check_bert_lm_head_model_forward():
configuration = BertConfig()
model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1
Expand All @@ -35,24 +36,28 @@ def check_bert_lmhead_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)

layers_per_stage = Policy.distribute_layers(len(model.bert.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_lmhead_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)

output = bert_lm_head_model_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768)

else:
attention_mask = torch.ones((2, 3))
output = bert_lmhead_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
output = bert_lm_head_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape)
assert output[0].shape == (2, 3, 30522)

Expand Down Expand Up @@ -93,7 +98,7 @@ def check_bert_lmhead_policy():

def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_forward()
check_bert_lm_head_model_forward()


def run_dist_policy(rank, world_size, port):
Expand All @@ -103,7 +108,7 @@ def run_dist_policy(rank, world_size, port):

@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_lmhead_forward():
def test_bert_lm_head_model_forward():
spawn(run_dist_model, 4)


Expand All @@ -115,5 +120,5 @@ def test_bert_lmhead_policy():

if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_lmhead_forward()
test_bert_lm_head_model_forward()
test_bert_lmhead_policy()
16 changes: 11 additions & 5 deletions tests/test_pipeline/test_policy/test_bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.policies.bert import BertModelPolicy, bert_model_forward
from colossalai.shardformer.shard import ShardConfig
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_bert_model_forward():
# this test may crash for internet reasons
model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
Expand All @@ -34,20 +36,25 @@ def check_bert_model_forward():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)

layers_per_stage = Policy.distribute_layers(len(model.encoder.layer), 2)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
print(output['hidden_states'].shape)
output = bert_model_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager,
stage_index=stage_index)
assert output['hidden_states'].shape == (2, 3, 768)
else:
attention_mask = torch.ones((2, 3))
output = bert_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
stage_manager=stage_manager,
stage_index=stage_index)
print(output[0].shape)
assert output[0].shape == (2, 3, 768)

Expand Down Expand Up @@ -112,4 +119,3 @@ def test_bert_model_policy():
"""test the bert model forward and bert model policy"""
#test_bert_model_forward()
test_bert_model_policy()
Comment thread
CjhHa1 marked this conversation as resolved.
# this test need config to run
1 change: 0 additions & 1 deletion tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}

# switch to train mode
original_model.train()
sharded_model.train()
Expand Down
Loading