Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 20 additions & 59 deletions colossalai/inference/modeling/models/glide_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@

import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)

from colossalai.inference.spec import GlideInput
Expand Down Expand Up @@ -156,31 +153,29 @@ def glide_llama_model_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
if hasattr(glide_input, "n_spec_tokens"):
position_ids = position_ids + glide_input.n_spec_tokens

# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -189,9 +184,9 @@ def glide_llama_model_forward(
# GlideLlamaDecoderLayer
layer_outputs = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -200,9 +195,6 @@ def glide_llama_model_forward(

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand All @@ -212,16 +204,11 @@ def glide_llama_model_forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
Expand Down Expand Up @@ -267,41 +254,17 @@ def __init__(self, config: GlideLlamaConfig):

self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Optional[torch.Tensor]:
Expand All @@ -319,8 +282,7 @@ def forward(
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)

# for RoPE
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
cos, sin = position_embeddings
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
Expand Down Expand Up @@ -367,9 +329,10 @@ def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlam
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
Expand Down Expand Up @@ -399,10 +362,10 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
Expand All @@ -425,9 +388,10 @@ def forward(

hidden_states = self.cross_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
position_ids=position_ids,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=True,
)
Expand All @@ -441,9 +405,6 @@ def forward(

outputs = (hidden_states,)

if use_cache:
outputs += (present_key_value,)

return outputs


Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,9 @@ def from_native_module(
attn_oproj=attn_oproj,
process_group=process_group,
model_shard_infer_config=model_shard_infer_config,
num_heads=module.num_heads,
hidden_size=module.hidden_size,
num_key_value_heads=module.num_key_value_heads,
num_heads=module.config.num_attention_heads,
hidden_size=module.config.hidden_size,
num_key_value_heads=module.config.num_key_value_heads,
)

return attn_layer
Expand Down
6 changes: 4 additions & 2 deletions colossalai/inference/spec/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer
from transformers.cache_utils import DynamicCache

from colossalai.utils import get_current_device

Expand Down Expand Up @@ -93,9 +94,8 @@ def speculate(

for _ in range(n_spec_tokens):
# update past key values
kwargs["past_key_values"] = past_key_values

outputs = self._drafter_model(input_ids, **kwargs)
outputs = self._drafter_model(input_ids, past_key_values=past_key_values, **kwargs)
next_token_logits = outputs.logits[:, -1, :]

# NOTE Only use greedy search for speculating.
Expand All @@ -114,6 +114,8 @@ def speculate(
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
if isinstance(past_key_values, DynamicCache):
past_key_values = past_key_values.to_legacy_cache()

out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
Expand Down
5 changes: 3 additions & 2 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def new_from_pretrained(
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
Expand Down Expand Up @@ -286,7 +286,8 @@ def new_from_pretrained(
config.name_or_path = pretrained_model_name_or_path

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
# init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = [no_init_weights()]

with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
Expand Down
Loading