From 53195df15e1f2f4a292708b4c1acb86f42959f73 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 31 Jul 2024 03:51:30 +0000 Subject: [PATCH 1/5] lora support hybrid plugin --- .../booster/plugin/hybrid_parallel_plugin.py | 24 ++++++++++++++++--- tests/test_lora/test_lora.py | 4 ++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 2c8cb6ba1e93..bdc24c3a4ed1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -30,6 +30,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy @@ -1194,7 +1195,7 @@ def support_no_sync(self) -> bool: return True def support_lora(self) -> bool: - return False + return True def control_checkpoint_io(self) -> bool: return True @@ -1422,6 +1423,23 @@ def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() def enable_lora( - self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, ) -> Module: - raise NotImplementedError + from peft import PeftModel, get_peft_model + + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index b8daf775db0e..aac09771252f 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -9,7 +9,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @@ -20,7 +20,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin()] test_configs = [ { "lora_config": lora_config, From e0720920f18406de09a4530675b148b025519ab7 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 31 Jul 2024 05:28:06 +0000 Subject: [PATCH 2/5] fix --- tests/test_lora/test_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index aac09771252f..cafecdd4e726 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -20,7 +20,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin()] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)] test_configs = [ { "lora_config": lora_config, From bd03bb06be84874bffeeae4921170af74415ea16 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 31 Jul 2024 06:54:06 +0000 Subject: [PATCH 3/5] fix --- tests/test_lora/test_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index cafecdd4e726..b8daf775db0e 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -9,7 +9,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @@ -20,7 +20,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] test_configs = [ { "lora_config": lora_config, From 652dbfdcd76396e4fb6bbd4d6d956d370654ec81 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 1 Aug 2024 06:37:03 +0000 Subject: [PATCH 4/5] fix --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bdc24c3a4ed1..d1b2392b944e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1432,6 +1432,7 @@ def enable_lora( from peft import PeftModel, get_peft_model assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") From d4b9496bbc869da139038a0ed2461f5a939936de Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 1 Aug 2024 11:09:52 +0000 Subject: [PATCH 5/5] fix --- .../checkpoint_io/hybrid_parallel_checkpoint_io.py | 14 ++++++++++++++ colossalai/shardformer/policies/auto_policy.py | 3 +++ tests/test_lora/test_lora.py | 7 +++++-- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b7097e432a1d..0310df5489b0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -947,3 +947,17 @@ def shard_from_complete_optimizer_state( state_[k] = v.detach().clone().to(device) return state_ + + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ae9f3603c96e..3f1d7c35383c 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -240,6 +240,9 @@ def _fullname(obj): # patch custom models which are not in transformers # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub) # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory) + if module.startswith("peft"): + klass = obj.base_model.model.__class__ + module = klass.__module__ if module.startswith("transformers_modules"): split_module = module.split(".") if len(split_module) >= 2: diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index b8daf775db0e..1ae17025d31e 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -9,7 +9,8 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) - test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] + test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)] test_configs = [ { "lora_config": lora_config, @@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type # test fwd bwd correctness test_model = model_load + if isinstance(model_load, HybridParallelModule): + model_load = model_load.module.module model_copy = copy.deepcopy(model_load) data = data_gen_fn()