Skip to content

feat: add WSD scheduler#5326

Draft
OutisLi wants to merge 1 commit intodeepmodeling:masterfrom
OutisLi:pr/wsd
Draft

feat: add WSD scheduler#5326
OutisLi wants to merge 1 commit intodeepmodeling:masterfrom
OutisLi:pr/wsd

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Mar 18, 2026

The doc will be added through PR #5276

Summary by CodeRabbit

  • New Features

    • Introduced "wsd" (warmup-stable-decay) learning rate schedule with configurable warmup, stable, and decay phases. Supports three decay modes: inverse_linear, cosine, and linear.
  • Tests

    • Added comprehensive test coverage validating the new WSD schedule across different decay types and edge cases.

Copilot AI review requested due to automatic review settings March 18, 2026 04:51
@dosubot dosubot bot added the new feature label Mar 18, 2026
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)
decay_num_steps = xp.asarray(self.decay_num_steps, dtype=step_dtype)

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable decay_num_steps is not used.
@OutisLi OutisLi marked this pull request as draft March 18, 2026 04:54
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new warmup-stable-decay (WSD) learning-rate scheduler to the backend-agnostic BaseLR registry, wires it into argument checking and backend exports, and expands cross-backend test coverage to validate behavior and consistency.

Changes:

  • Implement LearningRateWSD (type="wsd") with warmup support plus configurable decay modes (inverse_linear, cosine, linear).
  • Extend CLI/config arg validation to accept and validate WSD-specific parameters.
  • Add/extend tests across universal, TF, PT, PD, and consistency suites for WSD behavior.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
deepmd/dpmodel/utils/learning_rate.py Adds the LearningRateWSD scheduler implementation and registers it under BaseLR.
deepmd/utils/argcheck.py Adds WSD-specific validation and exposes WSD arguments via the lr args plugin registry.
deepmd/pt/utils/learning_rate.py Re-exports LearningRateWSD for the PyTorch backend utilities.
deepmd/pd/utils/learning_rate.py Re-exports LearningRateWSD for the Paddle backend utilities.
source/tests/universal/dpmodel/utils/test_learning_rate.py Adds extensive unit tests for WSD (modes, warmup, array input, boundary conditions).
source/tests/tf/test_lr.py Adds TF wrapper build/value tests for WSD (default + cosine).
source/tests/pt/test_lr.py Adds PT-side WSD curve tests for all decay modes.
source/tests/pd/test_lr.py Adds PD-side WSD curve tests for all decay modes.
source/tests/consistent/test_learning_rate.py Extends consistency parameterization to include WSD and adjusts the sampling step to hit the decay phase.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +556 to +560
stop_lr = xp.asarray(self.stop_lr, dtype=step_dtype)
stable_steps = xp.asarray(self.stable_steps, dtype=step_dtype)
decay_phase_steps = xp.asarray(self.decay_phase_steps, dtype=step_dtype)
decay_num_steps = xp.asarray(self.decay_num_steps, dtype=step_dtype)

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 18, 2026

📝 Walkthrough

Walkthrough

Introduces a new learning rate schedule class LearningRateWSD implementing warmup-stable-decay phases with three decay modes (inverse_linear, cosine, linear). The class is implemented in the core module, re-exported through multiple submodules (pd, pt), and supported by validation logic and comprehensive test coverage across all framework backends.

Changes

