Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3e5f432
refactor: adapt boost API in base and naive strategies
cwher Jun 13, 2023
ab3a855
fix: initialize plugin after setup_distributed
cwher Jun 13, 2023
3ca33ef
fix: fix save_pretrained fn
cwher Jun 13, 2023
40cebd3
refactor: adapt boost API in DDPStrategy
cwher Jun 13, 2023
46f1a05
to: add _post_init check
cwher Jun 14, 2023
bf13531
to: fix ddp backward, modify ddp dataloader and unwrap
cwher Jun 14, 2023
76a7034
feat: adapt boost API in ColossalAIStrategy
cwher Jun 14, 2023
6414f7f
fix: call setup_distributed before use get_current_device
cwher Jun 14, 2023
476cb72
fix: fix save_model and save_optimizer
cwher Jun 14, 2023
e1a86f1
test: remove save_sharded_optimizer test
cwher Jun 14, 2023
6b435c3
style: apply formatter
cwher Jun 14, 2023
8e97777
fix: fix stage check and add comments
cwher Jun 14, 2023
4fbae16
feat: allow dict type arg in strategy.prepare
cwher Jun 14, 2023
4333407
to: temporarily remove lr_scheduler for testing
cwher Jun 14, 2023
82499ec
style: simplify init of ColossalAIStrategy
cwher Jun 15, 2023
af5dd55
fix: fix lr_scheduler in sft and rm
cwher Jun 15, 2023
7cccb8f
style: modify comments
cwher Jun 15, 2023
6632896
test: add train_prompts tests
cwher Jun 16, 2023
f598bf8
fix: fix inference only case and use in train_prompts
cwher Jun 16, 2023
92b4664
test: skip failed tests in ci
cwher Jun 16, 2023
3376450
style: fix CodeFactor check
cwher Jun 16, 2023
a9fb675
fix: do not use model.to('cpu') with GeminiPlugin
cwher Jun 16, 2023
7849182
test: enable colossalai_gemini tests
cwher Jun 16, 2023
6b54b04
test: set CUDA_VISIBLE_DEVICES in ci
cwher Jun 25, 2023
af2ff25
docs: add note
cwher Jun 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
numel *= dist.get_world_size()
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
if isinstance(strategy.plugin, GeminiPlugin) and strategy.shard_init:
numel *= dist.get_world_size()
return numel


Expand Down
9 changes: 8 additions & 1 deletion applications/Chat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device


Expand Down Expand Up @@ -71,6 +71,11 @@ def __init__(self,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
if isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
"GeminiPlugin is not compatible with manual model.to('cpu')"

experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
Expand Down Expand Up @@ -105,6 +110,8 @@ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experien
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
Expand Down
20 changes: 11 additions & 9 deletions applications/Chat/coati/trainer/rm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datetime import datetime
from typing import List, Optional
from typing import Callable, List

import pandas as pd
import torch
import torch.distributed as dist
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .base import Trainer
from .callbacks import Callback
Expand All @@ -22,7 +21,8 @@ class RewardModelTrainer(Trainer):
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
Expand All @@ -37,7 +37,8 @@ def __init__(
model,
strategy: Strategy,
optim: Optimizer,
loss_fn,
lr_scheduler: _LRScheduler,
loss_fn: Callable,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader,
Expand All @@ -53,7 +54,7 @@ def __init__(
self.model = model
self.loss_fn = loss_fn
self.optimizer = optim
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
self.scheduler = lr_scheduler

def eval_acc(self, dataloader):
dist = 0
Expand Down Expand Up @@ -116,7 +117,8 @@ def fit(self):
# eval
dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
Expand Down
20 changes: 8 additions & 12 deletions applications/Chat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import math
import time
from typing import List, Optional
from typing import List

import torch
import torch.distributed as dist
import wandb
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler

from .base import Trainer
from .callbacks import Callback
Expand Down Expand Up @@ -38,28 +36,26 @@ def __init__(
model,
strategy: Strategy,
optim: Optimizer,
lr_scheduler: _LRScheduler,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
max_epochs: int = 2,
accumulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
from colossalai.booster.plugin import GeminiPlugin
assert not isinstance(strategy.plugin, GeminiPlugin), \
"Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.optimizer = optim

self.accumulation_steps = accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)

self.scheduler = get_scheduler("cosine",
self.optimizer,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
self.scheduler = lr_scheduler

def fit(self, logger, use_wandb: bool = False):
if use_wandb:
Expand Down
110 changes: 66 additions & 44 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -9,54 +9,55 @@
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from colossalai.booster import Booster
from colossalai.booster.plugin import Plugin

from .sampler import DistributedSampler

ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]


class Strategy(ABC):
"""
Base class for training strategies.
"""

def __init__(self) -> None:
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
super().__init__()
# NOTE: dist must be initialized before Booster
self.setup_distributed()
self.plugin = plugin_initializer()
self.booster = Booster(plugin=self.plugin)
self._post_init()

@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
def _post_init(self) -> None:
pass

@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
self.booster.backward(loss, optimizer)

def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
pass
optimizer.step()

@abstractmethod
def setup_distributed(self) -> None:
pass

@abstractmethod
def setup_model(self, model: nn.Module) -> nn.Module:
pass

@abstractmethod
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
pass

@abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
pass

def model_init_context(self):
return nullcontext()

def prepare(
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
"""Prepare models or model-optimizer-pairs based on each strategy.
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.

Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
Expand All @@ -65,25 +66,39 @@ def prepare(
>>> actor, critic = strategy.prepare(actor, critic)

Returns:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
"""

rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, model)
for arg in boost_args:
if isinstance(arg, nn.Module):
model, *_ = self.booster.boost(arg)
rets.append(model)
elif isinstance(arg, tuple):
try:
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model,
optimizer=optimizer)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(self.setup_model(model))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# remove None values
boost_result = {
key: value
for key, value in boost_result.items() if value is not None
}
rets.append(boost_result)
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
raise RuntimeError(f'Type {type(arg)} is not supported')

if len(rets) == 1:
return rets[0]
return rets
return rets[0] if len(rets) == 1 else rets

@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
Expand All @@ -97,23 +112,30 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return model

@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
pass
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
**kwargs
) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)

@abstractmethod
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
pass
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)

@abstractmethod
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
pass
def save_optimizer(self,
optimizer: Optimizer,
path: str,
only_rank0: bool = False,
**kwargs
) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)

@abstractmethod
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
pass
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
self.booster.load_optimizer(optimizer, path)

def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, 1, 0)

@abstractmethod
Expand Down
Loading