From a18fcd1d1ec5deb6264d7cc6451e9e4b6040a0c7 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 3 Aug 2023 17:55:09 +0800 Subject: [PATCH 1/8] fix remaining t5 bugs/rewrite t5 tests --- colossalai/shardformer/modeling/t5.py | 89 +++++------ colossalai/shardformer/policies/t5.py | 51 ++----- tests/kit/model_zoo/transformers/gpt.py | 3 + .../test_model/test_shard_gpt2.py | 2 + .../test_model/test_shard_t5.py | 143 ++++++++++++------ .../test_model/test_shard_t5_pipeline.py | 8 +- 6 files changed, 152 insertions(+), 144 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 0b3486e87c7e..eec0a60b9f87 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -261,6 +261,7 @@ def t5_model_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, stage_index: Optional[List[int]] = None, @@ -303,7 +304,6 @@ def t5_model_forward( decoder_head_mask = head_mask in_decoder = stage_manager.stage >= decoder_starting_stage - # Stage is in encoder, directly return the output of t5_stack_forward if not in_decoder: encoder_outputs = T5PipelineForwards.t5_stack_forward( @@ -323,25 +323,18 @@ def t5_model_forward( decoder_starting_stage=decoder_starting_stage) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_outputs': encoder_outputs} + return {'encoder_hidden_states': encoder_outputs[0]} else: return encoder_outputs at_last_decoder_stage = stage_manager.is_last_stage() at_first_decoder_stage = stage_manager.stage == decoder_starting_stage - if encoder_outputs is None: - raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") - - encoder_hidden_states = encoder_outputs[0] - if return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") - # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. if not at_first_decoder_stage and hidden_states is None: raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") @@ -360,6 +353,7 @@ def t5_model_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + stage_manager=stage_manager, hidden_states=hidden_states, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, @@ -368,22 +362,19 @@ def t5_model_forward( # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: - decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states return decoder_outputs if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + return decoder_outputs + encoder_hidden_states + else: + return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states) @staticmethod def t5_for_conditional_generation_forward( @@ -406,6 +397,7 @@ def t5_for_conditional_generation_forward( return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, stage_index: Optional[List[int]] = None, @@ -468,28 +460,25 @@ def t5_for_conditional_generation_forward( decoder_starting_stage=decoder_starting_stage) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_outputs': encoder_outputs} + return {'encoder_hidden_states': encoder_outputs[0]} else: return encoder_outputs at_last_decoder_stage = stage_manager.is_last_stage() at_first_decoder_stage = stage_manager.stage == decoder_starting_stage - if encoder_outputs is None: - raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.") + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") - encoder_hidden_states = encoder_outputs[0] - if return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in. if not at_first_decoder_stage and hidden_states is None: raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + # Decode decoder_outputs = T5PipelineForwards.t5_stack_forward( self.decoder, @@ -505,6 +494,7 @@ def t5_for_conditional_generation_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + stage_manager=stage_manager, hidden_states=hidden_states, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, @@ -513,7 +503,8 @@ def t5_for_conditional_generation_forward( # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: - decoder_outputs['encoder_outputs'] = encoder_outputs # encoder_outputs should be passed to the next stage + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states return decoder_outputs sequence_output = decoder_outputs[0] @@ -533,20 +524,16 @@ def t5_for_conditional_generation_forward( loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + return Seq2SeqLMOutput(loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states) @staticmethod def t5_encoder_model_forward( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 5e78ae9093fa..2ef52c214c6b 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -260,7 +260,7 @@ def get_held_layers(self) -> List[nn.Module]: model = self.model encoder = self.model.encoder - decoder = self.model.__dict__.get('decoder', None) + decoder = getattr(self.model, 'decoder', None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -300,7 +300,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli stage_manager = self.pipeline_stage_manager encoder = self.model.encoder - decoder = self.model.__dict__.get('decoder', None) + decoder = getattr(self.model, 'decoder', None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 @@ -355,15 +355,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] return [] - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]} - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) - return self.model - class T5ForConditionalGenerationPolicy(T5BasePolicy): @@ -409,28 +400,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: stage_manager.num_stages) shared_params = [] + shared_embedding = {} if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): - shared_params.append({ - 0: module.shared.weight, - decoder_starting_stage: module.decoder.embed_tokens.weight - }) + shared_embedding[0] = module.shared.weight + shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight + if id(module.lm_head.weight) == id(module.shared.weight): - shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight}) - return shared_params - return [] + shared_embedding[0] = module.shared.weight + shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight - def postprocess(self): - super().postprocess() - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = { - "shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] - } - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) - return self.model + return shared_params + + return [] class T5EncoderPolicy(T5BasePolicy): @@ -462,12 +446,3 @@ def get_held_layers(self) -> List[nn.Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] - - def postprocess(self): - if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: - binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]} - for k, v in binding_map.items(): - src = getattr_(self.model, k) - for dst in v: - setattr_(self.model, dst, src) - return self.model diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5c3eb4438bc8..c9006d81001e 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -79,6 +79,9 @@ def data_gen_for_sequence_classification(): config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 +config_for_token_classification = copy.deepcopy(config) +config_for_token_classification.num_labels = 2 + # register the following models model_zoo.register(name='transformers_gpt', model_fn=lambda: transformers.GPT2Model(config), diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index f7213d8c50b4..456048556b8e 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -48,6 +48,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': gpt2 = org_model sharded_gpt2 = sharded_model.unwrap() + gpt2 = org_model + sharded_gpt2 = sharded_model.unwrap() else: gpt2 = org_model.transformer sharded_gpt2 = sharded_model.unwrap().transformer diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 22f04c879879..a4a4b0045d09 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,60 +1,103 @@ -import os - import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - # the value "past_key_values" is sharded, so we ignore - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5) - - # do backward - org_loss.backward() - shard_loss.backward() - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - - # check grad - col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared'] - row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias'] - check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) - - # check weights are tied - if hasattr(org_model, 'lm_head'): - assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr() - assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() - - -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('use_lazy_init', [False, True]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention, - enable_jit_fused): +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) + + print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) + + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + + if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + + check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + + # unwrap model + t5 = org_model + sharded_t5 = sharded_model.unwrap() + + # TODO: Check shared embedding layers + row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + + # check weights and gradients + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize('test_config', [{'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'use_lazy_init': False}]) +# 'tp_size': 2, +# 'pp_size': 2, +# 'num_microbatches': 2, +# 'enable_fused_normalization': False, +# 'use_lazy_init': True +# }, { +# 'tp_size': 4, +# 'pp_size': 1, +# 'enable_fused_normalization': True, +# 'use_lazy_init': False +# }, { +# 'tp_size': 1, +# 'pp_size': 4, +# 'num_microbatches': 4, +# 'use_lazy_init': False +# } +@clear_cache_before_run() +def run_t5_test(test_config): + + # TODO: add plugin_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + # TODO: add test_config for flash attention & jit operator after supporting + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused, use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + if name != 'transformers_t5_encoder_model': + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -68,7 +111,7 @@ def check_t5(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_t5(): - spawn(check_t5, 2) + spawn(check_t5, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py index 7f3a5f2ea40b..dd1e99194d32 100644 --- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py @@ -43,7 +43,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ position_bias_shape = (batch_size, num_heads, seq_len, seq_len) num_encoder_layers = len(sharded_model.encoder.block) - decoder = sharded_model.__dict__.get('decoder', None) + decoder = getattr(sharded_model, 'decoder', None) num_decoder_layers = len(decoder.block) if decoder else 0 _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE) @@ -62,8 +62,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ inputs['position_bias'] = position_bias inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias if in_decoder: - encoder_output_states = torch.zeros(*hidden_state_shape).cuda() - inputs['encoder_outputs'] = (encoder_output_states,) + inputs['encoder_hidden_states'] = torch.zeros(*hidden_state_shape).cuda() sharded_model.train() output = sharded_model(**inputs) @@ -72,8 +71,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_ assert output.loss is not None else: if name != 'transformers_t5_encoder_model' and not in_decoder: - output = output['encoder_outputs'] - assert output[0].shape == hidden_state_shape + assert output['encoder_hidden_states'].shape == hidden_state_shape else: assert output['hidden_states'].shape == hidden_state_shape # position_bias information should be passed in T5 From a847d0be7d9bc5ad3ecc09b000fe362c9033c795 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 7 Aug 2023 14:02:47 +0800 Subject: [PATCH 2/8] fix multi-tensor communication in pipeline --- colossalai/pipeline/p2p.py | 6 +- colossalai/pipeline/schedule/one_f_one_b.py | 11 ++- colossalai/shardformer/modeling/t5.py | 6 +- .../test_model/test_shard_t5.py | 8 +- .../test_model/test_shard_t5_pipeline.py | 99 ------------------- 5 files changed, 24 insertions(+), 106 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_t5_pipeline.py diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f741b8363f13..af7a00b5c720 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -3,6 +3,7 @@ import io import pickle +import re from typing import Any, List, Optional, Union import torch @@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - if b'cuda' in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() - buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + # There might be more than one output tensors during forward + for cuda_str in re.finditer(b'cuda', buf_array): + pos = cuda_str.start() + buf_array[pos + 5] = 48 + device_index buf = bytes(buf_array) io_bytes = io.BytesIO(buf) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index d907d53edcde..ade3cf456fe3 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -107,8 +107,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], if output_obj_grad is None: optimizer.backward(output_obj) else: - for k, grad in output_obj_grad.items(): - optimizer.backward_by_grad(output_obj[k], grad) + if "backward_tensor_keys" not in output_obj: + for k, grad in output_obj_grad.items(): + optimizer.backward_by_grad(output_obj[k], grad) + else: + for k, grad in output_obj_grad.items(): + output_obj[k].grad = grad + for k in output_obj["backward_tensor_keys"]: + tensor_to_backward = output_obj[k] + optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad) # Collect the grad of the input_obj. input_obj_grad = None diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index eec0a60b9f87..d622da452366 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -238,7 +238,8 @@ def custom_forward(*inputs): return { 'hidden_states': hidden_states, 'position_bias': position_bias, - 'encoder_decoder_position_bias': encoder_decoder_position_bias + 'encoder_decoder_position_bias': encoder_decoder_position_bias, + 'backward_tensor_keys': ['hidden_states'] } @staticmethod @@ -264,6 +265,7 @@ def t5_model_forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: @@ -400,6 +402,7 @@ def t5_for_conditional_generation_forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: @@ -549,6 +552,7 @@ def t5_encoder_model_forward( hidden_states: Optional[torch.FloatTensor] = None, position_bias: Optional[torch.Tensor] = None, encoder_decoder_position_bias: Optional[torch.Tensor] = None, + backward_tensor_keys: Optional[List[str]] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a4a4b0045d09..be2b3ce5aec2 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -21,7 +21,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) + # print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( @@ -48,7 +48,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, t5 = org_model sharded_t5 = sharded_model.unwrap() - # TODO: Check shared embedding layers row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] # check weights and gradients @@ -93,8 +92,11 @@ def run_t5_test(test_config): test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != 'transformers_t5_encoder_model': + + # skip 4-stage pp test for t5_encoder + if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model': continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py deleted file mode 100644 index dd1e99194d32..000000000000 --- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py +++ /dev/null @@ -1,99 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.t5 import T5BasePolicy -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_t5.py -def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - input_ids = inputs['input_ids'] - - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - - batch_size, seq_len = input_ids.shape - hidden_size = sharded_model.config.d_model - num_heads = sharded_model.config.num_heads - hidden_state_shape = (batch_size, seq_len, hidden_size) - position_bias_shape = (batch_size, num_heads, seq_len, seq_len) - - num_encoder_layers = len(sharded_model.encoder.block) - decoder = getattr(sharded_model, 'decoder', None) - num_decoder_layers = len(decoder.block) if decoder else 0 - - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE) - stage = stage_manager.stage - at_first_stage = (stage == 0) or (stage == decoder_starting_stage) - at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) - in_decoder = stage >= decoder_starting_stage - - if not at_first_stage: - # change inputs if not the first stage - hidden_states = torch.zeros(*hidden_state_shape).cuda() - position_bias = torch.zeros(*position_bias_shape).cuda() - encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - inputs['position_bias'] = position_bias - inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias - if in_decoder: - inputs['encoder_hidden_states'] = torch.zeros(*hidden_state_shape).cuda() - - sharded_model.train() - output = sharded_model(**inputs) - if at_last_stage: - if name == 'transformers_t5_for_conditional_generation' and in_decoder: - assert output.loss is not None - else: - if name != 'transformers_t5_encoder_model' and not in_decoder: - assert output['encoder_hidden_states'].shape == hidden_state_shape - else: - assert output['hidden_states'].shape == hidden_state_shape - # position_bias information should be passed in T5 - assert output['position_bias'].shape == position_bias_shape - if in_decoder: - assert output['encoder_decoder_position_bias'].shape == position_bias_shape - - torch.cuda.empty_cache() - - -def check_t5(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_t5_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_t5(): - spawn(check_t5, 4) - - -if __name__ == "__main__": - test_t5() From eeb59fa1ca8e290d340758986b2927086d9391c5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 7 Aug 2023 14:09:22 +0800 Subject: [PATCH 3/8] rearrange test_config --- .../test_model/test_shard_gpt2.py | 6 +-- .../test_model/test_shard_t5.py | 39 +++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 456048556b8e..883b0eb7d024 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -71,15 +71,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() @parameterize('test_config', [{ - 'tp_size': 1, + 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, + 'enable_fused_normalization': True, 'use_lazy_init': True }, { - 'tp_size': 2, + 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': False, 'use_lazy_init': False }, { 'tp_size': 4, diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index be2b3ce5aec2..19443f74b0ca 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -63,23 +63,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'use_lazy_init': False}]) -# 'tp_size': 2, -# 'pp_size': 2, -# 'num_microbatches': 2, -# 'enable_fused_normalization': False, -# 'use_lazy_init': True -# }, { -# 'tp_size': 4, -# 'pp_size': 1, -# 'enable_fused_normalization': True, -# 'use_lazy_init': False -# }, { -# 'tp_size': 1, -# 'pp_size': 4, -# 'num_microbatches': 4, -# 'use_lazy_init': False -# } +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_fused_normalization': True, + 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False +}]) @clear_cache_before_run() def run_t5_test(test_config): From 654ebdbb466140aebc6ff6339597c05969515342 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 7 Aug 2023 17:05:53 +0800 Subject: [PATCH 4/8] fix keyerror in sync_shared_params --- .../booster/plugin/hybrid_parallel_plugin.py | 6 ++- colossalai/pipeline/schedule/_utils.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 3 -- .../test_model/test_shard_gpt2.py | 3 +- .../test_model/test_shard_t5.py | 43 ++++++++++++------- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a22bdb7199bb..42942aaeb89d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -50,8 +50,10 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp def sync_shared_params(self): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): - param = shared_param[self.stage_manager.stage] - dist.all_reduce(param.grad, group=group) + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() def no_sync(self) -> Iterator[None]: # no sync grads across data parallel diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 045c86e40e63..3ed9239272f1 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None: Args: x (Any): Object to be called. """ - if isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor) and x.requires_grad: x.retain_grad() diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index c9006d81001e..5c3eb4438bc8 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -79,9 +79,6 @@ def data_gen_for_sequence_classification(): config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 -config_for_token_classification = copy.deepcopy(config) -config_for_token_classification.num_labels = 2 - # register the following models model_zoo.register(name='transformers_gpt', model_fn=lambda: transformers.GPT2Model(config), diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 883b0eb7d024..1882bf7822cc 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -48,8 +48,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': gpt2 = org_model sharded_gpt2 = sharded_model.unwrap() - gpt2 = org_model - sharded_gpt2 = sharded_model.unwrap() else: gpt2 = org_model.transformer sharded_gpt2 = sharded_model.unwrap().transformer @@ -70,6 +68,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() + @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 19443f74b0ca..ee76f93653bd 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -21,7 +21,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - # print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) + print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( @@ -69,21 +69,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'num_microbatches': 2, 'enable_fused_normalization': True, 'use_lazy_init': True -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False }]) @clear_cache_before_run() def run_t5_test(test_config): @@ -91,6 +76,32 @@ def run_t5_test(test_config): # TODO: add plugin_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + # Completed: + # { + # 'tp_size': 1, + # 'pp_size': 2, + # 'num_microbatches': 4, + # 'use_lazy_init': False + # }, { + # 'tp_size': 4, + # 'pp_size': 1, + # 'enable_fused_normalization': True, + # 'use_lazy_init': False + # } + + # { + # 'tp_size': 2, + # 'pp_size': 2, + # 'num_microbatches': 2, + # 'enable_fused_normalization': True, + # 'use_lazy_init': True + # }, { + # 'tp_size': 1, + # 'pp_size': 4, + # 'num_microbatches': 4, + # 'use_lazy_init': False + # } + # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') From a449d0815fab1827e62203a11e91bf12ef01f9cc Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 8 Aug 2023 11:14:03 +0800 Subject: [PATCH 5/8] fix get_held_layers & Randomnizer, complete t5 tests --- colossalai/shardformer/layer/utils.py | 7 +++ colossalai/shardformer/policies/t5.py | 15 +++++-- .../test_model/test_shard_t5.py | 43 ++++++++----------- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index f2ac6563c46f..09cb7bfe1407 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -122,6 +122,13 @@ def increment_index(): """ Randomizer._INDEX += 1 + @staticmethod + def reset_index(): + """ + Reset the index to zero. + """ + Randomizer._INDEX = 0 + @staticmethod def is_randomizer_index_synchronized(process_group: ProcessGroup = None): """ diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2ef52c214c6b..b54c66444ab3 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -276,20 +276,27 @@ def get_held_layers(self) -> List[nn.Module]: if stage_manager.is_first_stage(): held_layers.append(model.shared) held_layers.append(encoder.embed_tokens) - held_layers.append(encoder.dropout) if stage_manager.stage == decoder_starting_stage - 1: held_layers.append(encoder.final_layer_norm) held_layers.append(encoder.dropout) - held_layers.extend(encoder.block[start_idx:end_idx]) + for i in range(start_idx, end_idx): + block = encoder.block[i] + held_layers.extend(block.layer) + held_layers.append(block.layer[0].SelfAttention) + held_layers.append(block.layer[1].DenseReluDense) else: # current stage is in t5's decoder if stage_manager.stage == decoder_starting_stage: held_layers.append(decoder.embed_tokens) - held_layers.append(decoder.dropout) if stage_manager.is_last_stage(): held_layers.append(decoder.final_layer_norm) held_layers.append(decoder.dropout) - held_layers.extend(decoder.block[start_idx:end_idx]) + for i in range(start_idx, end_idx): + block = decoder.block[i] + held_layers.extend(block.layer) + held_layers.append(block.layer[0].SelfAttention) + held_layers.append(block.layer[1].EncDecAttention) + held_layers.append(block.layer[2].DenseReluDense) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ee76f93653bd..a2fd2b935559 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -3,6 +3,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -69,6 +70,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'num_microbatches': 2, 'enable_fused_normalization': True, 'use_lazy_init': True +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False }]) @clear_cache_before_run() def run_t5_test(test_config): @@ -76,32 +92,6 @@ def run_t5_test(test_config): # TODO: add plugin_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - # Completed: - # { - # 'tp_size': 1, - # 'pp_size': 2, - # 'num_microbatches': 4, - # 'use_lazy_init': False - # }, { - # 'tp_size': 4, - # 'pp_size': 1, - # 'enable_fused_normalization': True, - # 'use_lazy_init': False - # } - - # { - # 'tp_size': 2, - # 'pp_size': 2, - # 'num_microbatches': 2, - # 'enable_fused_normalization': True, - # 'use_lazy_init': True - # }, { - # 'tp_size': 1, - # 'pp_size': 4, - # 'num_microbatches': 4, - # 'use_lazy_init': False - # } - # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') @@ -116,6 +106,7 @@ def run_t5_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() From d8bdf5a6df40d92ca90316bc5b38c7f9078f17e6 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 8 Aug 2023 11:36:23 +0800 Subject: [PATCH 6/8] erase printing --- tests/test_shardformer/test_model/test_shard_t5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index a2fd2b935559..d807ffa06296 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -22,8 +22,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - print(org_model.__class__.__name__, test_config['tp_size'], test_config['pp_size']) - org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( org_model, From c5cf2ef24a7b3441c81632ee2db7b2bd5b66e2b0 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 8 Aug 2023 16:24:12 +0800 Subject: [PATCH 7/8] fix get_held_layers through modifying _release_unheld_layers --- colossalai/shardformer/policies/t5.py | 15 ++++----------- colossalai/shardformer/shard/sharder.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index b54c66444ab3..2ef52c214c6b 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -276,27 +276,20 @@ def get_held_layers(self) -> List[nn.Module]: if stage_manager.is_first_stage(): held_layers.append(model.shared) held_layers.append(encoder.embed_tokens) + held_layers.append(encoder.dropout) if stage_manager.stage == decoder_starting_stage - 1: held_layers.append(encoder.final_layer_norm) held_layers.append(encoder.dropout) - for i in range(start_idx, end_idx): - block = encoder.block[i] - held_layers.extend(block.layer) - held_layers.append(block.layer[0].SelfAttention) - held_layers.append(block.layer[1].DenseReluDense) + held_layers.extend(encoder.block[start_idx:end_idx]) else: # current stage is in t5's decoder if stage_manager.stage == decoder_starting_stage: held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.dropout) if stage_manager.is_last_stage(): held_layers.append(decoder.final_layer_norm) held_layers.append(decoder.dropout) - for i in range(start_idx, end_idx): - block = decoder.block[i] - held_layers.extend(block.layer) - held_layers.append(block.layer[0].SelfAttention) - held_layers.append(block.layer[1].EncDecAttention) - held_layers.append(block.layer[2].DenseReluDense) + held_layers.extend(decoder.block[start_idx:end_idx]) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ae8cd8c6e553..8ee991141eaf 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -198,6 +198,18 @@ def _replace_sub_module( setattr_(org_layer, suffix, replace_layer) + def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: + + def collect_sub_modules(module: nn.Module): + recursive_held_layers.append(module) + for name, child in module.named_children(): + collect_sub_modules(child) + + recursive_held_layers = [] + for module in held_layers: + collect_sub_modules(module) + return recursive_held_layers + def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: r""" Release the unheld layers in the model @@ -205,7 +217,7 @@ def _release_unheld_layers(self) -> Optional[Set[nn.Module]]: if self.shard_config and self.shard_config.pipeline_stage_manager: held_layers = self.policy.get_held_layers() set_tensors_to_none(self.model, exclude=set(held_layers)) - return set(held_layers) + return set(self._get_recursive_held_layers(held_layers)) return None def _materialize(self) -> None: From 1ff4679c57d21c3c5e53fb95f7da24c60417174a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 8 Aug 2023 17:14:31 +0800 Subject: [PATCH 8/8] fix _get_recursive_held_layers bug --- colossalai/shardformer/shard/sharder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 8ee991141eaf..0ed745a1fc4a 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -201,6 +201,8 @@ def _replace_sub_module( def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: def collect_sub_modules(module: nn.Module): + if module is None: + return recursive_held_layers.append(module) for name, child in module.named_children(): collect_sub_modules(child)