Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def train():
"eagle_decoder_type": eagle_args.eagle_decoder_type,
"eagle_offline": use_offline_training,
"eagle_architecture_config": custom_config,
"eagle_train_length": training_args.training_seq_len,
}

mtsp.convert(model, [("eagle", config)])
Expand Down
9 changes: 9 additions & 0 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module):
if self.hf_quant_config is not None:
template_config["quantization_config"] = self.hf_quant_config

# For long context quality, we disable rope scaling for training
# and set yarn during export for inference.
template_config["rope_scaling"] = {
"rope_type": "yarn",
"rope_theta": 10000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure rope theta goes on the main config and not the rope scaling, and should be set the same for training/inference. Where did this template come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"factor": 32.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure these are the best choices for rope theta and factor. I think these might depend on how long max_position_embeddings actually is.

Some testing may be required. Gpt Oss uses rope theta 150k, for example. This may be some tradeoff between short-context and long-context accuracy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theta=10k is the default from HF: ref

Actually my guess is that it should match the theta used in training.

factor should be a tradeoff I think.

"original_max_position_embeddings": model.eagle_train_length,
}

return template_config

def export_quant_config(self):
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,10 @@ class EagleConfig(ModeloptBaseConfig):
default="llama",
description=("The class of eagle decoder to use. Available options: llama, kimik2"),
)

eagle_train_length: int = ModeloptField(
default=2048,
description=(
"The length of the training data. Used to set original_max_position_embeddings in rope_scaling."
),
)
1 change: 1 addition & 0 deletions modelopt/torch/speculative/eagle/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
eagle_architecture_config=config.eagle_architecture_config,
eagle_decoder_type=config.eagle_decoder_type,
eagle_train_length=config.eagle_train_length,
)

# no metadata, all specified via config.
Expand Down
19 changes: 3 additions & 16 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,7 @@
"hidden_act": "silu",
"torch_dtype": "bfloat16",
"position_embedding_type": "rope",
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"rope_scaling": {"rope_type": "default", "rope_theta": 10000},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can go further and actually just set rope scaling to null. Not sure if there's a difference in HF

Copy link
Contributor Author

@h-guo18 h-guo18 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting it to None triggers an error. We are using Llama definition from transformers 5.0 and it requires a rope type. "rope_type":"default" here will use the traditional rope without scaling.

ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L85

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a transforemr verfsion difference. In transforemr 4.x it's ok to leave it None. We are using transformer 5 here.

"num_hidden_layers": 1,
"intermediate_size": 14336,
"num_attention_heads": 32,
Expand Down Expand Up @@ -83,15 +76,9 @@
"qk_rope_head_dim": 64,
"rms_norm_eps": 0.00001,
"rope_scaling": {
"beta_fast": 1.0,
"beta_slow": 1.0,
"factor": 64.0,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
"rope_type": "default",
"rope_theta": 10000,
},
"rope_theta": 50000.0,
"routed_scaling_factor": 2.827,
"scoring_func": "sigmoid",
"seq_aux": True,
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def modify(
eagle_loss_decay_factor,
eagle_architecture_config,
eagle_decoder_type,
eagle_train_length,
):
"""Base Eagle Model modify function. Child class should implement the details."""
self.eagle_offline = eagle_offline
Expand All @@ -45,3 +46,4 @@ def modify(
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
self.eagle_loss_decay_factor = eagle_loss_decay_factor
self.eagle_decoder_type = eagle_decoder_type
self.eagle_train_length = eagle_train_length
2 changes: 2 additions & 0 deletions modelopt/torch/speculative/plugins/megatron_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def modify(
eagle_loss_decay_factor,
eagle_architecture_config,
eagle_decoder_type,
eagle_train_length,
):
if self.config.pipeline_model_parallel_size > 1:
warnings.warn(
Expand All @@ -715,6 +716,7 @@ def modify(
eagle_loss_decay_factor=eagle_loss_decay_factor,
eagle_architecture_config=eagle_architecture_config,
eagle_decoder_type=eagle_decoder_type,
eagle_train_length=eagle_train_length,
)

# sequence_parallel is not used in offline eagle
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def modify(
eagle_loss_decay_factor,
eagle_architecture_config,
eagle_decoder_type,
eagle_train_length,
):
"""Constructor.

Expand All @@ -576,6 +577,7 @@ def modify(
eagle_loss_decay_factor=eagle_loss_decay_factor,
eagle_architecture_config=eagle_architecture_config,
eagle_decoder_type=eagle_decoder_type,
eagle_train_length=eagle_train_length,
)

if eagle_decoder_type == "llama":
Expand Down