-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[shardformer] shardformer support opt models #4091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FrankLeeeee
merged 4 commits into
hpcaitech:feature/shardformer
from
flybird11111:feature/shardformer
Jun 27, 2023
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
c187562
[shardformer] shardformer support opt models
flybird11111 684cb96
[shardformer] shardformer support opt models, fix
flybird11111 e3036a6
[shardformer] shardformer support opt models, fix
flybird11111 af122c5
[shardformer] shardformer support opt models, fix
flybird11111 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| from transformers.models.opt.modeling_opt import ( | ||
| OPTAttention, | ||
| OPTDecoder, | ||
| OPTDecoderLayer, | ||
| OPTForCausalLM, | ||
| OPTForSequenceClassification, | ||
| ) | ||
|
|
||
| from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row | ||
|
|
||
| from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription | ||
|
|
||
|
|
||
| class OPTPolicy(Policy): | ||
|
|
||
| def preprocess(self): | ||
| # reshape the embedding layer | ||
| r""" | ||
| Reshape the Embedding layer to make the embedding dimension divisible by world_size | ||
| """ | ||
| vocab_size = self.model.config.vocab_size | ||
| world_size = self.shard_config.tensor_parallel_size | ||
| if vocab_size % world_size != 0: | ||
| new_vocab_size = vocab_size + world_size - vocab_size % world_size | ||
| self.model.resize_token_embeddings(new_vocab_size) | ||
| return self.model | ||
|
|
||
| def module_policy(self): | ||
| base_policy = { | ||
| OPTDecoder: | ||
| ModulePolicyDescription(attribute_replacement={}, | ||
| param_replacement=[], | ||
| sub_module_replacement=[ | ||
| SubModuleReplacementDescription( | ||
| suffix="embed_tokens", | ||
| target_module=Embedding1D, | ||
| ) | ||
| ]), | ||
| OPTDecoderLayer: | ||
| ModulePolicyDescription(attribute_replacement={}, | ||
| param_replacement=[], | ||
| sub_module_replacement=[ | ||
| SubModuleReplacementDescription( | ||
| suffix="fc1", | ||
| target_module=Linear1D_Col, | ||
| ), | ||
| SubModuleReplacementDescription( | ||
| suffix="fc2", | ||
| target_module=Linear1D_Row, | ||
| ) | ||
| ]), | ||
| OPTAttention: | ||
| ModulePolicyDescription(attribute_replacement={ | ||
| "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, | ||
| "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size | ||
| }, | ||
| param_replacement=[], | ||
| sub_module_replacement=[ | ||
| SubModuleReplacementDescription( | ||
| suffix="q_proj", | ||
| target_module=Linear1D_Col, | ||
| ), | ||
| SubModuleReplacementDescription( | ||
| suffix="k_proj", | ||
| target_module=Linear1D_Col, | ||
| ), | ||
| SubModuleReplacementDescription( | ||
| suffix="v_proj", | ||
| target_module=Linear1D_Col, | ||
| ), | ||
| SubModuleReplacementDescription( | ||
| suffix="out_proj", | ||
| target_module=Linear1D_Row, | ||
| ), | ||
| ]), | ||
| } | ||
| if self.shard_config.fused_layernorm: | ||
| base_policy[OPTDecoder].sub_module_replacement.append( | ||
| SubModuleReplacementDescription(suffix="final_layer_norm", | ||
| target_module=FusedLayerNorm, | ||
| ignore_if_not_exist=True)) | ||
| base_policy[OPTDecoderLayer].sub_module_replacement.extend([ | ||
| SubModuleReplacementDescription(suffix="self_attn_layer_norm", | ||
| target_module=FusedLayerNorm, | ||
| ignore_if_not_exist=True), | ||
| SubModuleReplacementDescription(suffix="final_layer_norm", | ||
| target_module=FusedLayerNorm, | ||
| ignore_if_not_exist=True) | ||
| ]) | ||
| return base_policy | ||
|
|
||
| def new_model_class(self): | ||
| return None | ||
|
|
||
| def postprocess(self): | ||
| return self.model | ||
|
|
||
|
|
||
| class OPTModelPolicy(OPTPolicy): | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
|
|
||
|
|
||
| class OPTForCausalLMPolicy(OPTPolicy): | ||
|
|
||
| def module_policy(self): | ||
| policy = super().module_policy() | ||
| new_item = { | ||
| OPTForCausalLM: | ||
| ModulePolicyDescription(attribute_replacement={}, | ||
| param_replacement=[], | ||
| sub_module_replacement=[ | ||
| SubModuleReplacementDescription(suffix="lm_head", | ||
| target_module=Linear1D_Col, | ||
| kwargs=dict(gather_output=True)) | ||
| ]) | ||
| } | ||
|
|
||
| policy.update(new_item) | ||
| return policy | ||
|
|
||
|
|
||
| class OPTForSequenceClassificationPolicy(OPTPolicy): | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
|
|
||
|
|
||
| class OPTForQuestionAnsweringPolicy(OPTPolicy): | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import copy | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| import colossalai | ||
| from colossalai.logging import disable_existing_loggers | ||
| from colossalai.testing import ( | ||
| assert_hf_output_close, | ||
| check_state_dict_equal, | ||
| clear_cache_before_run, | ||
| rerun_if_address_is_in_use, | ||
| spawn, | ||
| ) | ||
| from tests.kit.model_zoo import model_zoo | ||
| from tests.test_shardformer.test_model._utils import build_model, run_forward | ||
|
|
||
| os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' | ||
|
|
||
|
|
||
| def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): | ||
| org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, | ||
| output_transform_fn, loss_fn) | ||
| assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) | ||
|
|
||
| # run backward | ||
| org_loss.backward() | ||
| shard_loss.backward() | ||
|
|
||
| # check grad | ||
| if hasattr(org_model, 'model'): | ||
| opt_model = org_model.model | ||
| shard_opt_model = sharded_model.model | ||
| else: | ||
| opt_model = org_model | ||
| shard_opt_model = sharded_model | ||
|
|
||
| org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad | ||
| shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad | ||
|
|
||
| shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] | ||
| shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) | ||
| all_shard_grad = torch.cat(shard_grad_list, dim=0) | ||
| assert torch.allclose(org_loss, shard_loss, | ||
| atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" | ||
| assert torch.allclose(org_grad, all_shard_grad, | ||
| atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" | ||
|
|
||
|
|
||
| def check_OPTModel(rank, world_size, port): | ||
| disable_existing_loggers() | ||
| colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
|
|
||
| sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') | ||
| for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): | ||
| org_model, sharded_model = build_model(world_size, model_fn) | ||
| check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) | ||
|
|
||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| @pytest.mark.dist | ||
| @rerun_if_address_is_in_use() | ||
| @clear_cache_before_run() | ||
| def test_OPTModel(): | ||
| spawn(check_OPTModel, 4) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.