Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions applications/Chat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter
from colossalai.logging import DistributedLogger

from .base import SLTrainer
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
tensorboard_dir: str = None,
) -> None:
if accumulation_steps > 1:
assert not isinstance(strategy, GeminiStrategy), \
Expand Down Expand Up @@ -72,6 +73,10 @@ def _train(self, epoch: int):
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and self.tensorboard_writer:
self.tensorboard_writer.add_scalar('loss', self.total_loss / self.accumulation_steps)
self.tensorboard_writer.add_scalar('lr', self.scheduler.get_last_lr()[0])
self.tensorboard_writer.flush()
if is_rank_0() and self.use_wandb:
wandb.log({
"loss": self.total_loss / self.accumulation_steps,
Expand Down Expand Up @@ -105,7 +110,8 @@ def _before_fit(self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
use_wandb: bool = False):
use_wandb: bool = False,
tensorboard_dir: str = None):
"""
Args:
train_dataloader: the dataloader to use for training
Expand All @@ -116,9 +122,12 @@ def _before_fit(self,

self.logger = logger
self.use_wandb = use_wandb
self.tensorboard_writer = None
if use_wandb:
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
wandb.watch(self.model)
if tensorboard_dir:
self.tensorboard_writer = SummaryWriter(log_dir=tensorboard_dir) if is_rank_0() else None

self.total_loss = 0
self.no_epoch_bar = True
Expand Down
4 changes: 3 additions & 1 deletion applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def train(args):
trainer.fit(train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
use_wandb=args.use_wandb)
use_wandb=args.use_wandb,
tensorboard_dir=args.tensorboard_dir)

# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
Expand Down Expand Up @@ -206,6 +207,7 @@ def train(args):
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accumulation_steps', type=int, default=8)
parser.add_argument('--use_wandb', default=False, action='store_true')
parser.add_argument('--tensorboard_dir', type=str, default="")
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
args = parser.parse_args()
train(args)
1 change: 1 addition & 0 deletions applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ sse_starlette
wandb
sentencepiece
gpustat
tensorboard