Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions applications/Chat/coati/trainer/strategies/ddp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Optional

import os
import random

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.base import Actor
from coati.models.base import LM, Actor, RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .base import Strategy
from .naive import NaiveStrategy
Expand Down Expand Up @@ -72,17 +75,32 @@ def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module

def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None

for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()

if only_rank0 and dist.get_rank() != 0:
return
model = model.model.module
state_dict = model.state_dict()
torch.save(state_dict, path)

if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)

def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
Expand Down
27 changes: 23 additions & 4 deletions applications/Chat/coati/trainer/strategies/naive.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.optim as optim
from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .base import Strategy

Expand Down Expand Up @@ -38,9 +41,25 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)

def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
torch.save(unwrapped_model.state_dict(), path)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()

if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)

def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model)
Expand Down