From 4184eb1d20b0657260caf9f0b97a645bd962601c Mon Sep 17 00:00:00 2001 From: jason9075 Date: Mon, 17 Apr 2023 20:22:11 +0800 Subject: [PATCH 1/6] add constant lr with rules --- src/diffusers/optimization.py | 46 ++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 657e085062e0..daeb9d5f049d 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -16,7 +16,7 @@ import math from enum import Enum -from typing import Optional, Union +from typing import Callable, Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR @@ -34,6 +34,7 @@ class SchedulerType(Enum): POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" + CONSTANT_WITH_RULES = "constant_with_rules" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): @@ -77,6 +78,49 @@ def lr_lambda(current_step: int): return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) +def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate with rule for the learning rate. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + rule (`string`): + The rules for the learning rate. + ex: rules="1:10,0.1:20,0.01:30,0.005" + it means that the learning rate is multiple 1 for the first 10 steps, mutiple 0.1 for the + next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + rules_dict = {} + rule_list = rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr = float(rule_list[-1]) + + def create_rules_function(): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr + + return rule_func + + rules_f = create_rules_function() + + return LambdaLR(optimizer, rules_f, last_epoch=last_epoch) + + def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after From 8dde211806c46d857a361635a4f035e9d55b18f8 Mon Sep 17 00:00:00 2001 From: jason9075 Date: Mon, 17 Apr 2023 20:30:35 +0800 Subject: [PATCH 2/6] add constant with rules in TYPE_TO_SCHEDULER_FUNCTION --- src/diffusers/optimization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index daeb9d5f049d..4127d6c492b2 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -276,6 +276,7 @@ def lr_lambda(current_step: int): SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.CONSTANT_WITH_RULES: get_constant_schedule_with_rules, } From 1ad35207da7c21db8d68aab1f8d3f1cd73b91e80 Mon Sep 17 00:00:00 2001 From: jason9075 Date: Mon, 17 Apr 2023 22:40:30 +0800 Subject: [PATCH 3/6] add constant lr rate with rule --- src/diffusers/optimization.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 4127d6c492b2..983a2982179a 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -106,7 +106,7 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc rules_dict[steps] = value last_lr = float(rule_list[-1]) - def create_rules_function(): + def create_rules_function(rules_dict, last_lr): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): @@ -116,7 +116,7 @@ def rule_func(steps: int) -> float: return rule_func - rules_f = create_rules_function() + rules_f = create_rules_function(rules_dict, last_lr) return LambdaLR(optimizer, rules_f, last_epoch=last_epoch) @@ -283,6 +283,7 @@ def lr_lambda(current_step: int): def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, + rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, @@ -315,6 +316,9 @@ def get_scheduler( if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) + if name == SchedulerType.CONSTANT_WITH_RULES: + return schedule_func(optimizer, rules=rules, last_epoch=last_epoch) + # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") From cbf0b35901ede517122e1ccb3fa1c13f7e47a2d5 Mon Sep 17 00:00:00 2001 From: jason9075 Date: Mon, 17 Apr 2023 22:49:41 +0800 Subject: [PATCH 4/6] hotfix code quality --- src/diffusers/optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 983a2982179a..3df645180c1f 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -16,7 +16,7 @@ import math from enum import Enum -from typing import Callable, Optional, Union +from typing import Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR From 489bb563afd68d7d819a702a5093b0d738374078 Mon Sep 17 00:00:00 2001 From: jason9075 Date: Mon, 17 Apr 2023 23:22:17 +0800 Subject: [PATCH 5/6] fix doc style --- src/diffusers/optimization.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 3df645180c1f..55e9f0969ba3 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -86,10 +86,9 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. rule (`string`): - The rules for the learning rate. - ex: rules="1:10,0.1:20,0.01:30,0.005" - it means that the learning rate is multiple 1 for the first 10 steps, mutiple 0.1 for the - next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. + The rules for the learning rate. ex: rules="1:10,0.1:20,0.01:30,0.005" it means that the learning rate is + multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps + and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. From 348fdbd04a92933ec43eae437c6c254329a951fc Mon Sep 17 00:00:00 2001 From: jason9075 Date: Thu, 27 Apr 2023 20:02:24 +0800 Subject: [PATCH 6/6] change name constant_with_rules to piecewise constant --- src/diffusers/optimization.py | 36 ++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 55e9f0969ba3..78d68b7978a9 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -34,7 +34,7 @@ class SchedulerType(Enum): POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" - CONSTANT_WITH_RULES = "constant_with_rules" + PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): @@ -78,17 +78,17 @@ def lr_lambda(current_step: int): return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) -def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoch: int = -1): +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): """ - Create a schedule with a constant learning rate with rule for the learning rate. + Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. - rule (`string`): - The rules for the learning rate. ex: rules="1:10,0.1:20,0.01:30,0.005" it means that the learning rate is - multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps - and multiple 0.005 for the other steps. + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. @@ -97,27 +97,27 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc """ rules_dict = {} - rule_list = rules.split(",") + rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value - last_lr = float(rule_list[-1]) + last_lr_multiple = float(rule_list[-1]) - def create_rules_function(rules_dict, last_lr): + def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] - return last_lr + return last_lr_multiple return rule_func - rules_f = create_rules_function(rules_dict, last_lr) + rules_func = create_rules_function(rules_dict, last_lr_multiple) - return LambdaLR(optimizer, rules_f, last_epoch=last_epoch) + return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): @@ -275,14 +275,14 @@ def lr_lambda(current_step: int): SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, - SchedulerType.CONSTANT_WITH_RULES: get_constant_schedule_with_rules, + SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, - rules: Optional[str] = None, + step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, @@ -297,6 +297,8 @@ def get_scheduler( The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. + step_rules (`str`, *optional*): + A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. @@ -315,8 +317,8 @@ def get_scheduler( if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) - if name == SchedulerType.CONSTANT_WITH_RULES: - return schedule_func(optimizer, rules=rules, last_epoch=last_epoch) + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: