diff --git a/medusa/model/medusa_model.py b/medusa/model/medusa_model.py index c4fd997..c164b76 100644 --- a/medusa/model/medusa_model.py +++ b/medusa/model/medusa_model.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig -from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM +from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM, LlamaDecoderLayer, LlamaRMSNorm, _make_causal_mask, _expand_mask from transformers import AutoTokenizer from .utils import * from .kv_cache import initialize_past_key_values @@ -78,6 +78,7 @@ def __init__( base_model, medusa_num_heads=2, medusa_num_layers=1, + medusa_num_decoder_layers=2, base_model_name_or_path="lmsys/vicuna-7b-v1.3", ): """ @@ -93,8 +94,31 @@ def __init__( self.vocab_size = base_model.lm_head.weight.shape[0] self.medusa = medusa_num_heads self.medusa_num_layers = medusa_num_layers + self.medusa_num_decoder_layers = medusa_num_decoder_layers self.base_model_name_or_path = base_model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) + + # === + # [MEDUSA-COPY] + # Fork two decoder layers and RMS norm for fine tuning with Medusa heads + self.medusa_decoder_layers = nn.ModuleList( + [LlamaDecoderLayer(base_model.config) for _ in range(medusa_num_decoder_layers)] + ) + self.medusa_rms_norm = LlamaRMSNorm(self.hidden_size, eps=base_model.config.rms_norm_eps) + + self.medusa_decoder_layers.to(self.base_model.dtype).to(self.base_model.device) + self.medusa_rms_norm.to(self.base_model.dtype).to(self.base_model.device) + + # Initialize Medusa decoder layers and RMS norm layer with the parameters from the last layers of the base model + with torch.no_grad(): + for i in range(medusa_num_decoder_layers): + for name, param in self.medusa_decoder_layers[-(i + 1)].named_parameters(): + param.copy_(dict(base_model.model.layers[-(i + 1)].named_parameters())[name]) + + for name, param in self.medusa_rms_norm.named_parameters(): + param.copy_(dict(base_model.model.norm.named_parameters())[name]) + # === + # Create a list of Medusa heads self.medusa_head = nn.ModuleList( [ @@ -160,6 +184,87 @@ def from_pretrained( return model + # Copied from modeling_llama_kv.LlamaModel._prepare_decoder_attention_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + # inputs_embeds.dtype, + torch.float32, # [MODIFIED] force to cast to float32 + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + # [MODIFIED] add medusa mask + if hasattr(self, "medusa_mask") and self.medusa_mask is not None: + medusa_mask = self.medusa_mask + medusa_len = medusa_mask.size(-1) + combined_attention_mask[:, :, -medusa_len:, -medusa_len:][ + medusa_mask == 0 + ] = combined_attention_mask.min() + if hasattr(self, "medusa_mode"): + # debug mode + if self.medusa_mode == "debug": + torch.save(combined_attention_mask, "medusa_mask.pt") + + return combined_attention_mask + + # Copied from modeling_llama_kv.LlamaModel.forward + def _prepare_decoder_inputs(self, hidden_states, past_key_values, input_ids, position_ids, attention_mask): + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else hidden_states.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + # Passing hidden_states instead of input_embeds since only used + # for dtype and device + hidden_states, + past_key_values_length, + ) + + return attention_mask, position_ids + def forward( self, input_ids=None, @@ -168,6 +273,8 @@ def forward( past_key_values=None, output_orig=False, position_ids=None, + # [MEDUSA-COPY] + output_hidden_states=True, ): """Forward pass of the MedusaModel. @@ -193,8 +300,30 @@ def forward( ) if output_orig: orig = self.base_model.lm_head(outputs[0]) - # Clone the output hidden states - hidden_states = outputs[0].clone() + + # === + # [MEDUSA-COPY] + # Clone the output hidden states before Medusa decoder fork + hidden_states = (outputs.hidden_states)[-1 * (self.medusa_num_decoder_layers + 1)].clone() + + attention_mask, position_ids = self._prepare_decoder_inputs( + hidden_states, past_key_values, input_ids, position_ids, attention_mask + ) + + # Pass hidden states through medusa decoder layers + for decoder_layer in self.medusa_decoder_layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + ) + hidden_states = layer_outputs[0] + hidden_states = self.medusa_rms_norm(hidden_states) + # === + medusa_logits = [] # TODO: Consider parallelizing this loop for efficiency? for i in range(self.medusa):