Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer


from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
Expand Down Expand Up @@ -143,7 +147,7 @@ def _unwrap_actor(actor: Actor) -> nn.Module:
return model.module
return model

def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
Expand All @@ -159,10 +163,16 @@ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> N
module.merge_weights=True
module.eval()
# get state_dict and save
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)

if not isinstance(self.model, PreTrainedModel):
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
self.model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
Expand Down
2 changes: 1 addition & 1 deletion applications/ChatGPT/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ tqdm
datasets
loralib
colossalai>=0.2.4
torch
torch==1.12.1
langchain