From 7ca2b752dbe1d64b89c2197c18bb6ca0134f5882 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 27 Nov 2023 15:27:34 +0800 Subject: [PATCH] style: polish code and use packaging.version --- colossalai/shardformer/layer/normalization.py | 2 +- colossalai/shardformer/modeling/mistral.py | 18 ++++----- colossalai/shardformer/policies/bloom.py | 8 ++-- colossalai/shardformer/policies/falcon.py | 8 ++-- colossalai/shardformer/policies/mistral.py | 15 +++++--- colossalai/shardformer/policies/opt.py | 8 ++-- colossalai/shardformer/policies/whisper.py | 8 ++-- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/mistral.py | 15 ++++---- .../test_model/test_shard_mistral.py | 38 ++++++++++++++----- 10 files changed, 74 insertions(+), 47 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 81033c429fc7..f1aeb9ab8466 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -110,7 +110,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm - if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: + if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: normalized_shape = module.weight.shape[0] eps = module.variance_epsilon elementwise_affine = True diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 02bd72ea3dda..1ddb26c25d5c 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -1,14 +1,6 @@ -import warnings -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.utils import logging def get_mistral_flash_attention_forward(): @@ -29,8 +21,12 @@ def forward( assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 1a611bffeea0..7fc2e6719c3e 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -24,9 +24,11 @@ class BloomPolicy(Policy): def __init__(self) -> None: super().__init__() import transformers - version_string = transformers.__version__ - version_tuple = tuple(map(int, version_string.split('.')[:3])) - assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0." + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Bloom model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 660b19984103..f2eeb9d69c81 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -22,9 +22,11 @@ class FalconPolicy(Policy): def __init__(self) -> None: super().__init__() import transformers - version_string = transformers.__version__ - version_tuple = tuple(map(int, version_string.split('.')[:3])) - assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0." + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Falcon model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index f3c7066c0515..c16aa6deab3b 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -1,10 +1,7 @@ import warnings -from functools import partial -from typing import Callable, Dict, List, Union +from typing import Dict, Union import torch.nn as nn -from torch import Tensor -from torch.nn import Module from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D @@ -37,13 +34,16 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn( + "Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, } policy[MistralDecoderLayer] = ModulePolicyDescription( @@ -129,6 +129,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def postprocess(self): return self.model + class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: super().__init__() @@ -139,6 +140,7 @@ def module_policy(self): return super().module_policy() + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM @@ -164,6 +166,7 @@ def module_policy(self): return policy + class MistralForSequenceClassificationPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForSequenceClassification diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9123e1c29d8d..fd39eacc8c60 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -25,9 +25,11 @@ class OPTPolicy(Policy): def __init__(self) -> None: super().__init__() import transformers - version_string = transformers.__version__ - version_tuple = tuple(map(int, version_string.split('.')[:3])) - assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0." + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The OPT model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 962f348ed7ac..3065ef62eb9d 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -29,9 +29,11 @@ class WhisperPolicy(Policy): def __init__(self) -> None: super().__init__() import transformers - version_string = transformers.__version__ - version_tuple = tuple(map(int, version_string.split('.')[:3])) - assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0." + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Whisper model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index a9933a18aa22..be6d92f012a9 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -12,6 +12,7 @@ from .t5 import * from .vit import * from .whisper import * + try: from .mistral import * except ImportError: diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 8fa09424cc24..37f87585759e 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -1,14 +1,14 @@ import torch import transformers +from transformers import MistralConfig from ..registry import ModelAttribute, model_zoo -from transformers import MistralConfig - # =============================== # Register single-sentence Mistral # =============================== + def data_gen(): # Generated from following code snippet # @@ -18,10 +18,11 @@ def data_gen(): # tokenized_input = tokenizer([input], return_tensors="pt") # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) + input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) + def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` @@ -29,12 +30,14 @@ def data_gen_for_lm(): data["labels"] = data["input_ids"].clone() return data + def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() data["labels"] = torch.tensor([1], dtype=torch.int64) return data + # define output transform function output_transform_fn = lambda x: x @@ -46,11 +49,7 @@ def data_gen_for_sequence_classification(): loss_fn_for_seq_classification = lambda output: output.logits.mean() config = MistralConfig( - hidden_size=256, - intermediate_size=256, - num_attention_heads=64, - num_hidden_layers=2, - vocab_size=50258 + hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 ) model_zoo.register( diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 0cd9c1cd2bb4..07bc91b33b72 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -50,10 +50,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( - mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + mistral_model, + shard_mistral_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) col_layer_grads = get_grad_tensors_for_check( - mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + mistral_model, + shard_mistral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -81,7 +95,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 check_weight( - mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + mistral_model, + shard_mistral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) # check grads @@ -101,10 +122,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp32", }, { - "tp_size": 2, - "pp_size": 1, - "enable_all_optimization": True, - "use_lazy_init": False, + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, "precision": "fp32", }, { @@ -135,7 +156,6 @@ def check_mistral(rank, world_size, port): run_mistral_test() - @pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") @pytest.mark.dist @rerun_if_address_is_in_use() @@ -145,4 +165,4 @@ def test_mistral(): if __name__ == "__main__": - test_mistral() \ No newline at end of file + test_mistral()