From d97b215b9344ccd68c4253a7ffb419df2f49ee4c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 10 Dec 2020 21:08:02 -0800 Subject: [PATCH 1/4] implement missing get_last_lr --- deepspeed/runtime/lr_schedules.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 5ec106c28d67..33dc2e2fb18b 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -381,6 +381,11 @@ def get_lr(self): lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + def _update_optimizer(self, group_lrs): for param_group, lr in zip(self.optimizer.param_groups, group_lrs): param_group['lr'] = lr @@ -390,6 +395,7 @@ def step(self, batch_iteration=None): batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration self._update_optimizer(self.get_lr()) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -628,12 +634,18 @@ def get_lr(self): return self._get_cycle_lr() return self._get_decay_lr(self.last_batch_iteration - self.total_size) + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + def step(self, batch_iteration=None): if batch_iteration is None: batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -690,12 +702,18 @@ def get_lr(self): self.delta_lrs) ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} From 6c508a056ec31450b7aef3a9f54044143e207301 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 11 Dec 2020 15:21:43 -0800 Subject: [PATCH 2/4] assert if _last_lr is not set yet --- deepspeed/runtime/lr_schedules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 33dc2e2fb18b..645152b1ea1f 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -384,6 +384,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ + assert getattr(self, '_last_lr', None), "need to call step() first" return self._last_lr def _update_optimizer(self, group_lrs): @@ -637,6 +638,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ + assert getattr(self, '_last_lr', None), "need to call step() first" return self._last_lr def step(self, batch_iteration=None): @@ -705,6 +707,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ + assert getattr(self, '_last_lr', None), "need to call step() first" return self._last_lr def step(self, last_batch_iteration=None): From 0570803890e746a6f25999b8c49f851414d02178 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 11 Dec 2020 18:21:54 -0800 Subject: [PATCH 3/4] improve the assert --- deepspeed/runtime/lr_schedules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 645152b1ea1f..135df424d29d 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -384,7 +384,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', None), "need to call step() first" + assert getattr(self, '_last_lr', False), "need to call step() first" return self._last_lr def _update_optimizer(self, group_lrs): @@ -638,7 +638,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', None), "need to call step() first" + assert getattr(self, '_last_lr', False), "need to call step() first" return self._last_lr def step(self, batch_iteration=None): @@ -707,7 +707,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', None), "need to call step() first" + assert getattr(self, '_last_lr', False), "need to call step() first" return self._last_lr def step(self, last_batch_iteration=None): From 7b55777281aa26aaed3eca61ddb64b9ae169c06e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 12 Dec 2020 10:55:54 -0800 Subject: [PATCH 4/4] really fix the assert --- deepspeed/runtime/lr_schedules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 135df424d29d..e7e3be1e786b 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -384,7 +384,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', False), "need to call step() first" + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr def _update_optimizer(self, group_lrs): @@ -638,7 +638,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', False), "need to call step() first" + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr def step(self, batch_iteration=None): @@ -707,7 +707,7 @@ def get_lr(self): def get_last_lr(self): """ Return last computed learning rate by current scheduler. """ - assert getattr(self, '_last_lr', False), "need to call step() first" + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" return self._last_lr def step(self, last_batch_iteration=None):