diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 682111184..82a3a117f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -205,6 +205,7 @@ def train(): "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, "eagle_architecture_config": custom_config, + "eagle_train_length": training_args.training_seq_len, } mtsp.convert(model, [("eagle", config)]) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index d287b7474..1c72ffbc3 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -175,6 +175,15 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module): if self.hf_quant_config is not None: template_config["quantization_config"] = self.hf_quant_config + # For long context quality, we disable rope scaling for training + # and set yarn during export for inference. + template_config["rope_scaling"] = { + "rope_type": "yarn", + "rope_theta": 10000, + "factor": 32.0, + "original_max_position_embeddings": model.eagle_train_length, + } + return template_config def export_quant_config(self): diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 41987d4e4..754108a2b 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -105,3 +105,10 @@ class EagleConfig(ModeloptBaseConfig): default="llama", description=("The class of eagle decoder to use. Available options: llama, kimik2"), ) + + eagle_train_length: int = ModeloptField( + default=2048, + description=( + "The length of the training data. Used to set original_max_position_embeddings in rope_scaling." + ), + ) diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index 2b085d5e3..646897c30 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -58,6 +58,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu eagle_loss_decay_factor=config.eagle_loss_decay_factor, eagle_architecture_config=config.eagle_architecture_config, eagle_decoder_type=config.eagle_decoder_type, + eagle_train_length=config.eagle_train_length, ) # no metadata, all specified via config. diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index f8c4924c1..9102e69da 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -19,14 +19,7 @@ "hidden_act": "silu", "torch_dtype": "bfloat16", "position_embedding_type": "rope", - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3", - }, - "rope_theta": 500000.0, + "rope_scaling": {"rope_type": "default", "rope_theta": 10000}, "num_hidden_layers": 1, "intermediate_size": 14336, "num_attention_heads": 32, @@ -83,15 +76,9 @@ "qk_rope_head_dim": 64, "rms_norm_eps": 0.00001, "rope_scaling": { - "beta_fast": 1.0, - "beta_slow": 1.0, - "factor": 64.0, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": 4096, - "type": "yarn", + "rope_type": "default", + "rope_theta": 10000, }, - "rope_theta": 50000.0, "routed_scaling_factor": 2.827, "scoring_func": "sigmoid", "seq_aux": True, diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index d54fdc843..591f9f791 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -35,6 +35,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_train_length, ): """Base Eagle Model modify function. Child class should implement the details.""" self.eagle_offline = eagle_offline @@ -45,3 +46,4 @@ def modify( self.eagle_reuse_base_decoder = eagle_reuse_base_decoder self.eagle_loss_decay_factor = eagle_loss_decay_factor self.eagle_decoder_type = eagle_decoder_type + self.eagle_train_length = eagle_train_length diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index e37e8f931..09e745c55 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -693,6 +693,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_train_length, ): if self.config.pipeline_model_parallel_size > 1: warnings.warn( @@ -715,6 +716,7 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + eagle_train_length=eagle_train_length, ) # sequence_parallel is not used in offline eagle diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 23d0254e8..d0499c599 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -560,6 +560,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_train_length, ): """Constructor. @@ -576,6 +577,7 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + eagle_train_length=eagle_train_length, ) if eagle_decoder_type == "llama":