cp: fix: more robust fp8 rollout metric check (1307) into r0.4.0#1386
cp: fix: more robust fp8 rollout metric check (1307) into r0.4.0#1386
fix: more robust fp8 rollout metric check (1307) into r0.4.0#1386Conversation
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
📝 WalkthroughWalkthroughThis PR enhances the test metrics framework with new evaluation functions and upgrades the FP8 rollout test suite from v2 to v3. A Changes
Sequence DiagramsequenceDiagram
participant Test as Test Suite (v3)
participant Eval as evaluate_check()
participant Metrics as Metric Functions
participant Data as Metric Data
Test->>Eval: check expression with mean/ratio_above
Eval->>Metrics: parse & execute mean(..., ignore_top_p=0.05)
Metrics->>Data: fetch train/token_mult_prob_error values
Data-->>Metrics: return dict of values
Metrics->>Metrics: filter top outliers (ignore_top_p)
Metrics-->>Eval: return filtered mean
Eval->>Metrics: execute ratio_above(data, 1.1)
Metrics->>Data: fetch train/token_mult_prob_error values
Data-->>Metrics: return dict of values
Metrics->>Metrics: count values >= threshold
Metrics-->>Eval: return proportion
Eval->>Eval: compare results to gate condition
alt Gate passes
Eval-->>Test: ✓ metrics pass
else Gate fails
Eval-->>Test: ✗ metrics fail
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes The changes span heterogeneous areas: new public API functions with parameter additions ( Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (6)
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh (1)
7-10: Config bumps LGTM; SC2034 warnings are expected.NUM_RUNS and NUM_MINUTES are consumed by external tooling/common.env; safe to keep as-is. If you want to silence shellcheck, export or mark readonly.
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=180 +NUM_MINUTES=180 +# shellcheck disable=SC2034 # Used by launcher/common.env +readonly NUM_RUNS NUM_MINUTESBased on learnings.
tests/check_metrics.py (4)
78-82: Ensure range filtering is order-robust by iterating steps numerically.Dict insertion order may not match step order. Iterate by sorted step keys.
- vals = [] - for step, v in value.items(): - if range_start <= int(step) and int(step) < range_end: - vals.append(float(v)) + vals = [] + for step_str in sorted(value.keys(), key=int): + step = int(step_str) + if range_start <= step < range_end: + vals.append(float(value[step_str]))
83-96: ignore_top_p logic is solid; add empty-range guard to improve error messaging.statistics.mean([]) raises StatisticsError; raise a clear ValueError instead when no values remain after filtering/ranging.
- # Filter out top outliers if requested + # Filter out top outliers if requested if ignore_top_p > 0.0 and len(vals) > 0: @@ - return statistics.mean(vals) + if not vals: + raise ValueError("No values in selected range after filtering") + return statistics.mean(vals)
107-113: Limit eval’s builtins to reduce attack surface.These checks run in CI with trusted strings, but it’s safer to remove full builtins exposure.
- local_context = { + local_context = { "data": data, "min": min, "max": max, "mean": mean, "ratio_above": ratio_above, }And in eval calls:
- value = eval(value_expr, {"__builtins__": builtins}, local_context) + value = eval(value_expr, {"__builtins__": {}}, local_context) @@ - result = eval(check, {"__builtins__": builtins}, local_context) + result = eval(check, {"__builtins__": {}}, local_context)This keeps only the whitelisted helpers available. If you need specific safe builtins (e.g., True/False/None), pass them explicitly.
155-157: Update usage examples to mirror project tooling.Our shell drivers use uv; reflect that in examples.
- python check_metrics.py results.json "mean(data['loss'], ignore_top_p=0.05) < 1.5" - python check_metrics.py results.json "ratio_above(data['error'], 1.05) < 0.02" + uv run tests/check_metrics.py results.json "mean(data['loss'], ignore_top_p=0.05) < 1.5" + uv run tests/check_metrics.py results.json "ratio_above(data['error'], 1.05) < 0.02"tests/unit/test_check_metrics.py (1)
30-97: Prefer pytest.approx for floats to avoid flakiness.Replace exact float equality with approx where applicable.
- assert result == 3.0 + assert result == pytest.approx(3.0) @@ - assert result_no_filter == 22.0 # (1+2+3+4+100)/5 + assert result_no_filter == pytest.approx(22.0) # (1+2+3+4+100)/5 @@ - assert result_with_filter == 2.5 # (1+2+3+4)/4 + assert result_with_filter == pytest.approx(2.5) # (1+2+3+4)/4(Apply similarly across other float assertions in this file.)
Also applies to: 112-133, 166-219, 224-307, 313-407
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml(0 hunks)tests/check_metrics.py(5 hunks)tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh(2 hunks)tests/test_suites/nightly.txt(1 hunks)tests/unit/test_check_metrics.py(1 hunks)
💤 Files with no reviewable changes (1)
- examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml
🧰 Additional context used
📓 Path-based instructions (5)
tests/test_suites/nightly.txt
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Files:
tests/test_suites/nightly.txt
tests/test_suites/**
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Place driver shell scripts and common.env under tests/test_suites// and list nightly tests in tests/test_suites/nightly.txt
Files:
tests/test_suites/nightly.txttests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
tests/check_metrics.pytests/unit/test_check_metrics.py
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Follow the Google Shell Style Guide for all shell scripts
Useuv runto execute Python scripts in shell/driver scripts instead of activating virtualenvs and callingpythondirectly
Add the NVIDIA copyright header (with current year) at the top of all shell scripts, excluding tests/ and test-only scripts
Files:
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
tests/test_suites/llm/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
LLM driver script filenames must mirror the YAML base name and follow the same pattern with .sh extension
Files:
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
🧠 Learnings (2)
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
PR: NVIDIA-NeMo/RL#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to tests/test_suites/nightly.txt : Append the new driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Applied to files:
tests/test_suites/nightly.txt
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
PR: NVIDIA-NeMo/RL#1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.
Applied to files:
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
🧬 Code graph analysis (1)
tests/unit/test_check_metrics.py (1)
tests/check_metrics.py (5)
evaluate_check(100-133)max(30-32)mean(52-97)min(25-27)ratio_above(35-49)
🪛 Ruff (0.14.0)
tests/check_metrics.py
85-87: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Use of possibly insecure function; consider using ast.literal_eval
(S307)
123-123: Use of possibly insecure function; consider using ast.literal_eval
(S307)
tests/unit/test_check_metrics.py
103-103: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
108-108: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🪛 Shellcheck (0.11.0)
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Lint check
- GitHub Check: sphinx-build / Build docs
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
tests/check_metrics.py (1)
35-50: New ratio_above helper — concise and correct.Counts values >= threshold; handles empty dict by returning 0.0. Nice.
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | ||
| # With a few number of steps the logprob can have spikes that can move the average up. | ||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||
| 'mean(data["train/token_mult_prob_error"]) < 1.1' \ | ||
| 'data["train/token_mult_prob_error"]["40"] < 1.1' | ||
| 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.05) < 1.1' \ | ||
| 'ratio_above(data["train/token_mult_prob_error"], 1.1) < 0.1' | ||
| # ratio_above @ 1.1 was 0.03,0.06,0.05: 3sigma ~=0.1 | ||
| fi |
There was a problem hiding this comment.
Harden the step-reached check and quote paths.
Guard against missing keys (jq -> empty) and avoid word-splitting. Suggested patch:
-# Only run metrics if the target step is reached
-if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
- # With a few number of steps the logprob can have spikes that can move the average up.
- uv run tests/check_metrics.py $JSON_METRICS \
- 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.05) < 1.1' \
- 'ratio_above(data["train/token_mult_prob_error"], 1.1) < 0.1'
- # ratio_above @ 1.1 was 0.03,0.06,0.05: 3sigma ~=0.1
-fi
+# Only run metrics if the target step is reached
+steps_reached="$(jq -r '((.["train/loss"] | keys) // []) | map(tonumber) | (max? // 0)' "$JSON_METRICS")"
+if (( steps_reached >= MAX_STEPS )); then
+ # With few steps the logprob can spike and skew the average.
+ uv run tests/check_metrics.py "$JSON_METRICS" \
+ 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.05) < 1.1' \
+ 'ratio_above(data["train/token_mult_prob_error"], 1.1) < 0.1'
+ # ratio_above @ 1.1 was 0.03,0.06,0.05: 3sigma ~= 0.1
+fiAdditionally, consider enabling strict mode near the top for Google Shell style:
set -Eeuo pipefail🤖 Prompt for AI Agents
In
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
around lines 35-41, the step-reached check can fail when jq returns empty and
command substitutions are unquoted; update the check to safely handle missing
keys and avoid word-splitting by capturing jq output into a variable with quotes
and providing a safe default (e.g., treat empty as 0) or use jq's -e/// operator
to return a number, then compare that numeric value with MAX_STEPS; also quote
$JSON_METRICS and any other path variables passed to jq or uv, and enable strict
shell mode near the top of the script (set -Eeuo pipefail) to catch errors
early.
|
|
||
| # FP8 | ||
| tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v2.sh | ||
| tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh |
There was a problem hiding this comment.
🧩 Analysis chain
Nightly entry updated to v3 — looks good; please confirm no residual v2 entry remains.
The path matches the new driver script name. Double‑check that any old v2 entry was removed to avoid duplicate runs.
🏁 Script executed:
#!/bin/bash
# Verify only v3 is listed and the script exists
set -euo pipefail
rg -n 'grpo-llama3\.1-8b-instruct-1n8g-megatron-fp8-rollouts\.(v2|v3)\.sh' tests/test_suites/nightly.txt
fd -t f 'grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh' tests/test_suites/llm -x ls -l {}Length of output: 213
Need to verify path format consistency against the coding guidelines. Let me check the structure of other entries in nightly.txt.
🏁 Script executed:
#!/bin/bash
# Check the format of entries in nightly.txt
head -50 tests/test_suites/nightly.txtLength of output: 1811
Correct v3 entry and remove v2 duplication — verified. However, path format must be corrected per coding guidelines.
The v3 entry is present at line 41, the v2 entry has been removed (no duplicates), and the script file exists. However, the path violates the coding guidelines: it uses the full path tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh instead of the relative format required by the guidelines (llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh). Update the entry at line 41 to use the relative path format.
🤖 Prompt for AI Agents
In tests/test_suites/nightly.txt around line 41, the v3 entry uses an
absolute-ish path that violates guidelines; change the line from
"tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh"
to the required relative format
"llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh" so the file
reference follows the repository relative path convention.
| """Test that invalid ignore_top_p values raise an error.""" | ||
| data = {"1": 1.0, "2": 2.0, "3": 3.0} | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match="ignore_top_p must be between 0.0 and 1.0" | ||
| ): | ||
| mean(data, ignore_top_p=1.5) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match="ignore_top_p must be between 0.0 and 1.0" | ||
| ): | ||
| mean(data, ignore_top_p=-0.1) | ||
|
|
There was a problem hiding this comment.
Fix regex patterns in pytest.raises match=.
Use raw strings or re.escape to avoid unintended regex semantics.
- with pytest.raises(
- ValueError, match="ignore_top_p must be between 0.0 and 1.0"
- ):
+ with pytest.raises(
+ ValueError, match=r"ignore_top_p must be between 0\.0 and 1\.0"
+ ):
mean(data, ignore_top_p=1.5)
@@
- with pytest.raises(
- ValueError, match="ignore_top_p must be between 0.0 and 1.0"
- ):
+ with pytest.raises(
+ ValueError, match=r"ignore_top_p must be between 0\.0 and 1\.0"
+ ):
mean(data, ignore_top_p=-0.1)🧰 Tools
🪛 Ruff (0.14.0)
103-103: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
108-108: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🤖 Prompt for AI Agents
In tests/unit/test_check_metrics.py around lines 99 to 111, the pytest.raises
match strings are interpreted as regexes and may be mis-parsed; update the match
arguments to use raw string literals (prefix with r) or wrap the message in
re.escape to ensure the message is treated literally (e.g., r"ignore_top_p must
be between 0.0 and 1.0" or re.escape("ignore_top_p must be between 0.0 and
1.0")) so the assertion matches the exact error text without regex side-effects.
beep boop [🤖]: Hi @terrykong 👋,
Summary by CodeRabbit
New Features
Tests