From bbc47d3a104e1527bec502a33d400647b48dca07 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 24 Nov 2023 15:09:38 +0800 Subject: [PATCH 1/2] [shardformer] add tests to mistral fix fix --- .../shardformer/policies/auto_policy.py | 9 ++ colossalai/shardformer/policies/mistral.py | 14 ++ requirements/requirements-test.txt | 2 +- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/mistral.py | 77 +++++++++ .../test_model/test_shard_mistral.py | 148 ++++++++++++++++++ 6 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 tests/kit/model_zoo/transformers/mistral.py create mode 100644 tests/test_shardformer/test_model/test_shard_mistral.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 85ad7ce71fb0..ee7a62f30301 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -174,6 +174,15 @@ class PolicyLocation: "transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation( file_name="falcon", class_name="FalconForQuestionAnsweringPolicy" ), + "transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation( + file_name="mistral", class_name="MistralModelPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation( + file_name="mistral", class_name="MistralForCausalLMPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation( + file_name="mistral", class_name="MistralForSequenceClassificationPolicy" + ), } _INFER_POLICY_LIST = { diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index e1552f7c6d82..f3c7066c0515 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -133,6 +133,12 @@ class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + + return super().module_policy() + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM @@ -150,6 +156,10 @@ def module_policy(self): ] ) } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + policy.update(new_item) return policy @@ -171,5 +181,9 @@ def module_policy(self): ] ) } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + policy.update(new_item) return policy diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..e5551707bc2a 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers==4.33.0 +transformers timm titans torchaudio diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index d8d2c14b65d2..0cc7baf61f20 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -12,3 +12,4 @@ from .t5 import * from .vit import * from .whisper import * +from .mistral import * diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py new file mode 100644 index 000000000000..8dea515b2c88 --- /dev/null +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -0,0 +1,77 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mistral +# =============================== + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + +config = transformers.MistralConfig( + hidden_size=256, + intermediate_size=256, + num_attention_heads=64, + num_hidden_layers=2, + vocab_size=50258 +) + +model_zoo.register( + name="transformers_mistral", + model_fn=lambda: transformers.MistralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mistral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_casual_lm", + model_fn=lambda: transformers.MistralForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_sequence_classification", + model_fn=lambda: transformers.MistralForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py new file mode 100644 index 000000000000..28f6800829e9 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -0,0 +1,148 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + mistral_model = unwrap_model(org_model, "MistralModel", "model") + shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + # row_layer_for_check = ["layers[0].self_attn.q_proj"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 5e-5, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "MistralModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_mistral_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_mistral(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_mistral_test() + + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_mistral(): + spawn(check_mistral, 4) + + +if __name__ == "__main__": + test_mistral() \ No newline at end of file From 29f67036f1931340e1109ed4211cfa4755c59779 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 24 Nov 2023 17:25:22 +0800 Subject: [PATCH 2/2] fix fix fix fix fix --- colossalai/shardformer/policies/bloom.py | 7 +++++++ colossalai/shardformer/policies/falcon.py | 7 +++++++ colossalai/shardformer/policies/opt.py | 7 +++++++ colossalai/shardformer/policies/whisper.py | 7 +++++++ requirements/requirements-test.txt | 2 +- tests/kit/model_zoo/transformers/__init__.py | 5 ++++- tests/kit/model_zoo/transformers/mistral.py | 4 +++- tests/test_shardformer/test_model/test_shard_mistral.py | 2 +- 8 files changed, 37 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 13b9dd31345d..1a611bffeea0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -21,6 +21,13 @@ class BloomPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 0c0c6ed6d68f..660b19984103 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -19,6 +19,13 @@ class FalconPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5739d21a3903..9123e1c29d8d 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -22,6 +22,13 @@ class OPTPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index d9af2461cdb8..962f348ed7ac 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -26,6 +26,13 @@ class WhisperPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e5551707bc2a..467f83610eb0 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -4,7 +4,7 @@ pytest coverage==7.2.3 git+https://github.com/hpcaitech/pytest-testmon torchvision -transformers +transformers==4.33.0 timm titans torchaudio diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 0cc7baf61f20..a9933a18aa22 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -12,4 +12,7 @@ from .t5 import * from .vit import * from .whisper import * -from .mistral import * +try: + from .mistral import * +except ImportError: + print("This version of transformers doesn't support mistral.") diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index 8dea515b2c88..8fa09424cc24 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -3,6 +3,8 @@ from ..registry import ModelAttribute, model_zoo +from transformers import MistralConfig + # =============================== # Register single-sentence Mistral # =============================== @@ -43,7 +45,7 @@ def data_gen_for_sequence_classification(): loss_fn = lambda x: x.loss loss_fn_for_seq_classification = lambda output: output.logits.mean() -config = transformers.MistralConfig( +config = MistralConfig( hidden_size=256, intermediate_size=256, num_attention_heads=64, diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 28f6800829e9..0cd9c1cd2bb4 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -40,7 +40,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model") row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] - # row_layer_for_check = ["layers[0].self_attn.q_proj"] col_layer_for_check = ["layers[0].self_attn.o_proj"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. @@ -137,6 +136,7 @@ def check_mistral(rank, world_size, port): +@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()