From b40cb35e2b39017a24c6b71532bdcd883be3c1e1 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 00:29:09 -0500 Subject: [PATCH 1/7] Support for complex LR scheduler configuration --- basics/base_task.py | 29 +++++++++------------------ utils/__init__.py | 48 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index f82e1442c..cc6429c43 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -265,27 +265,24 @@ def on_validation_epoch_end(self): # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): - from utils import build_object_from_config + from utils import build_lr_scheduler_from_config scheduler_args = hparams['lr_scheduler_args'] assert scheduler_args['scheduler_cls'] != '' - scheduler = build_object_from_config( - scheduler_args['scheduler_cls'], - optimizer, - **scheduler_args - ) + scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) return scheduler # noinspection PyMethodMayBeStatic def build_optimizer(self, model): - from utils import build_object_from_config + from utils import build_object_from_class_name optimizer_args = hparams['optimizer_args'] assert optimizer_args['optimizer_cls'] != '' if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) - optimizer = build_object_from_config( + optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], + torch.optim.Optimizer, filter(lambda p: p.requires_grad, model.parameters()), **optimizer_args ) @@ -478,21 +475,13 @@ def on_load_checkpoint(self, checkpoint): if checkpoint.get('lr_schedulers', None): assert checkpoint.get('optimizer_states', False) - schedulers = checkpoint['lr_schedulers'] - assert len(schedulers) == 1 # only support one scheduler - scheduler = schedulers[0] - for k, v in scheduler_args.items(): - if k in scheduler and scheduler[k] != v: - rank_zero_info(f'| Overriding scheduler parameter {k} from checkpoint: {scheduler[k]} -> {v}') - scheduler[k] = v - scheduler['base_lrs'] = [group['initial_lr'] for group in checkpoint['optimizer_states'][0]['param_groups']] - new_lrs = simulate_lr_scheduler( + assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler + checkpoint['lr_schedulers'][0] = simulate_lr_scheduler( optimizer_args, scheduler_args, - last_epoch=scheduler['last_epoch'], + step_count=checkpoint['global_step'], num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) ) - scheduler['_last_lr'] = new_lrs - for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs): + for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], checkpoint['lr_schedulers'][0]['_last_lr']): if param_group['lr'] != new_lr: rank_zero_info( f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') diff --git a/utils/__init__.py b/utils/__init__.py index 4c88f1598..8e494bedc 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -260,29 +260,57 @@ def num_params(model, print_out=True, model_name="model"): return parameters -def build_object_from_config(cls_str, *args, **kwargs): +def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): import importlib pkg = ".".join(cls_str.split(".")[:-1]) cls_name = cls_str.split(".")[-1] cls_type = getattr(importlib.import_module(pkg), cls_name) + if parent_cls is not None: + assert issubclass(cls_type, parent_cls), f'| {cls_type} is not subclass of {parent_cls}.' return cls_type(*args, **filter_kwargs(kwargs, cls_type)) +def build_lr_scheduler_from_config(optimizer, scheduler_args): + def helper(params): + if isinstance(params, list): + return [helper(s) for s in params] + elif isinstance(params, dict): + resolved = {k: helper(v) for k, v in params.items()} + resolved['optimizer'] = optimizer + if 'cls' in resolved: + obj = build_object_from_class_name( + resolved['cls'], + torch.optim.lr_scheduler.LRScheduler, + **resolved + ) + if not hasattr(obj, 'last_epoch'): + obj.last_epoch = -1 + return obj + return resolved + else: + return params + resolved = helper(scheduler_args) + resolved['optimizer'] = optimizer + return build_object_from_class_name( + scheduler_args['scheduler_cls'], + torch.optim.lr_scheduler.LRScheduler, + **resolved + ) -def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_param_groups=1): - optimizer = build_object_from_config( +def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): + optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], + torch.optim.Optimizer, [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], **optimizer_args ) - scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, - **scheduler_args) - - if hasattr(scheduler, '_get_closed_form_lr'): - return scheduler._get_closed_form_lr() - else: - return scheduler.get_lr() + scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) + scheduler._initial_step() + optimizer._step_count = 1 + for _ in range(step_count): + scheduler.step() + return scheduler.state_dict() def remove_suffix(string: str, suffix: str): From 16ccf55447b20dbd5de4dd2e1e2a56349a330086 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 00:38:17 -0500 Subject: [PATCH 2/7] Pass in optimizer only if cls exists --- utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/__init__.py b/utils/__init__.py index 8e494bedc..2785a31e5 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -277,8 +277,8 @@ def helper(params): return [helper(s) for s in params] elif isinstance(params, dict): resolved = {k: helper(v) for k, v in params.items()} - resolved['optimizer'] = optimizer if 'cls' in resolved: + resolved['optimizer'] = optimizer obj = build_object_from_class_name( resolved['cls'], torch.optim.lr_scheduler.LRScheduler, From 733f7f8f052b0634aab07d5f45ef8fbe220ccb2d Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 00:29:09 -0500 Subject: [PATCH 3/7] Support for complex LR scheduler configuration --- basics/base_task.py | 29 +++++++++------------------ utils/__init__.py | 48 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 45a3bd8d3..1f39ae56d 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -289,27 +289,24 @@ def on_validation_epoch_end(self): # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): - from utils import build_object_from_config + from utils import build_lr_scheduler_from_config scheduler_args = hparams['lr_scheduler_args'] assert scheduler_args['scheduler_cls'] != '' - scheduler = build_object_from_config( - scheduler_args['scheduler_cls'], - optimizer, - **scheduler_args - ) + scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) return scheduler # noinspection PyMethodMayBeStatic def build_optimizer(self, model): - from utils import build_object_from_config + from utils import build_object_from_class_name optimizer_args = hparams['optimizer_args'] assert optimizer_args['optimizer_cls'] != '' if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args: optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2']) - optimizer = build_object_from_config( + optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], + torch.optim.Optimizer, model.parameters(), **optimizer_args ) @@ -502,21 +499,13 @@ def on_load_checkpoint(self, checkpoint): if checkpoint.get('lr_schedulers', None): assert checkpoint.get('optimizer_states', False) - schedulers = checkpoint['lr_schedulers'] - assert len(schedulers) == 1 # only support one scheduler - scheduler = schedulers[0] - for k, v in scheduler_args.items(): - if k in scheduler and scheduler[k] != v: - rank_zero_info(f'| Overriding scheduler parameter {k} from checkpoint: {scheduler[k]} -> {v}') - scheduler[k] = v - scheduler['base_lrs'] = [group['initial_lr'] for group in checkpoint['optimizer_states'][0]['param_groups']] - new_lrs = simulate_lr_scheduler( + assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler + checkpoint['lr_schedulers'][0] = simulate_lr_scheduler( optimizer_args, scheduler_args, - last_epoch=scheduler['last_epoch'], + step_count=checkpoint['global_step'], num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) ) - scheduler['_last_lr'] = new_lrs - for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs): + for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], checkpoint['lr_schedulers'][0]['_last_lr']): if param_group['lr'] != new_lr: rank_zero_info( f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') diff --git a/utils/__init__.py b/utils/__init__.py index 4c88f1598..8e494bedc 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -260,29 +260,57 @@ def num_params(model, print_out=True, model_name="model"): return parameters -def build_object_from_config(cls_str, *args, **kwargs): +def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): import importlib pkg = ".".join(cls_str.split(".")[:-1]) cls_name = cls_str.split(".")[-1] cls_type = getattr(importlib.import_module(pkg), cls_name) + if parent_cls is not None: + assert issubclass(cls_type, parent_cls), f'| {cls_type} is not subclass of {parent_cls}.' return cls_type(*args, **filter_kwargs(kwargs, cls_type)) +def build_lr_scheduler_from_config(optimizer, scheduler_args): + def helper(params): + if isinstance(params, list): + return [helper(s) for s in params] + elif isinstance(params, dict): + resolved = {k: helper(v) for k, v in params.items()} + resolved['optimizer'] = optimizer + if 'cls' in resolved: + obj = build_object_from_class_name( + resolved['cls'], + torch.optim.lr_scheduler.LRScheduler, + **resolved + ) + if not hasattr(obj, 'last_epoch'): + obj.last_epoch = -1 + return obj + return resolved + else: + return params + resolved = helper(scheduler_args) + resolved['optimizer'] = optimizer + return build_object_from_class_name( + scheduler_args['scheduler_cls'], + torch.optim.lr_scheduler.LRScheduler, + **resolved + ) -def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_param_groups=1): - optimizer = build_object_from_config( +def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): + optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], + torch.optim.Optimizer, [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], **optimizer_args ) - scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch, - **scheduler_args) - - if hasattr(scheduler, '_get_closed_form_lr'): - return scheduler._get_closed_form_lr() - else: - return scheduler.get_lr() + scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) + scheduler._initial_step() + optimizer._step_count = 1 + for _ in range(step_count): + scheduler.step() + return scheduler.state_dict() def remove_suffix(string: str, suffix: str): From b5e58b0b0a0c1726ce4d8d2578cdb3f90f3b59ad Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 00:38:17 -0500 Subject: [PATCH 4/7] Pass in optimizer only if cls exists --- utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/__init__.py b/utils/__init__.py index 8e494bedc..2785a31e5 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -277,8 +277,8 @@ def helper(params): return [helper(s) for s in params] elif isinstance(params, dict): resolved = {k: helper(v) for k, v in params.items()} - resolved['optimizer'] = optimizer if 'cls' in resolved: + resolved['optimizer'] = optimizer obj = build_object_from_class_name( resolved['cls'], torch.optim.lr_scheduler.LRScheduler, From 7c70cdcf20a1d966542ec8121ecefccc0552b9df Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 16:27:33 -0500 Subject: [PATCH 5/7] Done --- utils/__init__.py | 12 ++++++++---- utils/training_utils.py | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/utils/__init__.py b/utils/__init__.py index 2785a31e5..df052d373 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -271,6 +271,7 @@ def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): return cls_type(*args, **filter_kwargs(kwargs, cls_type)) + def build_lr_scheduler_from_config(optimizer, scheduler_args): def helper(params): if isinstance(params, list): @@ -278,14 +279,17 @@ def helper(params): elif isinstance(params, dict): resolved = {k: helper(v) for k, v in params.items()} if 'cls' in resolved: + if ( + resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler" + and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR" + ): + raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.") resolved['optimizer'] = optimizer obj = build_object_from_class_name( resolved['cls'], torch.optim.lr_scheduler.LRScheduler, **resolved ) - if not hasattr(obj, 'last_epoch'): - obj.last_epoch = -1 return obj return resolved else: @@ -298,6 +302,7 @@ def helper(params): **resolved ) + def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): optimizer = build_object_from_class_name( optimizer_args['optimizer_cls'], @@ -306,8 +311,7 @@ def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_ **optimizer_args ) scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) - scheduler._initial_step() - optimizer._step_count = 1 + scheduler.optimizer._step_count = 1 for _ in range(step_count): scheduler.step() return scheduler.state_dict() diff --git a/utils/training_utils.py b/utils/training_utils.py index db8fc0657..ad7f0473a 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -316,7 +316,9 @@ def get_metrics(self, trainer, model): items['steps'] = str(trainer.global_step) for k, v in items.items(): if isinstance(v, float): - if 0.001 <= v < 10: + if np.isnan(v): + items[k] = 'nan' + elif 0.001 <= v < 10: items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-') elif 0.00001 <= v < 0.001: if len(np.format_float_positional(v, unique=True, precision=8, trim='-')) > 8: From 154230dc658e2edfc0c0d8e5782e0fc72283f3f2 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 16:31:46 -0500 Subject: [PATCH 6/7] Template config change --- configs/base.yaml | 1 - configs/templates/config_variance.yaml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/base.yaml b/configs/base.yaml index a64e8dda1..6f16e22f6 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -55,7 +55,6 @@ optimizer_args: weight_decay: 0 lr_scheduler_args: scheduler_cls: torch.optim.lr_scheduler.StepLR - warmup_steps: 2000 step_size: 50000 gamma: 0.5 clip_grad_norm: 1 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 8bb2d44f0..a54ca94a7 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -67,6 +67,7 @@ lambda_var_loss: 1.0 optimizer_args: lr: 0.0006 lr_scheduler_args: + scheduler_cls: torch.optim.lr_scheduler.StepLR step_size: 12000 gamma: 0.75 max_batch_frames: 80000 From a6ac2f687ac9ac5c3466dbefee6cad1319c876e1 Mon Sep 17 00:00:00 2001 From: hrukalive Date: Fri, 4 Aug 2023 16:36:01 -0500 Subject: [PATCH 7/7] Reformat --- basics/base_task.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index ebf38d2b8..70d3d692c 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -494,7 +494,8 @@ def on_load_checkpoint(self, checkpoint): param_group[k] = v if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']: rank_zero_info( - f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}') + f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}' + ) param_group['initial_lr'] = optimizer_args['lr'] if checkpoint.get('lr_schedulers', None): @@ -505,8 +506,10 @@ def on_load_checkpoint(self, checkpoint): step_count=checkpoint['global_step'], num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups']) ) - for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], checkpoint['lr_schedulers'][0]['_last_lr']): + for param_group, new_lr in zip( + checkpoint['optimizer_states'][0]['param_groups'], + checkpoint['lr_schedulers'][0]['_last_lr'], + ): if param_group['lr'] != new_lr: - rank_zero_info( - f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') - param_group['lr'] = new_lr \ No newline at end of file + rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}') + param_group['lr'] = new_lr