Cohort / File(s) Summary
Core Implementation
deepmd/dpmodel/utils/learning_rate.py
Added LearningRateWSD class with warmup, stable, and decay phases. Supports three decay modes via _decay_value() method. Validates inputs (positive start/stop_lr, decay_phase_ratio in (0,1], valid decay_type).
Module Re-exports
deepmd/pd/utils/learning_rate.py, deepmd/pt/utils/learning_rate.py
Imported LearningRateWSD from core module and added to __all__ for public API exposure in each submodule.
Validation Logic
deepmd/utils/argcheck.py
Added _check_wsd_args() validation function and learning_rate_wsd() plugin registration to enforce WSD-specific constraints (start_lr, stop_lr, decay_phase_ratio, decay_type).
Core Framework Tests
source/tests/universal/dpmodel/utils/test_learning_rate.py
Added comprehensive test coverage for LearningRateWSD including plateau behavior, mid-period values, stop_lr computation, edge cases with/without warmup, and decay mode variants (inverse_linear, cosine, linear).
PD Tests
source/tests/pd/test_lr.py
Added TestLearningRateWSD class with three tests validating different decay_type behaviors across boundary and midpoint steps.
PT Tests
source/tests/pt/test_lr.py
Added TestLearningRateWSD class with three tests verifying decay mode calculations and boundary behavior at stop_steps.
TensorFlow Tests
source/tests/tf/test_lr.py
Added WSD scheduler integration tests verifying correct instantiation and value computation with optional cosine decay type.
Consistent Tests
source/tests/consistent/test_learning_rate.py
Added three new WSD test configurations with decay_phase_ratio 0.1 and decay_type variants; introduced dynamic step calculation using optional scheduler attributes (stable_steps, decay_phase_steps).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

new feature, Python, enhancement

Suggested reviewers

  • iProzd
  • njzjz
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add WSD scheduler' clearly and specifically identifies the main change—introducing a new WSD (warmup-truncated stable-decay) learning rate scheduler throughout the codebase.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can generate a title for your PR based on the changes with custom instructions.

Set the reviews.auto_title_instructions setting to generate a title for your PR based on the changes in the PR with custom instructions.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
source/tests/tf/test_lr.py (1)

106-140: These assertions don't exercise the TF graph path.

lr_schedule.value() just delegates to base_lr.value(), so both new tests still pass if LearningRateSchedule.build() is wrong for type="wsd". Please evaluate the tensor returned by build() at a mid-decay step and compare that result instead.

🧪 Suggested coverage improvement
-        global_step = tf.constant(0, dtype=tf.int64)
-        lr_schedule.build(global_step, num_steps=10000)
+        g = tf.Graph()
+        with g.as_default():
+            global_step = tf.placeholder(shape=[], dtype=tf.int64)
+            lr_tensor = lr_schedule.build(global_step, num_steps=10000)
 
         self.assertIsInstance(lr_schedule.base_lr, LearningRateWSD)
-        np.testing.assert_allclose(
-            lr_schedule.value(9500), lr_schedule.base_lr.value(9500), rtol=1e-10
-        )
+        with tf.Session(graph=g) as sess:
+            tensor_value = sess.run(lr_tensor, feed_dict={global_step: 9500})
+        np.testing.assert_allclose(
+            tensor_value, lr_schedule.base_lr.value(9500), rtol=1e-10
+        )

Apply the same pattern to the cosine WSD variant as well.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/tf/test_lr.py` around lines 106 - 140, The tests are only
exercising the Python path because lr_schedule.value(9500) passes a plain int;
update both tests to pass a TF tensor so the built TF graph is evaluated: after
calling LearningRateSchedule.build(global_step, num_steps=10000), call
lr_schedule.value(tf.constant(9500, dtype=tf.int64)) and compare it to
lr_schedule.base_lr.value(tf.constant(9500, dtype=tf.int64)) (and do the same
change in test_wsd_cosine_build_and_value) to ensure LearningRateSchedule.build,
base_lr and LearningRateWSD TF-paths are exercised.
source/tests/universal/dpmodel/utils/test_learning_rate.py (1)

142-165: Consider splitting mixed validation assertions into separate tests.

test_invalid_decay_phase_ratio also checks invalid decay_type; splitting improves failure localization.

♻️ Suggested test split
-    def test_invalid_decay_phase_ratio(self) -> None:
-        """Test invalid WSD decay_phase_ratio values."""
+    def test_invalid_decay_phase_ratio(self) -> None:
+        """Test invalid WSD decay_phase_ratio values."""
         with self.assertRaises(ValueError):
             LearningRateWSD(
                 start_lr=1e-3,
                 stop_lr=1e-5,
                 num_steps=10000,
                 decay_phase_ratio=0.0,
             )
         with self.assertRaises(ValueError):
             LearningRateWSD(
                 start_lr=1e-3,
                 stop_lr=1e-5,
                 num_steps=10000,
                 decay_phase_ratio=1.1,
             )
+
+    def test_invalid_decay_type(self) -> None:
+        """Test invalid WSD decay_type values."""
         with self.assertRaises(ValueError):
             LearningRateWSD(
                 start_lr=1e-3,
                 stop_lr=1e-5,
                 num_steps=10000,
                 decay_type="bad_mode",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/universal/dpmodel/utils/test_learning_rate.py` around lines 142
