Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion applications/Chat/coati/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss

__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
]
3 changes: 1 addition & 2 deletions applications/Chat/coati/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .actor import Actor
from .critic import Critic
from .lm import LM
from .reward_model import RewardModel

__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
__all__ = ['Actor', 'Critic', 'RewardModel']
30 changes: 0 additions & 30 deletions applications/Chat/coati/models/base/lm.py

This file was deleted.

3 changes: 1 addition & 2 deletions applications/Chat/coati/models/bloom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_lm import BLOOMLM
from .bloom_rm import BLOOMRM

__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
38 changes: 0 additions & 38 deletions applications/Chat/coati/models/bloom/bloom_lm.py

This file was deleted.

3 changes: 1 addition & 2 deletions applications/Chat/coati/models/gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_lm import GPTLM
from .gpt_rm import GPTRM

__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
38 changes: 0 additions & 38 deletions applications/Chat/coati/models/gpt/gpt_lm.py

This file was deleted.

3 changes: 1 addition & 2 deletions applications/Chat/coati/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_lm import LlamaLM
from .llama_rm import LlamaRM

__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
40 changes: 0 additions & 40 deletions applications/Chat/coati/models/llama/llama_lm.py

This file was deleted.

22 changes: 18 additions & 4 deletions applications/Chat/coati/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
convert_to_lora_recursively(child, lora_rank)


def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.

Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.

Returns:
nn.Module: The converted module.
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module


class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This class will convert all torch.nn.Linear layer to LoraLinear layer.
Expand All @@ -123,7 +140,4 @@ def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
self.lora_train_bias = lora_train_bias

def convert_to_lora(self) -> None:
if self.lora_rank <= 0:
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
3 changes: 1 addition & 2 deletions applications/Chat/coati/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_lm import OPTLM
from .opt_rm import OPTRM

__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
38 changes: 0 additions & 38 deletions applications/Chat/coati/models/opt/opt_lm.py

This file was deleted.

46 changes: 13 additions & 33 deletions applications/Chat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,19 @@
import time
from typing import List, Optional

import loralib as lora
import torch
import torch.distributed as dist
import wandb
from coati.models.loss import GPTLMLoss
from torch import nn
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler

from colossalai.logging import get_dist_logger

from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device


class SFTTrainer(Trainer):
Expand All @@ -47,19 +40,17 @@ def __init__(
optim: Optimizer,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
batch_size: int = 1,
max_epochs: int = 2,
accimulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accimulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader

self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.optimizer = strategy.setup_optimizer(optim, self.model)
(self.model, self.optimizer) = strategy.prepare((model, optim))

self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
Expand All @@ -86,17 +77,10 @@ def fit(self, logger, use_wandb: bool = False):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):

prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)

outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])

loss = outputs.loss
prompt_logits = outputs.logits

if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
Expand Down Expand Up @@ -135,18 +119,14 @@ def fit(self, logger, use_wandb: bool = False):
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()

outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
# prompt_logits = outputs.logits

loss_sum += loss.item()
num_seen += prompt_ids.size(0)
num_seen += batch["input_ids"].size(0)

loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
import torch.nn as nn
from coati.models.base import LM, Actor, Critic, RewardModel
from coati.models.base import Actor, Critic, RewardModel
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -99,7 +99,7 @@ def _unwrap_model(model: nn.Module) -> nn.Module:
Args:
model (nn.Module): an actor or a critic
"""
if isinstance(model, Actor) or isinstance(model, LM):
if isinstance(model, Actor):
return model.model
return model

Expand Down
Loading