Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

SUPPORT_PEFT = False
try:
import peft

SUPPORT_PEFT = True
except ImportError:
pass

import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
Expand Down Expand Up @@ -221,6 +229,38 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.

Args:
model (nn.Module): The model to be appended with LoRA modules.
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")

assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
if pretrained_dir is None:
assert (
lora_config is not None
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
assert isinstance(
lora_config, peft.LoraConfig
), "The passed in configuration should be an instance of peft.LoraConfig."
if lora_config is None:
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.

Expand Down Expand Up @@ -323,3 +363,20 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)

def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.

Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -360,6 +360,9 @@ def __init__(
def support_no_sync(self) -> bool:
return False

def support_lora(self) -> bool:
return False

def control_precision(self) -> bool:
return True

Expand Down Expand Up @@ -408,3 +411,8 @@ def get_checkpoint_io(self) -> CheckpointIO:

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import nullcontext
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -753,6 +753,9 @@ def control_precision(self) -> bool:
def support_no_sync(self) -> bool:
return False

def support_lora(self) -> bool:
return False

def control_checkpoint_io(self) -> bool:
return True

Expand Down Expand Up @@ -891,3 +894,8 @@ def get_checkpoint_io(self) -> CheckpointIO:

def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError

def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError
10 changes: 9 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -295,6 +295,9 @@ def __init__(
def support_no_sync(self) -> bool:
return self.stage == 1

def support_lora(self) -> bool:
return False

def control_precision(self) -> bool:
return True

Expand Down Expand Up @@ -336,3 +339,8 @@ def get_checkpoint_io(self) -> CheckpointIO:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.no_sync()

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError
12 changes: 11 additions & 1 deletion colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple

import torch.nn as nn
from torch.optim import Optimizer
Expand Down Expand Up @@ -33,6 +33,10 @@ def control_device(self) -> bool:
def support_no_sync(self) -> bool:
pass

@abstractmethod
def support_lora(self) -> bool:
pass

@abstractmethod
def configure(
self,
Expand Down Expand Up @@ -63,6 +67,12 @@ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[Non
Context manager to disable gradient synchronization.
"""

@abstractmethod
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
"""
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
"""

@abstractmethod
def prepare_dataloader(
self,
Expand Down
32 changes: 31 additions & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -116,6 +116,22 @@ def load_sharded_optimizer(
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)

def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
"""
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)


class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
Expand Down Expand Up @@ -173,6 +189,9 @@ def __init__(
def support_no_sync(self) -> bool:
return True

def support_lora(self) -> bool:
return True

def control_precision(self) -> bool:
return False

Expand Down Expand Up @@ -216,3 +235,14 @@ def get_checkpoint_io(self) -> CheckpointIO:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
from peft import PeftModel, get_peft_model

assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)
else:
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -190,7 +190,10 @@ def __init__(
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

def support_no_sync(self) -> bool:
False
return False

def support_lora(self) -> bool:
return False

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
Expand Down Expand Up @@ -235,3 +238,8 @@ def control_checkpoint_io(self) -> bool:

def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()

def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError
17 changes: 17 additions & 0 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,20 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)

# ================================================================================
# Abstract method for lora saving implementation.
# ================================================================================

@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.

Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""
3 changes: 3 additions & 0 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,6 @@ def load_sharded_model(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)

def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
raise NotImplementedError
3 changes: 2 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torchvision
transformers==4.33.0
timm
titans
torchaudio
torchaudio>=0.13.1
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
torchrec==0.2.0
contexttimer
Expand All @@ -18,4 +18,5 @@ SentencePiece
ninja
flash_attn==2.0.5
datasets
peft
#auto-gptq now not support torch1.12
Loading