- 165, test_invalid_decay_phase_ratio currently asserts multiple unrelated
invalid cases (invalid decay_phase_ratio and invalid decay_type) in one test;
split it into separate tests so each assertion checks a single validation rule.
Create one test (e.g., test_invalid_decay_phase_ratio_values) that calls
LearningRateWSD with decay_phase_ratio=0.0 and decay_phase_ratio=1.1 and asserts
ValueError, and another test (e.g., test_invalid_decay_type) that calls
LearningRateWSD with decay_type="bad_mode" and asserts ValueError; keep
function/class names LearningRateWSD, decay_phase_ratio, and decay_type to
locate the code. Ensure each new test has a clear name and only one
responsibility for better failure localization.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/dpmodel/utils/learning_rate.py`:
- Around line 514-525: The current int(self.decay_phase_ratio * self.num_steps)
can produce 0 or exceed post-warmup steps; change the computation of
decay_phase_steps to derive from and clamp against decay_num_steps instead of
raising: compute e.g. desired = max(0, int(round(self.decay_phase_ratio *
self.num_steps))) (or 0 if you prefer), then set self.decay_phase_steps =
min(desired, self.decay_num_steps) and if self.decay_phase_steps == 0 and
self.decay_num_steps > 0 set it to 1 (or otherwise handle the zero-decay case),
replacing the two ValueError checks; update code around decay_phase_steps,
decay_phase_ratio, num_steps and decay_num_steps to use this clamped value so
short runs or heavy-warmup configs don’t raise at runtime.

---

Nitpick comments:
In `@source/tests/tf/test_lr.py`:
- Around line 106-140: The tests are only exercising the Python path because
lr_schedule.value(9500) passes a plain int; update both tests to pass a TF
tensor so the built TF graph is evaluated: after calling
LearningRateSchedule.build(global_step, num_steps=10000), call
lr_schedule.value(tf.constant(9500, dtype=tf.int64)) and compare it to
lr_schedule.base_lr.value(tf.constant(9500, dtype=tf.int64)) (and do the same
change in test_wsd_cosine_build_and_value) to ensure LearningRateSchedule.build,
base_lr and LearningRateWSD TF-paths are exercised.

In `@source/tests/universal/dpmodel/utils/test_learning_rate.py`:
- Around line 142-165: test_invalid_decay_phase_ratio currently asserts multiple
unrelated invalid cases (invalid decay_phase_ratio and invalid decay_type) in
one test; split it into separate tests so each assertion checks a single
validation rule. Create one test (e.g., test_invalid_decay_phase_ratio_values)
that calls LearningRateWSD with decay_phase_ratio=0.0 and decay_phase_ratio=1.1
and asserts ValueError, and another test (e.g., test_invalid_decay_type) that
calls LearningRateWSD with decay_type="bad_mode" and asserts ValueError; keep
function/class names LearningRateWSD, decay_phase_ratio, and decay_type to
locate the code. Ensure each new test has a clear name and only one
responsibility for better failure localization.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a283f5a5-f445-4ee4-858e-4f089ee4287d

📥 Commits

Reviewing files that changed from the base of the PR and between b2805fb and f271bf8.

📒 Files selected for processing (9)
  • deepmd/dpmodel/utils/learning_rate.py
  • deepmd/pd/utils/learning_rate.py
  • deepmd/pt/utils/learning_rate.py
  • deepmd/utils/argcheck.py
  • source/tests/consistent/test_learning_rate.py
  • source/tests/pd/test_lr.py
  • source/tests/pt/test_lr.py
  • source/tests/tf/test_lr.py
  • source/tests/universal/dpmodel/utils/test_learning_rate.py

Comment on lines +514 to +525
self.decay_phase_steps = int(self.decay_phase_ratio * self.num_steps)
if self.decay_phase_steps <= 0:
raise ValueError(
"decay_phase_ratio results in zero decay steps. "
"Increase num_steps or decay_phase_ratio."
)
if self.decay_phase_steps > self.decay_num_steps:
raise ValueError(
"decay phase steps must not exceed the post-warmup steps. "
f"Got decay_phase_steps={self.decay_phase_steps}, "
f"post_warmup_steps={self.decay_num_steps}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't floor the decay phase into an impossible span.

Line 514 uses int(decay_phase_ratio * self.num_steps), so valid-looking configs can fail only at runtime: num_steps=5 with the default ratio produces 0 decay steps, and num_steps=10, warmup_steps=9, decay_phase_ratio=0.2 produces 2 > decay_num_steps=1. Please clamp the derived span, or derive it from self.decay_num_steps, so short runs and warmup-heavy runs still work.

💡 Possible fix
-        self.decay_phase_steps = int(self.decay_phase_ratio * self.num_steps)
-        if self.decay_phase_steps <= 0:
-            raise ValueError(
-                "decay_phase_ratio results in zero decay steps. "
-                "Increase num_steps or decay_phase_ratio."
-            )
-        if self.decay_phase_steps > self.decay_num_steps:
-            raise ValueError(
-                "decay phase steps must not exceed the post-warmup steps. "
-                f"Got decay_phase_steps={self.decay_phase_steps}, "
-                f"post_warmup_steps={self.decay_num_steps}."
-            )
+        if self.decay_num_steps <= 0:
+            raise ValueError("WSD requires at least one post-warmup step.")
+        self.decay_phase_steps = min(
+            max(int(self.decay_phase_ratio * self.num_steps), 1),
+            self.decay_num_steps,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/utils/learning_rate.py` around lines 514 - 525, The current
int(self.decay_phase_ratio * self.num_steps) can produce 0 or exceed post-warmup
steps; change the computation of decay_phase_steps to derive from and clamp
against decay_num_steps instead of raising: compute e.g. desired = max(0,
int(round(self.decay_phase_ratio * self.num_steps))) (or 0 if you prefer), then
set self.decay_phase_steps = min(desired, self.decay_num_steps) and if
self.decay_phase_steps == 0 and self.decay_num_steps > 0 set it to 1 (or
otherwise handle the zero-decay case), replacing the two ValueError checks;
update code around decay_phase_steps, decay_phase_ratio, num_steps and
decay_num_steps to use this clamped value so short runs or heavy-warmup configs
don’t raise at runtime.

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 71.01449% with 20 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.29%. Comparing base (b2805fb) to head (f271bf8).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/utils/argcheck.py 38.46% 16 Missing ⚠️
deepmd/dpmodel/utils/learning_rate.py 90.69% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5326      +/-   ##
==========================================
- Coverage   82.29%   82.29%   -0.01%     
==========================================
  Files         775      775              
  Lines       77627    77696      +69     
  Branches     3675     3675              
==========================================
+ Hits        63887    63939      +52     
- Misses      12566    12584      +18     
+ Partials     1174     1173       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants