feat: FP8 Training in Megatron Path#971
Conversation
|
I need to rebase after #905 is merged; currently the latest code is in https://github.com/guyueh1/NeMo-RL/tree/fp8_training_mbridge |
📝 WalkthroughWalkthroughAdds FP8 generation and FP8 training docs, example configs and a new GRPO FP8 training recipe, FP8 wiring and padding guards in MegatronPolicyWorker and LM policy, unit and nightly tests for FP8 training, and updates the Megatron-Bridge submodule to an FP8 branch. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Runner as GRPO Runner
participant Worker as MegatronPolicyWorker
participant SeqPack as SequencePacking
participant Model as Megatron Model
Note over Runner,Worker: Training (packable sequences)
Runner->>Worker: train(batch)
Worker->>SeqPack: compute pad_factor(cp_size, tp_size)
alt fp8_cfg.enabled == true
Note right of SeqPack #D6EAF8: enforce min pad = 16
SeqPack-->>Worker: pad_factor = max(default,16)
else
SeqPack-->>Worker: pad_factor = default
end
Worker->>Model: set model fp8 flags (if enabled) and forward/backward
Model-->>Worker: losses, grads
Worker-->>Runner: step metrics
sequenceDiagram
autonumber
participant Client as Caller
participant Worker as MegatronPolicyWorker
participant SeqPack as SequencePacking
participant Model as Megatron Model
Note over Client,Worker: Logprobs path (packable requests)
Client->>Worker: get_logprobs(requests)
Worker->>SeqPack: compute pad_factor
opt fp8_cfg.enabled == true
Note right of SeqPack #D6EAF8: pad_factor = max(default,16)
end
Worker->>Model: forward for logprobs (model fp8 flags applied if enabled)
Model-->>Worker: logits/logprobs
Worker-->>Client: logprobs
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (8)
docs/fp8.md (3)
3-6: Fix documentation typos and improve clarity.The documentation contains several typos and unclear language that should be addressed:
- Line 3: "developement" should be "development"
- Line 5: "using Deepseek style FP8" should be more descriptive - consider "using DeepSeek-style FP8 with sub-channel scaling"
- Line 6: The sentence structure could be improved for better readability
Apply this diff to fix the issues:
-This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in developement. Currently we support +This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in development. Currently we support -* FP8 generation, using Deepseek style FP8 (sub channel scaling) -* FP8 training, using TransformerEngine as linear layer implementation, supporting Deepseek style FP8 (sub channel scaling) and per-tensor scaling +* FP8 generation, using DeepSeek-style FP8 (sub-channel scaling) +* FP8 training, using TransformerEngine as the linear layer implementation, supporting DeepSeek-style FP8 (sub-channel scaling) and per-tensor scaling
40-49: Specify language for code blocks and fix typos.The markdown linter correctly identifies missing language specifications for code blocks, and there's a typo in the configuration comments.
Apply this diff to fix the issues:
-FP8 training requires megatron path, and is recommented to be configured with the following settings: +FP8 training requires the Megatron path, and is recommended to be configured with the following settings: -``` +```yaml policy: megatron_cfg: fp8_cfg: fp8: "hybrid" # choices: [hybrid, e4m3] - fp8_recipe: "tensorwise" # choicse: [tensorwise, blockwise] + fp8_recipe: "tensorwise" # choices: [tensorwise, blockwise] fp8_param: false # boolean value -``` +```
54-64: Add language specification for error traceback.The error traceback should be formatted as a code block with proper language specification for better readability.
-``` +```python File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast FP8GlobalStateManager.fp8_autocast_enter( File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 465, in fp8_autocast_enter assert fp8_block_available, reason_for_no_fp8_block ^^^^^^^^^^^^^^^^^^^ AssertionError: FP8 block scaled GEMM requires Hopper and CUDA >= 12.9. -``` +```nemo_rl/models/policy/megatron_policy_worker.py (1)
589-596: Address exception handling best practices.The static analysis tools correctly identify issues with the exception handling. The error message should be more descriptive and follow exception chaining best practices.
Apply this diff to improve exception handling:
- fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) - if fp8_cfg is not None and fp8_cfg.get("enabled", False): - try: - model_cfg.fp8 = fp8_cfg["fp8"] - model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] - model_cfg.fp8_param = fp8_cfg["fp8_param"] - except KeyError as e: - raise KeyError(f"Missing key in fp8_cfg: {e}") + fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) + if fp8_cfg is not None and fp8_cfg.get("enabled", False): + required_keys = ["fp8", "fp8_recipe", "fp8_param"] + missing_keys = [key for key in required_keys if key not in fp8_cfg] + if missing_keys: + raise KeyError( + f"Missing required FP8 configuration keys: {missing_keys}. " + f"Required keys are: {required_keys}" + ) from None + + model_cfg.fp8 = fp8_cfg["fp8"] + model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] + model_cfg.fp8_param = fp8_cfg["fp8_param"]tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh (3)
6-11: Address unused variable warnings and improve robustness.The shellcheck warnings identify several unused variables and a potential issue with directory change failure.
Apply this diff to address the issues:
# ===== BEGIN CONFIG ===== NUM_NODES=1 STEPS_PER_RUN=100 MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=300 +# NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up - currently unused +# NUM_MINUTES=300 # Currently unused # ===== END CONFIG =====
16-16: Add error handling for directory change.The shellcheck warning correctly identifies that the
cdcommand should handle potential failures.# Run the experiment -cd $PROJECT_ROOT +cd "$PROJECT_ROOT" || { echo "Failed to change to PROJECT_ROOT: $PROJECT_ROOT"; exit 1; }
28-28: Fix argument expansion.The shellcheck error correctly identifies that the array expansion should be quoted to avoid re-splitting.
- $@ \ + "$@" \examples/configs/grpo_math_8B_megatron_fp8.yaml (1)
30-30: Add missing newline at end of file.The YAML linter correctly identifies the missing newline character at the end of the file.
env_vars: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1" +
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge(1 hunks)docs/fp8.md(2 hunks)examples/configs/grpo_math_8B_megatron_fp8.yaml(1 hunks)examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.yaml(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(1 hunks)tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh(1 hunks)tests/test_suites/nightly.txt(1 hunks)tests/unit/models/policy/test_megatron_worker.py(3 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
nemo_rl/models/policy/megatron_policy_worker.py
596-596: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
596-596: Avoid specifying long messages outside the exception class
(TRY003)
🪛 Shellcheck (0.10.0)
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
🪛 markdownlint-cli2 (0.17.2)
docs/fp8.md
42-42: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_8B_megatron_fp8.yaml
[error] 30-30: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (8)
3rdparty/Megatron-Bridge-workspace/Megatron-Bridge (1)
1-1: Submodule bump to dfaf656 — manual verification required (missing .gitmodules entry).Verification script failed with: "ERROR: Submodule path not found in .gitmodules: 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge".
Actions:
- Confirm .gitmodules contains path 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge and that its URL points to the intended upstream (not a fork).
- Locally obtain SHAs and compare:
- New SHA: git ls-tree HEAD 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge | awk '{print $3}'
- Old SHA (origin/main): git show origin/main:3rdparty/Megatron-Bridge-workspace/Megatron-Bridge | sed -En 's/^Subproject commit ([0-9a-f]{7,40}).*/\1/p'
- If upstream is on GitHub, run a compare (gh api repos/<owner/repo>/compare/<OLD_SHA>...<NEW_SHA>) and inspect commit messages/files for FP8-only changes and any API/interface/config changes.
- Consider pinning to a tagged release for reproducibility if available.
examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.yaml (1)
1-162: Excellent comprehensive FP8 training configuration.This is a well-structured GRPO training recipe that properly integrates Megatron FP8 training with vLLM for generation. The configuration correctly sets up all necessary components:
- Proper GRPO parameters with validation and checkpointing
- Complete Megatron configuration with FP8 settings
- Appropriate optimizer and scheduler settings
- Comprehensive logging setup
The configuration aligns well with the FP8 capabilities being added to the codebase.
tests/test_suites/nightly.txt (1)
39-39: FP8 training test properly integrated into nightly suite.The addition of the FP8 training test script to the nightly suite is correctly placed and follows the existing pattern. This ensures the FP8 training functionality will be regularly tested.
tests/unit/models/policy/test_megatron_worker.py (2)
131-136: FP8 configuration added to test setup.The FP8 configuration block is correctly added to the test setup, providing good test coverage for the FP8 functionality. The configuration values align with the documentation.
341-352: FP8 test case properly parameterized.The FP8 test case is well-integrated into the existing test parameterization. The test ID "2gpu_tp2_llama_fp8" clearly identifies the FP8 variant, and the config updates are appropriate.
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh (1)
35-39: Excellent metrics validation approach.The conditional metrics checking based on whether the target step is reached is a robust approach. The specific metrics being checked (
train/token_mult_prob_error) are appropriate for FP8 training validation, ensuring numerical stability.examples/configs/grpo_math_8B_megatron_fp8.yaml (2)
19-30: Well-configured FP8 Megatron settings.The FP8 configuration is well-structured and aligns with best practices:
fp8: "hybrid"uses both E4M3 and E5M2 formats optimallyfp8_recipe: "blockwise"provides better numerical stability than tensorwise- Environment variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1"enables FP32 scaling factors for better precisionThe optimizer and DDP settings are also appropriate for FP8 training.
12-17: vLLM configuration optimizations look appropriate.The vLLM configuration changes appear well-optimized for FP8:
gpu_memory_utilization: 0.5provides reasonable memory usage- Compilation config with fusion and noop passes enabled should improve performance
These settings complement the FP8 training setup nicely.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
1812-1823: Fix O(N) repeated state_dict reload in move_model (severe perf hit).The nested loop rebuilds and loads the full state_dict once per parameter. Build once and load once.
Apply this diff:
- else: - # Ordinary offload case - if move_params: - for name, param in model.state_dict().items(): - new_state_dict = {} - for name, item in model.state_dict().items(): - if isinstance(item, torch.Tensor): - item = item.detach().to( - device=device, non_blocking=True, copy=True - ) - new_state_dict[name] = item - model.load_state_dict(new_state_dict) + else: + # Ordinary offload case: build once, load once + if move_params: + new_state = {} + for name, item in model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device=device, non_blocking=True, copy=True) + new_state[name] = item + model.load_state_dict(new_state)
♻️ Duplicate comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
589-597: Validate fp8_cfg values and raise cleanly; set normalized values on model_cfg.Add required-key check, normalize/case-fold strings, validate against allowed sets, and use “raise … from None/err” to satisfy B904. This prevents hard-to-debug TE errors at runtime.
Apply this diff:
- fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) - self.fp8_cfg = fp8_cfg - if fp8_cfg is not None and fp8_cfg.get("enabled", False): - try: - model_cfg.fp8 = fp8_cfg["fp8"] - model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] - model_cfg.fp8_param = fp8_cfg["fp8_param"] - except KeyError as e: - raise KeyError(f"Missing key in fp8_cfg: {e}") + fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) + self.fp8_cfg = fp8_cfg + if fp8_cfg is not None and fp8_cfg.get("enabled", False): + required = ("fp8", "fp8_recipe", "fp8_param") + missing = [k for k in required if k not in fp8_cfg] + if missing: + raise KeyError(f"Missing required FP8 config keys: {missing}") from None + + valid_formats = {"hybrid", "e4m3"} + valid_recipes = {"tensorwise", "blockwise"} + + fmt = str(fp8_cfg["fp8"]).lower() + recipe = str(fp8_cfg["fp8_recipe"]).lower() + param = fp8_cfg["fp8_param"] + + if fmt not in valid_formats: + raise ValueError(f"Invalid fp8 '{fmt}'. Allowed: {sorted(valid_formats)}") from None + if recipe not in valid_recipes: + raise ValueError(f"Invalid fp8_recipe '{recipe}'. Allowed: {sorted(valid_recipes)}") from None + if not isinstance(param, bool): + raise ValueError(f"fp8_param must be boolean; got {type(param).__name__}") from None + + model_cfg.fp8 = fmt + model_cfg.fp8_recipe = recipe + model_cfg.fp8_param = param
🧹 Nitpick comments (2)
examples/configs/grpo_math_8B_megatron_fp8.yaml (2)
19-19: Ensure sequence_length_round aligns with CP/TP baseline.If CP>1, the minimal safe multiple is cp_size2tp_size. Keep this ≥ baseline to avoid repacking and performance loss. Consider documenting this in the config comment.
Would you like me to compute and log the effective multiple at startup to warn when a too-small value is provided?
29-29: EOF newline nit.File is missing a trailing newline. Add one to satisfy linters.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/configs/grpo_math_8B_megatron_fp8.yaml(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(3 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
nemo_rl/models/policy/megatron_policy_worker.py
597-597: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
597-597: Avoid specifying long messages outside the exception class
(TRY003)
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_8B_megatron_fp8.yaml
[error] 29-29: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (2)
examples/configs/grpo_math_8B_megatron_fp8.yaml (2)
21-26: fp8_cfg values must match validated sets in code.The worker now validates: fp8 in {"hybrid","e4m3"}, fp8_recipe in {"tensorwise","blockwise"}, fp8_param: bool. Current values look valid.
Please confirm docs/fp8.md reflects these exact allowed values and notes that values are case-insensitive.
12-17: vLLM compilation_config flags: require vLLM ≥ v0.9.0; FP8 supported since v0.5.x — verify vLLM pin & GPU support.
enable_fusion and enable_noop appear in vLLM v0.9.0+; precision: "fp8" is documented from ~v0.5.x. If this config relies on those flags, pin the repo's vLLM >= 0.9.0 and confirm target GPUs/drivers support FP8 (NVIDIA compute capability >8.9 for full FP8; Ampere has weight-only W8A16 variants).
Location: examples/configs/grpo_math_8B_megatron_fp8.yaml (lines 12–17)
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
nemo_rl/models/policy/megatron_policy_worker.py (2)
954-955: Pad factor should be LCM of CP/TP baseline, sequence_length_round, and FP8(16); also round pad_full_seq_to.Using max(..., 16) can violate configured rounding and cause PP/CP shape mismatches.
- pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size - if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): - pad_factor = max(pad_factor, 16) + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + cfg_round = self.cfg["sequence_packing"].get("sequence_length_round") + if cfg_round: + pad_factor = math.lcm(pad_factor, int(cfg_round)) + if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): + pad_factor = math.lcm(pad_factor, 16) if self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1: _, pad_full_seq_to = ( batch.get_microbatch_iterator_for_packable_sequences_len() ) + if pad_full_seq_to is not None and (pad_full_seq_to % pad_factor): + pad_full_seq_to = ((pad_full_seq_to + pad_factor - 1) // pad_factor) * pad_factor
1161-1163: Mirror LCM pad_factor logic in logprobs path to match training.Keep packing constraints identical across train/logprobs.
- if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): - # if fp8 is enabled, pad the sequence length to multiples of 16 - pad_factor = max(pad_factor, 16) + cfg_round = self.cfg["sequence_packing"].get("sequence_length_round") + if cfg_round: + pad_factor = math.lcm(pad_factor, int(cfg_round)) + if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): + pad_factor = math.lcm(pad_factor, 16)
🧹 Nitpick comments (2)
.gitmodules (1)
9-9: Pin submodule to a commit SHA instead of a moving branch.Branch-tracking submodules are fragile and risky for reproducibility. Pin Megatron-Bridge to a specific commit and drop the branch line. Keep the submodule pointer updated in the parent repo when you want to advance.
- branch = guyueh/nemo-rl-use-fp8nemo_rl/models/policy/megatron_policy_worker.py (1)
595-603: Validate FP8 cfg and raise KeyError withfrom None; remove try/except.Adds guardrails and fixes Ruff B904/TRY003. Also ensures
fp8_paramis boolean.- fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) - self.fp8_cfg = fp8_cfg - if fp8_cfg is not None and fp8_cfg.get("enabled", False): - try: - model_cfg.fp8 = fp8_cfg["fp8"] - model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] - model_cfg.fp8_param = fp8_cfg["fp8_param"] - except KeyError as e: - raise KeyError(f"Missing key in fp8_cfg: {e}") + fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) + self.fp8_cfg = fp8_cfg + if fp8_cfg is not None and fp8_cfg.get("enabled", False): + required = ("fp8", "fp8_recipe", "fp8_param") + missing = [k for k in required if k not in fp8_cfg] + if missing: + raise KeyError(f"Missing required FP8 keys: {missing}") from None + valid_fp8 = {"hybrid", "e4m3"} + valid_recipe = {"tensorwise", "blockwise"} + if fp8_cfg["fp8"] not in valid_fp8: + raise ValueError(f"Invalid fp8='{fp8_cfg['fp8']}', expected {valid_fp8}") from None + if fp8_cfg["fp8_recipe"] not in valid_recipe: + raise ValueError(f"Invalid fp8_recipe='{fp8_cfg['fp8_recipe']}', expected {valid_recipe}") from None + if not isinstance(fp8_cfg["fp8_param"], bool): + raise ValueError(f"fp8_param must be bool, got {type(fp8_cfg['fp8_param'])}") from None + model_cfg.fp8 = fp8_cfg["fp8"] + model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] + model_cfg.fp8_param = fp8_cfg["fp8_param"]Add once at top-level imports:
import math
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
.gitmodules(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(3 hunks)tests/unit/models/policy/test_megatron_worker.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- nemo_rl/models/policy/lm_policy.py
- tests/unit/models/policy/test_megatron_worker.py
🧰 Additional context used
🪛 Ruff (0.12.2)
nemo_rl/models/policy/megatron_policy_worker.py
603-603: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
603-603: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
3acee86 to
44be79e
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
nemo_rl/models/policy/megatron_policy_worker.py (3)
595-603: Validate fp8_cfg values and raise from None (Ruff B904).Adds guardrails for invalid config and addresses the linter hint.
- fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) - self.fp8_cfg = fp8_cfg - if fp8_cfg is not None and fp8_cfg.get("enabled", False): - try: - model_cfg.fp8 = fp8_cfg["fp8"] - model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] - model_cfg.fp8_param = fp8_cfg["fp8_param"] - except KeyError as e: - raise KeyError(f"Missing key in fp8_cfg: {e}") + fp8_cfg = self.cfg["megatron_cfg"].get("fp8_cfg", None) + self.fp8_cfg = fp8_cfg + if fp8_cfg is not None and fp8_cfg.get("enabled", False): + # Validate presence and values + required = ("fp8", "fp8_recipe", "fp8_param") + missing = [k for k in required if k not in fp8_cfg] + if missing: + raise KeyError(f"Missing required FP8 configuration keys: {missing}") from None + valid_fp8 = {"hybrid", "e4m3"} + valid_recipe = {"tensorwise", "blockwise"} + if fp8_cfg["fp8"] not in valid_fp8: + raise ValueError(f"Invalid fp8 '{fp8_cfg['fp8']}', expected one of {sorted(valid_fp8)}") from None + if fp8_cfg["fp8_recipe"] not in valid_recipe: + raise ValueError(f"Invalid fp8_recipe '{fp8_cfg['fp8_recipe']}', expected one of {sorted(valid_recipe)}") from None + if not isinstance(fp8_cfg["fp8_param"], bool): + raise ValueError(f"fp8_param must be boolean, got {type(fp8_cfg['fp8_param'])}") from None + model_cfg.fp8 = fp8_cfg["fp8"] + model_cfg.fp8_recipe = fp8_cfg["fp8_recipe"] + model_cfg.fp8_param = fp8_cfg["fp8_param"]Additional import required if not present:
import math # for LCM below
951-956: Use LCM for pad_factor (include config round and FP8 multiple).max(..., 16) can violate CP/TP or configured rounding; LCM keeps constraints aligned.
- pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size - if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): - pad_factor = max(pad_factor, 16) + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + cfg_round = self.cfg["sequence_packing"].get("sequence_length_round") + if cfg_round: + pad_factor = math.lcm(pad_factor, int(cfg_round)) + if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): + pad_factor = math.lcm(pad_factor, 16)
1161-1163: Mirror LCM pad_factor in logprobs path.Keeps train/logprobs packing identical.
- if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): - # if fp8 is enabled, pad the sequence length to multiples of 16 - pad_factor = max(pad_factor, 16) + cfg_round = self.cfg["sequence_packing"].get("sequence_length_round") + if cfg_round: + pad_factor = math.lcm(pad_factor, int(cfg_round)) + if self.fp8_cfg is not None and self.fp8_cfg.get("enabled", False): + pad_factor = math.lcm(pad_factor, 16)
🧹 Nitpick comments (12)
examples/configs/grpo_math_70B_megatron_fp8.yaml (1)
1-24: Add trailing newline.Fixes yamllint error and keeps diffs clean.
docs/fp8.md (6)
3-3: Fix typos and tighten wording.Change “developement” → “development”; add comma after “Currently”.
-This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in developement. Currently we support +This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in development. Currently, we support
5-7: Minor terminology/typo nits.“Deepseek style” → “DeepSeek‑style”; “per-tensor scaling” hyphenation is optional; keep consistent across doc.
29-33: Fix spelling.“signficant” → “significant”.
-# a signficant effect to performance +# a significant effect on performance
40-40: Grammar: “recommended” and proper noun.“megatron path” → “Megatron path”; “recommented” → “recommended”.
-FP8 training requires megatron path, and is recommented to be configured with the following settings: +FP8 training requires the Megatron path, and is recommended to be configured with the following settings:
42-49: Add fenced‑code language and fix typos in inline comments.Also correct “choicse” → “choices”.
-``` +```yaml policy: megatron_cfg: fp8_cfg: fp8: "hybrid" # choices: [hybrid, e4m3] - fp8_recipe: "tensorwise" # choicse: [tensorwise, blockwise] + fp8_recipe: "tensorwise" # choices: [tensorwise, blockwise] fp8_param: false # boolean value--- `51-65`: **Code‑fence language + version claim verification.** - Remove stray “*” in the heading. - Add language to fenced blocks (“text” for traceback). - Please double‑check the exact Torch/CUDA versions pinned in this repo before claiming CUDA 12.8/12.9 requirements to avoid stale guidance. ```diff -### Special note with using FP8 training with Deepseek-style FP8 (sub channel scaling)* +### Special note on FP8 training with DeepSeek‑style FP8 (sub‑channel scaling)-``` +```text File "/opt/ray_venvs/... ... AssertionError: FP8 block scaled GEMM requires Hopper and CUDA >= 12.9.</blockquote></details> <details> <summary>tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh (3)</summary><blockquote> `6-11`: **Unused vars (shellcheck) — export or remove.** If consumed by sourced helpers, export; otherwise drop them. ```diff -NUM_NODES=1 +export NUM_NODES=1 STEPS_PER_RUN=100 MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=300 +export NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +export NUM_MINUTES=300
16-17: Handle cd failure.Avoid continuing if PROJECT_ROOT is missing.
-cd $PROJECT_ROOT +cd "$PROJECT_ROOT" || exit 1
31-39: Quote paths in post‑processing.Prevents breakage with spaces/special chars.
-uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS +uv run tests/json_dump_tb_logs.py "$LOG_DIR" --output_path "$JSON_METRICS" -if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then - uv run tests/check_metrics.py $JSON_METRICS \ +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' "$JSON_METRICS") -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py "$JSON_METRICS" \ 'mean(data["train/token_mult_prob_error"]) < 1.1' \ 'data["train/token_mult_prob_error"]["100"] < 1.1' fiexamples/configs/grpo_math_8B_megatron_fp8.yaml (1)
13-24: Add trailing newline and confirm FP8 recipe support.
- Fix newline at EOF (yamllint).
- Validate fp8="hybrid" with fp8_recipe="blockwise" on your TE build.
examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.yaml (1)
123-144: Clarify max_new_tokens semantics vs prompt length.You pass tokens_to_generate = max_new_tokens − prompt_len in generate(); verify this field represents max total sequence length, not “new tokens”. Otherwise generations may be shorter than intended.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
.gitmodules(1 hunks)3rdparty/Megatron-Bridge-workspace/Megatron-Bridge(1 hunks)docs/fp8.md(2 hunks)examples/configs/grpo_math_70B_megatron_fp8.yaml(1 hunks)examples/configs/grpo_math_8B_megatron_fp8.yaml(1 hunks)examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.yaml(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)nemo_rl/models/policy/megatron_policy_worker.py(3 hunks)tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh(1 hunks)tests/test_suites/nightly.txt(1 hunks)tests/unit/models/policy/test_megatron_worker.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
- .gitmodules
- tests/test_suites/nightly.txt
- tests/unit/models/policy/test_megatron_worker.py
- nemo_rl/models/policy/lm_policy.py
🧰 Additional context used
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_70B_megatron_fp8.yaml
[error] 24-24: no new line character at the end of file
(new-line-at-end-of-file)
examples/configs/grpo_math_8B_megatron_fp8.yaml
[error] 24-24: no new line character at the end of file
(new-line-at-end-of-file)
🪛 markdownlint-cli2 (0.17.2)
docs/fp8.md
42-42: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 Ruff (0.12.2)
nemo_rl/models/policy/megatron_policy_worker.py
603-603: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
603-603: Avoid specifying long messages outside the exception class
(TRY003)
🪛 Shellcheck (0.10.0)
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
🔇 Additional comments (4)
examples/configs/grpo_math_70B_megatron_fp8.yaml (1)
15-24: Validate FP8 format/recipe compatibility.Confirm TransformerEngine build supports fp8="hybrid" with fp8_recipe="blockwise" and fp8_param=true for this model; mismatches will fail at init.
examples/configs/grpo_math_8B_megatron_fp8.yaml (1)
10-13: Generation memory cap looks conservative; OK to merge.gpu_memory_utilization=0.5 is a safe default for 8B; tune later if needed.
examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-training.yaml (2)
21-22: Token-level loss requires token_mask.Ensure the data pipeline populates token_mask; train() asserts its presence for TOKEN_LEVEL loss.
114-121: Environment guard for FP8 block scaling.Good to surface NVTE_FP8_BLOCK_SCALING_FP32_SCALES here; confirm the runtime actually exports env_vars from config before model init.
jgerh
left a comment
There was a problem hiding this comment.
Completed the tech pubs review of docs/fp8.md and provided a few suggested text revisions and added some headings for readability.
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
examples/configs/grpo_math_70B_megatron_fp8.yaml (2)
21-22: Precision-aware optimizer disabled: confirm intended.With FP8 training, disabling use_precision_aware_optimizer may affect stability. If this is intentional, consider documenting the rationale in the recipe.
24-24: Add trailing newline.YAMLlint flags missing newline at EOF. Please add one.
- NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1" + NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1" +
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/configs/grpo_math_70B_megatron.yaml(1 hunks)examples/configs/grpo_math_70B_megatron_fp8.yaml(1 hunks)nemo_rl/models/policy/lm_policy.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- nemo_rl/models/policy/lm_policy.py
🧰 Additional context used
🪛 YAMLlint (1.37.1)
examples/configs/grpo_math_70B_megatron_fp8.yaml
[error] 24-24: no new line character at the end of file
(new-line-at-end-of-file)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (5)
examples/configs/grpo_math_70B_megatron_fp8.yaml (4)
10-12: Confirm vLLM FP8 + Deep GEMM support on your build.precision: "fp8" with use_deep_gemm: true requires the runtime build to have FP8 Deep GEMM enabled. If not available, add a fallback (e.g., precision: "bfloat16") or feature-flag this.
Run to verify any feature flags/env required by your vLLM build:
13-13: Rename sequence_length_round → sequence_length_pad_multiple under sequence_packingsequence_packing expects sequence_length_pad_multiple (sequence_length_round is used only by dynamic_batching); the current YAML places sequence_length_round under sequence_packing so it will be ignored. Either rename the key or remove it — if you want a user-configurable pad you must also change lm_policy to read the config value instead of computing it.
Apply this diff:
- sequence_packing: - sequence_length_round: 16 + sequence_packing: + sequence_length_pad_multiple: 16
15-20: Parallelism settings — verify effective TP/PP/CP world-size with FP8 enabled.Do not assume TP=4; the code reads tensor_model_parallel_size / pipeline_model_parallel_size / context_parallel_size from megartron_cfg and applies fp8_cfg to model_cfg. Confirm the effective tensor_model_parallel_size, pipeline_model_parallel_size, and context_parallel_size so TP * PP * CP equals the intended GPUs-per-replica, and verify the Megatron-Bridge FP8 combo ("hybrid" + "blockwise" with fp8_param=false) is supported by your FP8 branch.
Check these locations: examples/configs/grpo_math_70B_megatron_fp8.yaml, nemo_rl/models/policy/megatron_policy_worker.py, nemo_rl/models/megatron/community_import.py, nemo_rl/models/generation/fp8.py.
Use your earlier grep (or cat the YAML) to confirm runtime/merged megartron_cfg values.
23-24: Env var plumbing verified — NVTE_ from policy.megatron_cfg.env_vars is propagated to trainer ranks.*lm_policy.py reads megatron_cfg.env_vars and passes it into RayWorkerGroup; RayWorkerGroup merges those with os.environ and injects them into worker runtime_env (nemo_rl/distributed/worker_groups.py), so NVTE_FP8_BLOCK_SCALING_FP32_SCALES will reach trainer processes.
examples/configs/grpo_math_70B_megatron.yaml (1)
62-62: TP=8 for vLLM: verify orchestration/resources.With tensor_parallel_size: 8, vLLM will occupy 8 GPUs per engine. Ensure the RL orchestration reserves nodes/GPUs accordingly and that training (TP=4, PP=4 → 16 GPUs) and generation don’t contend on the same hosts.
Run to locate the launcher/allocator that wires vLLM TP to nodes:
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> Signed-off-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com>
Signed-off-by: root <root@cw-dfw-h100-004-236-026.cm.cluster>
|
fixed another bug for unit test, please queue again @terrykong |
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
|
@terrykong i am dropping the submodule change as it turns out that change was not needed, please review again. Details: the previous mbridge change was to support refit with fp8_param=true setup, in this setup we need to first dequant fp8 params to bf16 and then do the mcore->hf conversion. But even so there is still a data race bug with this scenario (#1164) and I am disabling it for all benchmarks. Perf impact should be quite small. |
Signed-off-by: Guyue Huang <guyueh@nvidia.com> Signed-off-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com> Signed-off-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com> Signed-off-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
Enable fp8 training in megatron path.
Issues
List issues that this PR closes (syntax):
Closes: #820
Usage
uv run python examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml policy.megatron_cfg.fp8="hybrid"Before your PR is "Ready for review"
Pre checks:
Additional Information
doc/fp8.mdSummary by CodeRabbit