Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/diffusers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SchedulerType(Enum):
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
PIECEWISE_CONSTANT = "piecewise_constant"


def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
Expand Down Expand Up @@ -77,6 +78,48 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.

Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
step_rules (`string`):
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
steps and multiple 0.005 for the other steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.

Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

rules_dict = {}
rule_list = step_rules.split(",")
for rule_str in rule_list[:-1]:
value_str, steps_str = rule_str.split(":")
steps = int(steps_str)
value = float(value_str)
rules_dict[steps] = value
last_lr_multiple = float(rule_list[-1])

def create_rules_function(rules_dict, last_lr_multiple):
def rule_func(steps: int) -> float:
sorted_steps = sorted(rules_dict.keys())
for i, sorted_step in enumerate(sorted_steps):
if steps < sorted_step:
return rules_dict[sorted_steps[i]]
return last_lr_multiple

return rule_func

rules_func = create_rules_function(rules_dict, last_lr_multiple)

return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)


def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
Expand Down Expand Up @@ -232,12 +275,14 @@ def lr_lambda(current_step: int):
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
}


def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
step_rules: Optional[str] = None,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
Expand All @@ -252,6 +297,8 @@ def get_scheduler(
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
step_rules (`str`, *optional*):
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
Expand All @@ -270,6 +317,9 @@ def get_scheduler(
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, last_epoch=last_epoch)

if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch)

# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
Expand Down