diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index e61189c537..4795560124 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -145,7 +145,7 @@ policy: moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo moe_permute_fusion: false - #gives ~20% training perf speedup with sequence packing + # gives ~20% training perf speedup with sequence packing apply_rope_fusion: True # gives ~25% training perf speedup with sequence packing and apply_rope_fusion bias_activation_fusion: True diff --git a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml index 8034e78c54..61e4c5f7b2 100644 --- a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml +++ b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml @@ -53,6 +53,9 @@ policy: sequence_parallel: true moe_permute_fusion: true apply_rope_fusion: false + gradient_accumulation_fusion: false + # MTP — disabled + mtp_num_layers: 0 optimizer: lr: 5.0e-07 min_lr: 5.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml index 75457ab802..1bea98a83f 100644 --- a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml @@ -29,8 +29,11 @@ policy: num_layers_in_first_pipeline_stage: 3 num_layers_in_last_pipeline_stage: 2 apply_rope_fusion: false + gradient_accumulation_fusion: false moe_permute_fusion: true defer_fp32_logits: true + # MTP — disabled + mtp_num_layers: 0 optimizer: lr: 5.0e-07 min_lr: 5.0e-08 diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 351b8e2cb1..13199cef2f 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -340,6 +340,9 @@ def setup_model_config( # Apply MoE settings _apply_moe_config(model_cfg, config) + # Apply MTP settings + _apply_mtp_config(model_cfg, config) + # Apply precision settings _apply_precision_config(model_cfg, config, dtype) @@ -439,6 +442,11 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] +def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None: + if "mtp_num_layers" in config["megatron_cfg"]: + model_cfg.mtp_num_layers = config["megatron_cfg"]["mtp_num_layers"] + + def _apply_precision_config( model_cfg: Any, config: PolicyConfig, dtype: torch.dtype ) -> None: diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 1dad9dcd41..8ccd61958c 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -249,6 +249,8 @@ class MegatronConfig(TypedDict): # Number of tokens per chunk when computing the fused linear CE loss. # Smaller values reduce peak memory further but may decrease throughput. linear_ce_fusion_chunk_size: NotRequired[int] + # When mtp_num_layers=0, Multi-Token Prediction is disabled. + mtp_num_layers: NotRequired[int] class DraftConfigDisabled(TypedDict):