Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,56 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
)

import numpy as np

from deepmd.common import (
j_get_type,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


class BaseLR(ABC, PluginVariant, make_plugin_registry("lr")):
def __new__(cls: type, *args: Any, **kwargs: Any) -> Any:
if cls is BaseLR:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

def __init__(
self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any
) -> None:
"""
Base class for learning rate schedules.

Parameters
----------
start_lr
The initial learning rate.
stop_lr
The final learning rate.
stop_steps
The total training steps for learning rate scheduler.
"""
self.start_lr = start_lr
self.stop_lr = stop_lr
self.stop_steps = stop_steps

@abstractmethod
def value(self, step: int) -> np.float64:
"""Get the learning rate at the given step."""
pass


class LearningRateExp:
@BaseLR.register("exp")
class LearningRateExp(BaseLR):
def __init__(
self,
start_lr: float,
Expand Down Expand Up @@ -37,7 +81,7 @@ def __init__(
If provided, the decay rate will be set instead of
calculating it through interpolation between start_lr and stop_lr.
"""
self.start_lr = start_lr
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
self.decay_steps = decay_steps
if self.decay_steps >= stop_steps:
Expand All @@ -47,11 +91,49 @@ def __init__(
)
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = stop_lr
self.min_lr = self.stop_lr

def value(self, step: int) -> np.float64:
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
step_lr = self.min_lr
return step_lr


@BaseLR.register("cosine")
class LearningRateCosine(BaseLR):
def __init__(
self,
start_lr: float,
stop_lr: float,
stop_steps: int,
**kwargs: Any,
) -> None:
"""
Defines a cosine annealing learning rate schedule.
The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
following a cosine curve over the training steps.

Parameters
----------
start_lr
The initial learning rate at the beginning of training.
stop_lr
The final learning rate at the end of training.
stop_steps
The total number of training steps over which the learning rate
will be annealed from start_lr to stop_lr.
"""
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
self.lr_min_factor = stop_lr / start_lr

def value(self, step: int) -> np.float64:
if step >= self.stop_steps:
return self.start_lr * self.lr_min_factor
return self.start_lr * (
self.lr_min_factor
+ 0.5
* (1 - self.lr_min_factor)
* (1 + np.cos(np.pi * (step / self.stop_steps)))
)
11 changes: 4 additions & 7 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
SAMPLER_RECORD,
)
from deepmd.pt.utils.learning_rate import (
LearningRateExp,
BaseLR,
)
from deepmd.pt.utils.stat import (
make_stat_input,
Expand Down Expand Up @@ -266,13 +266,10 @@ def get_sample() -> Any:
_stat_file_path.root.close()
return get_sample

def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
assert lr_params.get("type", "exp") == "exp", (
"Only learning rate `exp` is supported!"
)
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
lr_exp = LearningRateExp(**lr_params)
return lr_exp
lr_schedule = BaseLR(**lr_params)
return lr_schedule

# Optimizer
if self.multi_task and training_params.get("optim_dict", None) is not None:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.utils.learning_rate import (
BaseLR,
LearningRateCosine,
LearningRateExp,
)

__all__ = [
"BaseLR",
"LearningRateCosine",
"LearningRateExp",
]
24 changes: 23 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,10 @@ def linear_ener_model_args() -> Argument:


# --- Learning rate configurations: --- #
lr_args_plugin = ArgsPlugin()


@lr_args_plugin.register("exp")
def learning_rate_exp() -> list[Argument]:
doc_start_lr = "The learning rate at the start of the training."
doc_stop_lr = (
Expand Down Expand Up @@ -2509,12 +2513,30 @@ def learning_rate_exp() -> list[Argument]:
return args


@lr_args_plugin.register("cosine", doc=doc_only_pt_supported)
def learning_rate_cosine() -> list[Argument]:
"""
Defines a cosine annealing learning rate schedule.

The learning rate starts at `start_lr` and gradually decreases to `stop_lr`
following a cosine curve over the training steps.
"""
doc_start_lr = "The learning rate at the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training. "

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-5, doc=doc_stop_lr),
]
return args


def learning_rate_variant_type_args() -> Variant:
doc_lr = "The type of the learning rate."

return Variant(
"type",
[Argument("exp", dict, learning_rate_exp())],
lr_args_plugin.get_all_argument(),
optional=True,
default_tag="exp",
doc=doc_lr,
Expand Down
17 changes: 17 additions & 0 deletions source/tests/pt/test_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
tf.disable_eager_execution()

from deepmd.pt.utils.learning_rate import (
LearningRateCosine,
LearningRateExp,
)
from deepmd.tf.utils import (
Expand Down Expand Up @@ -102,5 +103,21 @@ def decay_rate_pt(self) -> None:
)


class TestLearningRateCosine(unittest.TestCase):
def test_basic_curve(self) -> None:
start_lr = 1.0
stop_lr = 0.1
stop_steps = 10
lr = LearningRateCosine(start_lr, stop_lr, stop_steps)

self.assertTrue(np.allclose(lr.value(0), start_lr))
self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr))
self.assertTrue(np.allclose(lr.value(stop_steps + 5), stop_lr))

mid_step = stop_steps // 2
expected_mid = stop_lr + (start_lr - stop_lr) * 0.5
self.assertTrue(np.allclose(lr.value(mid_step), expected_mid))


if __name__ == "__main__":
unittest.main()