Skip to content

fix: report the correct number of workers during FLOPs calculation#1034

Merged
terrykong merged 14 commits intomainfrom
ybgao/sep2-fix-flops-accounting
Sep 9, 2025
Merged

fix: report the correct number of workers during FLOPs calculation#1034
terrykong merged 14 commits intomainfrom
ybgao/sep2-fix-flops-accounting

Conversation

@ybgao-nvidia
Copy link
Copy Markdown
Contributor

@ybgao-nvidia ybgao-nvidia commented Sep 2, 2025

What does this PR do ?

This PR:

  1. corrects the theoretical FLOPs accounting when computing the floating point utilization
  2. adds unit tests to ensure the correct number of model FLOPs is reported

Issues

This PR resolves #933, resolves #1015.

Usage

N/A

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added FLOPs tracking support for Qwen3 MoE and expanded FLOPs config to include MoE-specific fields.
  • Bug Fixes / Behavior Changes

    • Adjusted multi-GPU FLOPs aggregation to account for GPUs per worker.
    • Unified Llama FLOPs formula and reduced encoder-length contribution.
  • Tests

    • Added FLOPs range and utilization checks for Megatron and DTensor policies.
    • Added comprehensive FLOPs counter validation across multiple LLMs.

@coderabbitai ignore

Signed-off-by: Yubo Gao <yubog@nvidia.com>
@ybgao-nvidia ybgao-nvidia changed the title report the correct number of workers during FLOPs calculation fix: report the correct number of workers during FLOPs calculation Sep 2, 2025
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add unit tests for this? may also be good to add a few TFLOP unit tests since it (noticed that wasn't in https://github.com/NVIDIA-NeMo/RL/pull/632/files), can grab from some of the internal sheets, e.g., llama 8b/qwen3 30B_A3B/nemotron-h @ bs=128

parthchadha
parthchadha previously approved these changes Sep 2, 2025
Copy link
Copy Markdown
Contributor

@parthchadha parthchadha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, lets add a unit test as @terrykong requested.

Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
@ybgao-nvidia ybgao-nvidia added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Sep 4, 2025
wangshangsam
wangshangsam previously approved these changes Sep 4, 2025
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/utils/flops_tracker.py (1)

55-67: Update Qwen3-MoE field mapping with actual HF config attributes
The qwen3 FLOPs formula uses moe_router_topk, moe_layer_freq, moe_shared_expert_intermediate_size, and moe_ffn_hidden_size. Map these to Qwen3MoeConfig as follows:

--- a/nemo_rl/utils/flops_tracker.py
+++ b/nemo_rl/utils/flops_tracker.py
@@ -55,13 +55,17 @@ def get_flops_config(config):
     elif isinstance(config, Qwen3MoeConfig):
-        return FLOPSConfig(
+        return FLOPSConfig(
             gbs=0,
             hs=config.hidden_size,
             layers=config.num_hidden_layers,
             ffn_hs=config.intermediate_size,
             vocab_size=config.vocab_size,
-            query_groups=getattr(config, "num_key_value_heads", getattr(config, "num_kv_heads", None)),
+            query_groups=getattr(config, "num_key_value_heads", getattr(config, "num_kv_heads", None)),
             attention_heads=config.num_attention_heads,
-            moe_ffn_hidden_size=getattr(
-                config, "moe_intermediate_size", getattr(config, "expert_intermediate_size", config.intermediate_size)
-            ),
-            moe_shared_expert_intermediate_size=getattr(config, "moe_shared_expert_intermediate_size", None),
-            moe_layer_freq=getattr(config, "moe_layer_freq", None),
-            moe_router_topk=getattr(
-                config, "moe_router_topk", getattr(config, "num_experts_per_tok", getattr(config, "moe_topk", None))
-            ),
+            # MoE-specific mapping to HF Qwen3-MoE config
+            moe_ffn_hidden_size=getattr(config, "moe_intermediate_size", config.intermediate_size),
+            moe_shared_expert_intermediate_size=getattr(config, "moe_intermediate_size", None),
+            moe_layer_freq=getattr(config, "decoder_sparse_step", getattr(config, "mlp_only_layers", None)),
+            moe_router_topk=config.num_experts_per_tok,
         ), qwen3
🧹 Nitpick comments (2)
nemo_rl/utils/flops_tracker.py (2)

77-77: Model-name gating is fine; consider future-proofing.

"llama-3" in model_name.lower() works for 3.x variants. Optionally, gate by HF config (e.g., a model_type check) to avoid brittle string matches if naming shifts.

-        ), llama3 if "llama-3" in model_name.lower() else llama2
+        ), llama3 if ("llama-3" in model_name.lower() or getattr(config, "model_type", "") == "llama3") else llama2

