Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def print_rank_0(*args, **kwargs) -> None:
print(*args, **kwargs)


def divide(x: float, y: float) -> float:
if y == 0:
return float('inf')
elif y == float('inf'):
return float('nan')
return x / y


@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
Expand All @@ -29,6 +37,24 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item()


class Timer:

def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.

def start(self) -> None:
self.start_time = time()

def end(self) -> None:
assert self.start_time is not None
self.duration += time() - self.start_time
self.start_time = None

def reset(self) -> None:
self.duration = 0.


class PerformanceEvaluator(Callback):
"""
Callback for valuate the performance of the model.
Expand Down Expand Up @@ -58,27 +84,34 @@ def __init__(self,
self.ignore_episodes = ignore_episodes
self.disable: bool = False

self.make_experience_duration: float = 0.
self.make_experience_start_time: Optional[float] = None
self.overall_timer = Timer()
self.make_experience_timer = Timer()
self.learn_timer = Timer()
self.make_experience_num_samples: int = 0
self.make_experience_flop: int = 0
self.learn_duration: float = 0.
self.learn_start_time: Optional[float] = None
self.learn_num_samples: int = 0
self.learn_flop: int = 0

def on_episode_start(self, episode: int) -> None:
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
if self.disable:
return
self.overall_timer.start()

def on_episode_end(self, episode: int) -> None:
if self.disable:
return
self.overall_timer.end()

def on_make_experience_start(self) -> None:
if self.disable:
return
self.make_experience_start_time = time()
self.make_experience_timer.start()

def on_make_experience_end(self, experience: Experience) -> None:
if self.disable:
return
self.make_experience_duration += time() - self.make_experience_start_time
self.make_experience_timer.end()

batch_size, seq_len = experience.sequences.shape

Expand All @@ -101,12 +134,12 @@ def on_make_experience_end(self, experience: Experience) -> None:
def on_learn_batch_start(self) -> None:
if self.disable:
return
self.learn_start_time = time()
self.learn_timer.start()

def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
if self.disable:
return
self.learn_duration += time() - self.learn_start_time
self.learn_timer.end()

batch_size, seq_len = experience.sequences.shape

Expand All @@ -118,16 +151,33 @@ def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))

def on_fit_end(self) -> None:
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)

avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
avg_make_experience_throughput = self.make_experience_num_samples * \
self.world_size / (avg_make_experience_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)

avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)

num_effective_samples = min(self.learn_num_samples, self.make_experience_num_samples) * self.world_size

avg_overall_throughput = num_effective_samples / (avg_overall_duration + 1e-12)

overall_time_per_sample = divide(1, avg_overall_throughput)
make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples)
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)

print_rank_0(
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
f'Performance summary:\n' +
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
)
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')