Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
Expand Down Expand Up @@ -1194,7 +1195,7 @@ def support_no_sync(self) -> bool:
return True

def support_lora(self) -> bool:
return False
return True

def control_checkpoint_io(self) -> bool:
return True
Expand Down Expand Up @@ -1422,6 +1423,24 @@ def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
self,
model: Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> Module:
raise NotImplementedError
from peft import PeftModel, get_peft_model

assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
14 changes: 14 additions & 0 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,17 @@ def shard_from_complete_optimizer_state(
state_[k] = v.detach().clone().to(device)

return state_

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def _fullname(obj):
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("peft"):
klass = obj.base_model.model.__class__
module = klass.__module__
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
Expand Down
7 changes: 5 additions & 2 deletions tests/test_lora/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
Expand All @@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)

test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
test_configs = [
{
"lora_config": lora_config,
Expand Down Expand Up @@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type

# test fwd bwd correctness
test_model = model_load
if isinstance(model_load, HybridParallelModule):
model_load = model_load.module.module
model_copy = copy.deepcopy(model_load)

data = data_gen_fn()
Expand Down