23-23: Add fallback import for dense Qwen3Config

Replace the lone MoE import in nemo_rl/utils/flops_tracker.py with a guarded import that also supports dense Qwen3:

-from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
+from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
+try:
+    from transformers.models.qwen3.configuration_qwen3 import Qwen3Config  # type: ignore
+except ImportError:
+    Qwen3Config = None  # type: ignore

This ensures forward‐compatibility for users loading dense Qwen3 models in environments without MoE.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 88c4e51 and 47431ca.

📒 Files selected for processing (2)
  • nemo_rl/utils/flops_tracker.py (3 hunks)
  • tests/unit/utils/test_flops_counter.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/utils/test_flops_counter.py
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/utils/flops_tracker.py (1)
nemo_rl/utils/flops_formulas.py (3)
  • FLOPSConfig (21-59)
  • llama3 (90-105)
  • llama2 (72-87)
🔇 Additional comments (1)
nemo_rl/utils/flops_tracker.py (1)

74-75: LLaMA query_groups source looks correct for GQA.

Using num_key_value_heads for query_groups aligns with GQA where groups = KV heads. This fixes the previous ratio error.

@ybgao-nvidia ybgao-nvidia added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Sep 9, 2025
Comment thread nemo_rl/utils/flops_tracker.py Outdated
wangshangsam
wangshangsam previously approved these changes Sep 9, 2025
Copy link
Copy Markdown
Contributor

@wangshangsam wangshangsam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have further additional comment.

Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
nemo_rl/utils/flops_formulas.py (2)

460-469: AttributeError risk: gated_linear_unit is not a field on FLOPSConfig.

_mlp_layer_flops accesses config.gated_linear_unit, which isn’t defined on FLOPSConfig. This will raise at runtime if the hybrid path is used.

Apply:

@@ class FLOPSConfig:
     mamba_num_heads: Optional[int] = None
