diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 10f7ec8d04..f82a42660b 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -1,12 +1,56 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) from typing import ( Any, ) import numpy as np +from deepmd.common import ( + j_get_type, +) +from deepmd.utils.plugin import ( + PluginVariant, + make_plugin_registry, +) + + +class BaseLR(ABC, PluginVariant, make_plugin_registry("lr")): + def __new__(cls: type, *args: Any, **kwargs: Any) -> Any: + if cls is BaseLR: + cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) + return super().__new__(cls) + + def __init__( + self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any + ) -> None: + """ + Base class for learning rate schedules. + + Parameters + ---------- + start_lr + The initial learning rate. + stop_lr + The final learning rate. + stop_steps + The total training steps for learning rate scheduler. + """ + self.start_lr = start_lr + self.stop_lr = stop_lr + self.stop_steps = stop_steps + + @abstractmethod + def value(self, step: int) -> np.float64: + """Get the learning rate at the given step.""" + pass + -class LearningRateExp: +@BaseLR.register("exp") +class LearningRateExp(BaseLR): def __init__( self, start_lr: float, @@ -37,7 +81,7 @@ def __init__( If provided, the decay rate will be set instead of calculating it through interpolation between start_lr and stop_lr. """ - self.start_lr = start_lr + super().__init__(start_lr, stop_lr, stop_steps, **kwargs) default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1 self.decay_steps = decay_steps if self.decay_steps >= stop_steps: @@ -47,7 +91,7 @@ def __init__( ) if decay_rate is not None: self.decay_rate = decay_rate - self.min_lr = stop_lr + self.min_lr = self.stop_lr def value(self, step: int) -> np.float64: """Get the learning rate at the given step.""" @@ -55,3 +99,41 @@ def value(self, step: int) -> np.float64: if step_lr < self.min_lr: step_lr = self.min_lr return step_lr + + +@BaseLR.register("cosine") +class LearningRateCosine(BaseLR): + def __init__( + self, + start_lr: float, + stop_lr: float, + stop_steps: int, + **kwargs: Any, + ) -> None: + """ + Defines a cosine annealing learning rate schedule. + The learning rate starts at `start_lr` and gradually decreases to `stop_lr` + following a cosine curve over the training steps. + + Parameters + ---------- + start_lr + The initial learning rate at the beginning of training. + stop_lr + The final learning rate at the end of training. + stop_steps + The total number of training steps over which the learning rate + will be annealed from start_lr to stop_lr. + """ + super().__init__(start_lr, stop_lr, stop_steps, **kwargs) + self.lr_min_factor = stop_lr / start_lr + + def value(self, step: int) -> np.float64: + if step >= self.stop_steps: + return self.start_lr * self.lr_min_factor + return self.start_lr * ( + self.lr_min_factor + + 0.5 + * (1 - self.lr_min_factor) + * (1 + np.cos(np.pi * (step / self.stop_steps))) + ) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d98b23d25c..7d768cf66b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -63,7 +63,7 @@ SAMPLER_RECORD, ) from deepmd.pt.utils.learning_rate import ( - LearningRateExp, + BaseLR, ) from deepmd.pt.utils.stat import ( make_stat_input, @@ -266,13 +266,10 @@ def get_sample() -> Any: _stat_file_path.root.close() return get_sample - def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: - assert lr_params.get("type", "exp") == "exp", ( - "Only learning rate `exp` is supported!" - ) + def get_lr(lr_params: dict[str, Any]) -> BaseLR: lr_params["stop_steps"] = self.num_steps - self.warmup_steps - lr_exp = LearningRateExp(**lr_params) - return lr_exp + lr_schedule = BaseLR(**lr_params) + return lr_schedule # Optimizer if self.multi_task and training_params.get("optim_dict", None) is not None: diff --git a/deepmd/pt/utils/learning_rate.py b/deepmd/pt/utils/learning_rate.py index 3502434bc0..ff7d4f7ec7 100644 --- a/deepmd/pt/utils/learning_rate.py +++ b/deepmd/pt/utils/learning_rate.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.dpmodel.utils.learning_rate import ( + BaseLR, + LearningRateCosine, LearningRateExp, ) __all__ = [ + "BaseLR", + "LearningRateCosine", "LearningRateExp", ] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7fcc117ab5..1809b19083 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2477,6 +2477,10 @@ def linear_ener_model_args() -> Argument: # --- Learning rate configurations: --- # +lr_args_plugin = ArgsPlugin() + + +@lr_args_plugin.register("exp") def learning_rate_exp() -> list[Argument]: doc_start_lr = "The learning rate at the start of the training." doc_stop_lr = ( @@ -2509,12 +2513,30 @@ def learning_rate_exp() -> list[Argument]: return args +@lr_args_plugin.register("cosine", doc=doc_only_pt_supported) +def learning_rate_cosine() -> list[Argument]: + """ + Defines a cosine annealing learning rate schedule. + + The learning rate starts at `start_lr` and gradually decreases to `stop_lr` + following a cosine curve over the training steps. + """ + doc_start_lr = "The learning rate at the start of the training." + doc_stop_lr = "The desired learning rate at the end of the training. " + + args = [ + Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr), + Argument("stop_lr", float, optional=True, default=1e-5, doc=doc_stop_lr), + ] + return args + + def learning_rate_variant_type_args() -> Variant: doc_lr = "The type of the learning rate." return Variant( "type", - [Argument("exp", dict, learning_rate_exp())], + lr_args_plugin.get_all_argument(), optional=True, default_tag="exp", doc=doc_lr, diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py index 2d6bf156e1..75f663f041 100644 --- a/source/tests/pt/test_lr.py +++ b/source/tests/pt/test_lr.py @@ -7,6 +7,7 @@ tf.disable_eager_execution() from deepmd.pt.utils.learning_rate import ( + LearningRateCosine, LearningRateExp, ) from deepmd.tf.utils import ( @@ -102,5 +103,21 @@ def decay_rate_pt(self) -> None: ) +class TestLearningRateCosine(unittest.TestCase): + def test_basic_curve(self) -> None: + start_lr = 1.0 + stop_lr = 0.1 + stop_steps = 10 + lr = LearningRateCosine(start_lr, stop_lr, stop_steps) + + self.assertTrue(np.allclose(lr.value(0), start_lr)) + self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr)) + self.assertTrue(np.allclose(lr.value(stop_steps + 5), stop_lr)) + + mid_step = stop_steps // 2 + expected_mid = stop_lr + (start_lr - stop_lr) * 0.5 + self.assertTrue(np.allclose(lr.value(mid_step), expected_mid)) + + if __name__ == "__main__": unittest.main()