diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index e4d0a970740d..12d4a7fda5ad 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -8,7 +8,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader - +from torch.utils.tensorboard import SummaryWriter from colossalai.logging import DistributedLogger from .base import SLTrainer @@ -37,6 +37,7 @@ def __init__( lr_scheduler: _LRScheduler, max_epochs: int = 2, accumulation_steps: int = 8, + tensorboard_dir: str = None, ) -> None: if accumulation_steps > 1: assert not isinstance(strategy, GeminiStrategy), \ @@ -72,6 +73,10 @@ def _train(self, epoch: int): self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() self.scheduler.step() + if is_rank_0() and self.tensorboard_writer: + self.tensorboard_writer.add_scalar('loss', self.total_loss / self.accumulation_steps) + self.tensorboard_writer.add_scalar('lr', self.scheduler.get_last_lr()[0]) + self.tensorboard_writer.flush() if is_rank_0() and self.use_wandb: wandb.log({ "loss": self.total_loss / self.accumulation_steps, @@ -105,7 +110,8 @@ def _before_fit(self, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, logger: Optional[DistributedLogger] = None, - use_wandb: bool = False): + use_wandb: bool = False, + tensorboard_dir: str = None): """ Args: train_dataloader: the dataloader to use for training @@ -116,9 +122,12 @@ def _before_fit(self, self.logger = logger self.use_wandb = use_wandb + self.tensorboard_writer = None if use_wandb: wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) wandb.watch(self.model) + if tensorboard_dir: + self.tensorboard_writer = SummaryWriter(log_dir=tensorboard_dir) if is_rank_0() else None self.total_loss = 0 self.no_epoch_bar = True diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index f068ea2bf5de..118c34461b31 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -175,7 +175,8 @@ def train(args): trainer.fit(train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, - use_wandb=args.use_wandb) + use_wandb=args.use_wandb, + tensorboard_dir=args.tensorboard_dir) # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) @@ -206,6 +207,7 @@ def train(args): parser.add_argument('--lr', type=float, default=5e-6) parser.add_argument('--accumulation_steps', type=int, default=8) parser.add_argument('--use_wandb', default=False, action='store_true') + parser.add_argument('--tensorboard_dir', type=str, default="") parser.add_argument('--grad_checkpoint', default=False, action='store_true') args = parser.parse_args() train(args) diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index e5f5ca0932a8..afb0dddb644f 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -11,3 +11,4 @@ sse_starlette wandb sentencepiece gpustat +tensorboard