From 16923494917f40e121fb5586f4922b97fea6caee Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 25 Apr 2023 15:16:47 +0800 Subject: [PATCH] [chat] refactor performance evaluator --- .../Chat/coati/ray/callbacks/__init__.py | 9 + applications/Chat/coati/ray/callbacks/base.py | 66 ++++++ .../ray/callbacks/performance_evaluator.py | 212 ++++++++++++++++++ .../Chat/coati/ray/detached_trainer_base.py | 44 ++-- .../Chat/coati/ray/detached_trainer_ppo.py | 6 +- .../Chat/coati/ray/experience_maker_holder.py | 36 ++- .../callbacks/performance_evaluator.py | 89 -------- 7 files changed, 347 insertions(+), 115 deletions(-) create mode 100644 applications/Chat/coati/ray/callbacks/__init__.py create mode 100644 applications/Chat/coati/ray/callbacks/base.py create mode 100644 applications/Chat/coati/ray/callbacks/performance_evaluator.py diff --git a/applications/Chat/coati/ray/callbacks/__init__.py b/applications/Chat/coati/ray/callbacks/__init__.py new file mode 100644 index 000000000000..5f5e488f383e --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/__init__.py @@ -0,0 +1,9 @@ +from .base import MakerCallback, TrainerCallback +from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator + +__all__ = [ + "TrainerCallback", + "MakerCallback", + "ExperienceMakerPerformanceEvaluator", + "TrainerPerformanceEvaluator", +] diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py new file mode 100644 index 000000000000..3306150a41ff --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/base.py @@ -0,0 +1,66 @@ +from abc import ABC + +from coati.experience_maker import Experience + + +class TrainerCallback(ABC): + """ + Base callback class. It defines the interface for callbacks. + """ + + def on_fit_start(self) -> None: + pass + + def on_fit_end(self) -> None: + pass + + def on_episode_start(self, episode: int) -> None: + pass + + def on_episode_end(self, episode: int) -> None: + pass + + def on_epoch_start(self, epoch: int) -> None: + pass + + def on_epoch_end(self, epoch: int) -> None: + pass + + def on_batch_start(self) -> None: + pass + + def on_batch_end(self, metrics: dict, experience: Experience) -> None: + pass + + def on_update_start(self) -> None: + pass + + def on_update_end(self) -> None: + pass + + +class MakerCallback(ABC): + + def on_loop_start(self) -> None: + pass + + def on_loop_end(self) -> None: + pass + + def on_make_experience_start(self) -> None: + pass + + def on_make_experience_end(self, experience: Experience) -> None: + pass + + def on_send_start(self) -> None: + pass + + def on_send_end(self) -> None: + pass + + def on_batch_start(self) -> None: + pass + + def on_batch_end(self) -> None: + pass diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py new file mode 100644 index 000000000000..cd3517609e7a --- /dev/null +++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py @@ -0,0 +1,212 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +from coati.experience_maker import Experience + +from .base import MakerCallback, TrainerCallback + + +def get_world_size() -> int: + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def print_rank_0(*args, **kwargs) -> None: + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0. + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + self.duration += time() - self.start_time + + def reset(self) -> None: + self.duration = 0. + + +class ExperienceMakerPerformanceEvaluator(MakerCallback): + + def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, + reward_model_num_params: int) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.initial_model_num_params = initial_model_num_params + self.reward_model_num_params = reward_model_num_params + + self.batch_timer = Timer() + self.send_timer = Timer() + self.make_experience_timer = Timer() + self.total_samples: int = 0 + self.make_experience_flop: int = 0 + + print_rank_0( + f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}' + ) + + def on_make_experience_start(self) -> None: + self.make_experience_timer.start() + + def on_make_experience_end(self, experience: Experience) -> None: + self.make_experience_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.total_samples += batch_size + + # actor generate + num_actions = experience.action_mask.size(1) + input_len = seq_len - num_actions + total_seq_len = (input_len + seq_len - 1) * num_actions / 2 + self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2 + # actor forward + self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2 + # critic forward + self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2 + # initial model forward + self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2 + # reward model forward + self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2 + + def on_send_start(self) -> None: + self.send_timer.start() + + def on_send_end(self) -> None: + self.send_timer.end() + + def on_batch_start(self) -> None: + self.batch_timer.start() + + def on_batch_end(self) -> None: + self.batch_timer.end() + + def on_loop_end(self) -> None: + avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size) + avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) + avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size) + + avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12) + avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) + avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size) + avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \ + (self.total_samples * self.world_size) + avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) + + print_rank_0( + 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' + + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' + + + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' + ) + + +class TrainerPerformanceEvaluator(TrainerCallback): + + def __init__(self, + actor_num_params: int, + critic_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_first_episodes: int = 1) -> None: + super().__init__() + self.world_size = get_world_size() + self.actor_num_params = actor_num_params + self.critic_num_params = critic_num_params + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_first_episodes = ignore_first_episodes + self.ignore_this_episode = False + + self.episode_timer = Timer() + self.batch_timer = Timer() + self.update_timer = Timer() + self.total_samples: int = 0 + self.learn_flop: int = 0 + + print_rank_0( + f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}' + ) + + def on_episode_start(self, episodes: int) -> None: + self.ignore_this_episode = episodes < self.ignore_first_episodes + if self.ignore_this_episode: + return + self.episode_timer.start() + + def on_episode_end(self, episodes: int) -> None: + if self.ignore_this_episode: + return + self.episode_timer.end() + + def on_batch_start(self) -> None: + if self.ignore_this_episode: + return + self.batch_timer.start() + + def on_batch_end(self, metrics: dict, experience: Experience) -> None: + if self.ignore_this_episode: + return + self.batch_timer.end() + + batch_size, seq_len = experience.sequences.shape + + self.total_samples += batch_size + + # actor forward-backward, 3 means forward(1) + backward(2) + self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + # critic forward-backward + self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) + + def on_update_start(self) -> None: + if self.ignore_this_episode: + return + self.update_timer.start() + + def on_update_end(self) -> None: + if self.ignore_this_episode: + return + self.update_timer.end() + + def on_fit_end(self) -> None: + if self.total_samples == 0: + print_rank_0('No samples are collected, skip trainer performance evaluation') + return + avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) + avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size) + avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size) + + avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12) + avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12) + avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size) + avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size) + avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) + + print_rank_0( + 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' + + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' + + + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' + ) diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py index a4f666dc5714..ac2d35e9da19 100644 --- a/applications/Chat/coati/ray/detached_trainer_base.py +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -6,10 +6,10 @@ import torch from coati.experience_maker import Experience from coati.replay_buffer.utils import BufferItem -from coati.trainer.callbacks import Callback from torch.utils.data import DataLoader from tqdm import tqdm +from .callbacks import TrainerCallback from .detached_replay_buffer import DetachedReplayBuffer from .utils import is_rank_0 @@ -35,7 +35,7 @@ def __init__(self, train_batch_size: int = 8, buffer_limit: int = 0, dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], + callbacks: List[TrainerCallback] = [], debug: bool = False) -> None: super().__init__() self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) @@ -68,7 +68,9 @@ def _learn(self, update_steps: int, train_epochs: int) -> None: data = [] # warmup pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0()) + self._on_epoch_start(0) self._learn_epoch(pbar, data) + self._on_epoch_end(0) # item is already a batch dataloader = DataLoader(data, batch_size=1, @@ -77,7 +79,9 @@ def _learn(self, update_steps: int, train_epochs: int) -> None: collate_fn=lambda x: x[0]) for epoch in range(1, train_epochs): pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0()) + self._on_epoch_start(epoch) self._learn_epoch(pbar, data) + self._on_epoch_end(epoch) def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: is_warmup = len(data) == 0 @@ -87,9 +91,9 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: # sample a batch and then train to avoid waiting experience = x if not is_warmup else self._buffer_sample() experience.to_device(torch.cuda.current_device()) - self._on_learn_batch_start() + self._on_batch_start() metrics = self.training_step(experience) - self._on_learn_batch_end(metrics, experience) + self._on_batch_end(metrics, experience) if self._debug: print("[trainer] step over") @@ -100,11 +104,14 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: self._on_fit_start() - for _ in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): + for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): + self._on_episode_start(i) self._learn(update_steps, train_epochs) + self._on_update_start() self._update_remote_makers() + self._on_update_end() + self._on_episode_end(i) self._on_fit_end() - self._on_finish() @ray.method(concurrency_group="buffer_length") def buffer_get_length(self): @@ -147,23 +154,26 @@ def _on_episode_end(self, episode: int) -> None: for callback in self.callbacks: callback.on_episode_end(episode) - def _on_learn_epoch_start(self, epoch: int) -> None: + def _on_epoch_start(self, epoch: int) -> None: for callback in self.callbacks: - callback.on_learn_epoch_start(epoch) + callback.on_epoch_start(epoch) - def _on_learn_epoch_end(self, epoch: int) -> None: + def _on_epoch_end(self, epoch: int) -> None: for callback in self.callbacks: - callback.on_learn_epoch_end(epoch) + callback.on_epoch_end(epoch) - def _on_learn_batch_start(self) -> None: + def _on_batch_start(self) -> None: for callback in self.callbacks: - callback.on_learn_batch_start() + callback.on_batch_start() - def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: + def _on_batch_end(self, metrics: dict, experience: Experience) -> None: for callback in self.callbacks: - callback.on_learn_batch_end(metrics, experience) + callback.on_batch_end(metrics, experience) - def _on_finish(self) -> None: + def _on_update_start(self) -> None: for callback in self.callbacks: - if hasattr(callback, 'on_finish'): - callback.on_finish() + callback.on_update_start() + + def _on_update_end(self) -> None: + for callback in self.callbacks: + callback.on_update_end() diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index c5459c4d96d1..347df3d84589 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -6,12 +6,12 @@ from coati.models.base import Actor, Critic from coati.models.loss import PolicyLoss, ValueLoss from coati.trainer.callbacks import Callback -from coati.trainer.callbacks.performance_evaluator import TrainerPerformaceEvaluator from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy from torch.optim import Adam from colossalai.nn.optimizer import HybridAdam +from .callbacks import TrainerCallback, TrainerPerformanceEvaluator from .detached_trainer_base import DetachedTrainer from .utils import ( get_actor_from_args, @@ -64,7 +64,7 @@ def __init__( eps_clip: float = 0.2, value_clip: float = 0.4, dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], + callbacks: List[TrainerCallback] = [], eval_performance: bool = False, debug: bool = False, ) -> None: @@ -80,7 +80,7 @@ def __init__( if eval_performance: actor_numel = get_model_numel(self.actor) critic_numel = get_model_numel(self.critic) - evaluator = TrainerPerformaceEvaluator(actor_numel, critic_numel) + evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel) callbacks = callbacks + [evaluator] if isinstance(self.strategy, ColossalAIStrategy): diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 4616c01bdf0f..996996400064 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -12,13 +12,13 @@ from coati.models.base import Actor, Critic, RewardModel from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.trainer.callbacks import Callback -from coati.trainer.callbacks.performance_evaluator import ExperienceMakerPerformanceEvaluator from coati.trainer.strategies import Strategy from coati.trainer.strategies.sampler import DistributedSampler from ray.exceptions import GetTimeoutError from torch import Tensor from tqdm import tqdm +from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env @@ -42,7 +42,7 @@ def __init__( sync_models_from_trainers: bool = False, buffer_cpu_offload: bool = True, kl_coef: float = 0.1, - callbacks: List[Callback] = [], + callbacks: List[MakerCallback] = [], eval_performance: bool = False, debug: bool = False, **generate_kwargs): @@ -122,13 +122,17 @@ def _send_items(self, experience: Experience) -> None: target_trainer.buffer_extend.remote(items_per_trainer[i]) def _inference_step(self, batch) -> None: + self._on_batch_start() with self._model_visit_lock: self._on_make_experience_start() experience = self._make_experience(batch) self._on_make_experience_end(experience) + self._on_send_start() if self.buffer_cpu_offload: experience.to_device('cpu') self._send_items(experience) + self._on_send_end() + self._on_batch_end() def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0): """Working loop of the experience maker. @@ -139,6 +143,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0. """ self._get_ready() + self._on_loop_start() dataloader = dataloader_fn() if num_steps > 0: # ignore num epochs @@ -156,7 +161,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 for batch in dataloader: self._inference_step(batch) pbar.update() - self._on_finish() + self._on_loop_end() @ray.method(concurrency_group="model_io") def update_experience_maker(self, @@ -205,10 +210,29 @@ def _on_make_experience_end(self, experience: Experience) -> None: for callback in self.callbacks: callback.on_make_experience_end(experience) - def _on_finish(self) -> None: + def _on_loop_start(self) -> None: for callback in self.callbacks: - if hasattr(callback, 'on_finish'): - callback.on_finish() + callback.on_loop_start() + + def _on_loop_end(self) -> None: + for callback in self.callbacks: + callback.on_loop_end() + + def _on_send_start(self) -> None: + for callback in self.callbacks: + callback.on_send_start() + + def _on_send_end(self) -> None: + for callback in self.callbacks: + callback.on_send_end() + + def _on_batch_start(self) -> None: + for callback in self.callbacks: + callback.on_batch_start() + + def _on_batch_end(self) -> None: + for callback in self.callbacks: + callback.on_batch_end() def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None: diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 0aebd2bf6280..5ca44a52d6e7 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -29,95 +29,6 @@ def all_reduce_mean(x: float, world_size: int) -> float: return tensor.item() -class ExperienceMakerPerformanceEvaluator(Callback): - - def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, - reward_model_num_params: int) -> None: - super().__init__() - self.world_size = get_world_size() - self.actor_num_params = actor_num_params - self.critic_num_params = critic_num_params - self.initial_model_num_params = initial_model_num_params - self.reward_model_num_params = reward_model_num_params - - self.make_experience_duration: float = 0. - self.make_experience_start_time: Optional[float] = None - self.make_experience_num_samples: int = 0 - self.make_experience_flop: int = 0 - - def on_make_experience_start(self) -> None: - self.make_experience_start_time = time() - - def on_make_experience_end(self, experience: Experience) -> None: - self.make_experience_duration += time() - self.make_experience_start_time - - batch_size, seq_len = experience.sequences.shape - - self.make_experience_num_samples += batch_size - - # actor generate - num_actions = experience.action_mask.size(1) - input_len = seq_len - num_actions - total_seq_len = (input_len + seq_len - 1) * num_actions / 2 - self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2 - # actor forward - self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2 - # critic forward - self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2 - # initial model forward - self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2 - # reward model forward - self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2 - - def on_finish(self) -> None: - avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size) - - avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12) - avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) - - print_rank_0( - f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}' - ) - - -class TrainerPerformaceEvaluator(Callback): - - def __init__(self, actor_num_params: int, critic_num_params: int, enable_grad_checkpoint: bool = False) -> None: - super().__init__() - self.world_size = get_world_size() - self.actor_num_params = actor_num_params - self.critic_num_params = critic_num_params - self.enable_grad_checkpoint = enable_grad_checkpoint - - self.learn_duration: float = 0. - self.learn_start_time: Optional[float] = None - self.learn_num_samples: int = 0 - self.learn_flop: int = 0 - - def on_learn_batch_start(self) -> None: - self.learn_start_time = time() - - def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None: - self.learn_duration += time() - self.learn_start_time - - batch_size, seq_len = experience.sequences.shape - - self.learn_num_samples += batch_size - - # actor forward-backward, 3 means forward(1) + backward(2) - self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) - # critic forward-backward - self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint)) - - def on_finish(self) -> None: - avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size) - - avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12) - avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12) - - print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}') - - class PerformanceEvaluator(Callback): """ Callback for valuate the performance of the model.