Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions applications/Chat/coati/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training
"""

def __init__(self,
strategy: Strategy,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
def __init__(
self,
strategy: Strategy,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
Expand All @@ -50,10 +51,7 @@ def _before_fit(self):

def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs,
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
self._train(epoch)
self._eval(epoch)

Expand All @@ -75,8 +73,7 @@ def __init__(self,
buffer: NaiveReplayBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []
) -> None:
callbacks: List[Callback] = []) -> None:
super().__init__()
self.strategy = strategy
self.buffer = buffer
Expand Down Expand Up @@ -138,7 +135,7 @@ def _make_experience(self, collect_step: int):
@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
Expand All @@ -154,13 +151,14 @@ def _update_phase(self, update_step: int):
self._learn(update_step)
self._on_learn_epoch_end(update_step)

def fit(self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
):
def fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
):
"""
The main training loop of on-policy rl trainers.

Expand All @@ -175,23 +173,16 @@ def fit(self,
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)

with self._fit_ctx():
for episode in tqdm.trange(num_episodes,
desc="Episodes",
disable=not is_rank_0()):
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps,
desc="Collect steps",
disable=not is_rank_0()):
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer,
self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.buffer.clear()