Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ policy:
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
moe_permute_fusion: false
#gives ~20% training perf speedup with sequence packing
# gives ~20% training perf speedup with sequence packing
apply_rope_fusion: True
# gives ~25% training perf speedup with sequence packing and apply_rope_fusion
bias_activation_fusion: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ policy:
sequence_parallel: true
moe_permute_fusion: true
apply_rope_fusion: false
gradient_accumulation_fusion: false
# MTP — disabled
mtp_num_layers: 0
optimizer:
lr: 5.0e-07
min_lr: 5.0e-08
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ policy:
num_layers_in_first_pipeline_stage: 3
num_layers_in_last_pipeline_stage: 2
apply_rope_fusion: false
gradient_accumulation_fusion: false
moe_permute_fusion: true
defer_fp32_logits: true
# MTP — disabled
mtp_num_layers: 0
optimizer:
lr: 5.0e-07
min_lr: 5.0e-08
Expand Down
8 changes: 8 additions & 0 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def setup_model_config(
# Apply MoE settings
_apply_moe_config(model_cfg, config)

# Apply MTP settings
_apply_mtp_config(model_cfg, config)

# Apply precision settings
_apply_precision_config(model_cfg, config, dtype)

Expand Down Expand Up @@ -439,6 +442,11 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None:
model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"]


def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None:
if "mtp_num_layers" in config["megatron_cfg"]:
model_cfg.mtp_num_layers = config["megatron_cfg"]["mtp_num_layers"]


def _apply_precision_config(
model_cfg: Any, config: PolicyConfig, dtype: torch.dtype
) -> None:
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ class MegatronConfig(TypedDict):
# Number of tokens per chunk when computing the fused linear CE loss.
# Smaller values reduce peak memory further but may decrease throughput.
linear_ce_fusion_chunk_size: NotRequired[int]
# When mtp_num_layers=0, Multi-Token Prediction is disabled.
mtp_num_layers: NotRequired[int]


class DraftConfigDisabled(TypedDict):
Expand Down
Loading