diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 85ad7ce71fb0..ee7a62f30301 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -174,6 +174,15 @@ class PolicyLocation: "transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation( file_name="falcon", class_name="FalconForQuestionAnsweringPolicy" ), + "transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation( + file_name="mistral", class_name="MistralModelPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation( + file_name="mistral", class_name="MistralForCausalLMPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation( + file_name="mistral", class_name="MistralForSequenceClassificationPolicy" + ), } _INFER_POLICY_LIST = { diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 13b9dd31345d..1a611bffeea0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -21,6 +21,13 @@ class BloomPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Bloom model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 0c0c6ed6d68f..660b19984103 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -19,6 +19,13 @@ class FalconPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Falcon model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index e1552f7c6d82..f3c7066c0515 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -133,6 +133,12 @@ class MistralModelPolicy(MistralPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + + return super().module_policy() + class MistralForCausalLMPolicy(MistralPolicy): def module_policy(self): from transformers import MistralForCausalLM @@ -150,6 +156,10 @@ def module_policy(self): ] ) } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + policy.update(new_item) return policy @@ -171,5 +181,9 @@ def module_policy(self): ] ) } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 5739d21a3903..9123e1c29d8d 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -22,6 +22,13 @@ class OPTPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The OPT model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index d9af2461cdb8..962f348ed7ac 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -26,6 +26,13 @@ class WhisperPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + version_string = transformers.__version__ + version_tuple = tuple(map(int, version_string.split('.')[:3])) + assert version_tuple <= (4, 33, 0), "The Whisper model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index d8d2c14b65d2..a9933a18aa22 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -12,3 +12,7 @@ from .t5 import * from .vit import * from .whisper import * +try: + from .mistral import * +except ImportError: + print("This version of transformers doesn't support mistral.") diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py new file mode 100644 index 000000000000..8fa09424cc24 --- /dev/null +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -0,0 +1,79 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +from transformers import MistralConfig + +# =============================== +# Register single-sentence Mistral +# =============================== + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + +config = MistralConfig( + hidden_size=256, + intermediate_size=256, + num_attention_heads=64, + num_hidden_layers=2, + vocab_size=50258 +) + +model_zoo.register( + name="transformers_mistral", + model_fn=lambda: transformers.MistralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mistral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_casual_lm", + model_fn=lambda: transformers.MistralForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_sequence_classification", + model_fn=lambda: transformers.MistralForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py new file mode 100644 index 000000000000..0cd9c1cd2bb4 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -0,0 +1,148 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + mistral_model = unwrap_model(org_model, "MistralModel", "model") + shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 5e-5, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + mistral_model, shard_mistral_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "MistralModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + mistral_model, shard_mistral_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_mistral_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_mistral(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_mistral_test() + + + +@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_mistral(): + spawn(check_mistral, 4) + + +if __name__ == "__main__": + test_mistral() \ No newline at end of file