Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 132 additions & 3 deletions medusa/model/medusa_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM, LlamaDecoderLayer, LlamaRMSNorm, _make_causal_mask, _expand_mask
from transformers import AutoTokenizer
from .utils import *
from .kv_cache import initialize_past_key_values
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
base_model,
medusa_num_heads=2,
medusa_num_layers=1,
medusa_num_decoder_layers=2,
base_model_name_or_path="lmsys/vicuna-7b-v1.3",
):
"""
Expand All @@ -93,8 +94,31 @@ def __init__(
self.vocab_size = base_model.lm_head.weight.shape[0]
self.medusa = medusa_num_heads
self.medusa_num_layers = medusa_num_layers
self.medusa_num_decoder_layers = medusa_num_decoder_layers
self.base_model_name_or_path = base_model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)

# ===
# [MEDUSA-COPY]
# Fork two decoder layers and RMS norm for fine tuning with Medusa heads
self.medusa_decoder_layers = nn.ModuleList(
[LlamaDecoderLayer(base_model.config) for _ in range(medusa_num_decoder_layers)]
)
self.medusa_rms_norm = LlamaRMSNorm(self.hidden_size, eps=base_model.config.rms_norm_eps)

self.medusa_decoder_layers.to(self.base_model.dtype).to(self.base_model.device)
self.medusa_rms_norm.to(self.base_model.dtype).to(self.base_model.device)

# Initialize Medusa decoder layers and RMS norm layer with the parameters from the last layers of the base model
with torch.no_grad():
for i in range(medusa_num_decoder_layers):
for name, param in self.medusa_decoder_layers[-(i + 1)].named_parameters():
param.copy_(dict(base_model.model.layers[-(i + 1)].named_parameters())[name])

for name, param in self.medusa_rms_norm.named_parameters():
param.copy_(dict(base_model.model.norm.named_parameters())[name])
# ===

# Create a list of Medusa heads
self.medusa_head = nn.ModuleList(
[
Expand Down Expand Up @@ -160,6 +184,87 @@ def from_pretrained(

return model

# Copied from modeling_llama_kv.LlamaModel._prepare_decoder_attention_mask
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
# inputs_embeds.dtype,
torch.float32, # [MODIFIED] force to cast to float32
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)

# [MODIFIED] add medusa mask
if hasattr(self, "medusa_mask") and self.medusa_mask is not None:
medusa_mask = self.medusa_mask
medusa_len = medusa_mask.size(-1)
combined_attention_mask[:, :, -medusa_len:, -medusa_len:][
medusa_mask == 0
] = combined_attention_mask.min()
if hasattr(self, "medusa_mode"):
# debug mode
if self.medusa_mode == "debug":
torch.save(combined_attention_mask, "medusa_mask.pt")

return combined_attention_mask

# Copied from modeling_llama_kv.LlamaModel.forward
def _prepare_decoder_inputs(self, hidden_states, past_key_values, input_ids, position_ids, attention_mask):
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
device = input_ids.device if input_ids is not None else hidden_states.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=hidden_states.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
# Passing hidden_states instead of input_embeds since only used
# for dtype and device
hidden_states,
past_key_values_length,
)

return attention_mask, position_ids

def forward(
self,
input_ids=None,
Expand All @@ -168,6 +273,8 @@ def forward(
past_key_values=None,
output_orig=False,
position_ids=None,
# [MEDUSA-COPY]
output_hidden_states=True,
):
"""Forward pass of the MedusaModel.

Expand All @@ -193,8 +300,30 @@ def forward(
)
if output_orig:
orig = self.base_model.lm_head(outputs[0])
# Clone the output hidden states
hidden_states = outputs[0].clone()

# ===
# [MEDUSA-COPY]
# Clone the output hidden states before Medusa decoder fork
hidden_states = (outputs.hidden_states)[-1 * (self.medusa_num_decoder_layers + 1)].clone()

attention_mask, position_ids = self._prepare_decoder_inputs(
hidden_states, past_key_values, input_ids, position_ids, attention_mask
)

# Pass hidden states through medusa decoder layers
for decoder_layer in self.medusa_decoder_layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)
hidden_states = layer_outputs[0]
hidden_states = self.medusa_rms_norm(hidden_states)
# ===

medusa_logits = []
# TODO: Consider parallelizing this loop for efficiency?
for i in range(self.medusa):
Expand Down