+    # Whether MLP uses a gated activation (e.g., SwiGLU); affects 2x factor in FFN FLOPs.
+    gated_linear_unit: bool = False
@@ def _mlp_layer_flops(config: FLOPSConfig):
-    return (
+    gl_mult = 2 if getattr(config, "gated_linear_unit", False) else 1
+    return (
         6
         * config.gbs
         * config.enc_seq_len
         * config.hs
         * config.ffn_hs
-        * (2 if config.gated_linear_unit else 1)
+        * gl_mult
     )

488-508: Potential None-multiply in Mamba path (mamba_num_groups).

config.mamba_num_groups is Optional but multiplied directly; if None, this will crash.

Apply:

@@ def _mamba_layer_flops(config: FLOPSConfig):
-    return (
+    groups = config.mamba_num_groups if config.mamba_num_groups is not None else 1
+    return (
         (
             6
             * config.gbs
             * config.enc_seq_len
             * config.hs
-            * (2 * d_in + 2 * config.mamba_num_groups * config.mamba_state_dim + nheads)
+            * (2 * d_in + 2 * groups * config.mamba_state_dim + nheads)
         )
         + (3 * 2 * config.gbs * config.enc_seq_len * d_in * config.mamba_state_dim)
         + (6 * config.gbs * config.enc_seq_len * d_in * config.hs)
     )
nemo_rl/utils/flops_tracker.py (1)

55-67: Use documented Qwen3 MoE config fields for KV heads and router top-k

Apply:

@@ nemo_rl/utils/flops_tracker.py:55-67
-            query_groups=config.num_key_value_heads,
+            query_groups=getattr(config, "num_key_value_heads", config.num_attention_heads),-            moe_router_topk=1,
+            moe_router_topk=(
+                getattr(config, "num_experts_per_tok", None)
+                or getattr(config, "router_topk", None)
+                or getattr(config, "n_routed_experts", 1)
+            ),

This ensures query_groups uses num_key_value_heads and moe_router_topk uses num_experts_per_tok.

🧹 Nitpick comments (6)
nemo_rl/utils/flops_formulas.py (3)

71-86: Llama: confirm the 0.5x attention term and update the docstring.

You halved the attention-seq term to 6·(enc_seq_len/hs). If this encodes causal self-attn (upper-triangular), please update the docstring to “Llama family (Llama 2/3), causal” and add a short comment citing the causal mask assumption so future edits don’t regress this coefficient.

Apply:

-def llama(config: FLOPSConfig):
-    """Model FLOPs for llama3 family."""
+def llama(config: FLOPSConfig):
+    """Model FLOPs for Llama family (Llama 2/3) under causal self-attention."""
+    # Note: attention term uses 0.5 factor due to causal mask (upper-triangular)

167-207: Qwen3 MoE: router top-k handling looks hard-coded; plumb the real value from HF config.

Formula uses moe_router_topk but elsewhere we set it to 1 by default. Wire the actual top-k from the HF config (e.g., num_experts_per_tok/router_topk/n_routed_experts) to avoid undercounting MoE FLOPs.

If applicable in your HF version, use a defensive getter:

-        * (config.moe_ffn_hidden_size * config.moe_router_topk)  # MoE layers
+        * (config.moe_ffn_hidden_size * config.moe_router_topk)  # MoE layers

And ensure moe_router_topk is set correctly in convert_config_to_flops_config (see tracker comment with diff).


472-485: Minor: simplify an expression in non-MLA attn.

config.enc_seq_len / 2 * 2 cancels to config.enc_seq_len. Simplify for clarity.

-            + config.enc_seq_len / 2 * 2
+            + config.enc_seq_len
nemo_rl/utils/flops_tracker.py (3)

82-95: Device TFLOPS table: confirm BF16 peak numbers and add tensor-core variants as needed.

H100 BF16 is set to 1979/2. Validate this against your MFU denominator definition (whether a fused FMA counts as 2 FLOPs) and consider adding other H100 SKUs if tests rely on them.


117-135: Tracker core looks good; small optimization optional.

Per-token track_batch loops item-wise; for large batches you could accumulate counts by length histogram to reduce Python overhead. Low priority unless hot.


29-41: HF config init: consider surfacing exceptions with model hint.

If AutoConfig fails, catching and rethrowing with model_name and trust_remote_code=True context would ease debugging. Optional.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7a68653 and 3ee48ec.

📒 Files selected for processing (2)
  • nemo_rl/utils/flops_formulas.py (2 hunks)
  • nemo_rl/utils/flops_tracker.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/utils/flops_tracker.py (2)
nemo_rl/models/policy/utils.py (1)
  • sliding_window_overwrite (169-195)
nemo_rl/utils/flops_formulas.py (4)
  • FLOPSConfig (21-59)
  • llama (71-86)
  • qwen2 (125-165)
  • qwen3 (168-208)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Coverage (doc-test)
  • GitHub Check: Coverage (e2e)
  • GitHub Check: Coverage (unit-test)
🔇 Additional comments (5)
nemo_rl/utils/flops_tracker.py (5)

26-26: Consolidated llama import: good cleanup.

Single entry-point reduces branching in formula selection.


69-78: Llama path: correct use of KV groups; nice unification.

Mapping query_groups=num_key_value_heads matches the formula’s query_groups/attention_heads ratio for GQA.


112-115: from_config decoupling from model_name for formula selection: LGTM.

Cleaner API; keeps model_name only for display/telemetry.


44-46: API change is local; no other call sites found


26-26: No remaining llama2/llama3 code references. Only string occurrences in test/docs and example configs remain, which is expected.

Comment thread nemo_rl/utils/flops_tracker.py
@terrykong terrykong added this pull request to the merge queue Sep 9, 2025
Merged via the queue into main with commit 62112f6 Sep 9, 2025
26 checks passed
@terrykong terrykong deleted the ybgao/sep2-fix-flops-accounting branch September 9, 2025 23:35
@samodi-nv
Copy link
Copy Markdown
Contributor

@ybgao-nvidia after applying this MR i no longer see flop utilization for qwen3-8b family of models. I think it's because qwen3config was changed to qwen3MoeConfig in flops_tracker

guyueh1 pushed a commit to guyueh1/NeMo-RL that referenced this pull request Sep 15, 2025
@guyueh1
Copy link
Copy Markdown
Contributor

guyueh1 commented Sep 15, 2025

@ybgao-nvidia I wonder why we are putting Qwen3MoeConfig and Qwen3Config in the same if-branch? for moe shouldn't we use

            moe_ffn_hidden_size=config.moe_intermediate_size,
            moe_router_topk=config.num_experts_per_tok,

instead of

            moe_ffn_hidden_size=config.intermediate_size,
            moe_router_topk=1,

?

@coderabbitai coderabbitai Bot mentioned this pull request Nov 19, 2025
4 tasks
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
@coderabbitai coderabbitai Bot mentioned this pull request Dec 1, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect MFU computation during multi-GPU training with megatron Incorrect MFU printed to console

6 participants