From ad5f976c16dec6fa6f1589c351798ef551690d1a Mon Sep 17 00:00:00 2001 From: asmigosw Date: Mon, 3 Feb 2025 06:16:53 +0000 Subject: [PATCH 1/8] Single qpc support till export Signed-off-by: Amit Raj --- QEfficient/base/modeling_qeff.py | 1 + QEfficient/base/onnx_transforms.py | 46 ++ .../models/mllama/modeling_mllama.py | 427 ++++++++++++++---- .../transformers/models/modeling_auto.py | 64 ++- .../transformers/models/pytorch_transforms.py | 14 +- 5 files changed, 443 insertions(+), 109 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2760cf52f..e48475b78 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -175,6 +175,7 @@ def _export( } if onnx_transform_kwargs is not None: transform_kwargs.update(onnx_transform_kwargs) + for transform in self._onnx_transforms: model, transformed = transform.apply(model, **transform_kwargs) model.metadata_props.append( diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 543ec4e2d..802252ce2 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -91,3 +91,49 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed + + +class RemoveCrossAttentionIOTransform(OnnxTransform): + """ + Removes the input and output names of cross-attention layers. + """ + + @classmethod + def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: + """ + :param onnx_base_dir: Base directory to load tensors (if not already loaded). + """ + layers_to_remove = [3, 8, 13, 18, 23, 28, 33, 38] + names_to_remove = [] + for layer in layers_to_remove: + names_to_remove.append(f'past_key.{layer}_RetainedState') + names_to_remove.append(f'past_value.{layer}_RetainedState') + names_to_remove.append(f'past_key.{layer}') + names_to_remove.append(f'past_value.{layer}') + + graph = model.graph + transformed = False + + # Remove outputs + for name in names_to_remove: + output_to_remove = None + for output in graph.output: + if output.name == name: + output_to_remove = output + break + if output_to_remove: + graph.output.remove(output_to_remove) + transformed = True + + # # Remove inputs + # for name in names_to_remove: + # input_to_remove = None + # for input in graph.input: + # if input.name == name: + # input_to_remove = input + # break + # if input_to_remove: + # graph.input.remove(input_to_remove) + # transformed = True + + return model, transformed \ No newline at end of file diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 90be64096..b4bdd4339 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -10,11 +10,14 @@ import math from typing import List, Optional, Tuple, Union +import requests import torch import torch.nn.functional as F import torch.utils.checkpoint +from PIL import Image from torch import nn from torch.nn import CrossEntropyLoss +from transformers import AutoProcessor from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( @@ -71,6 +74,34 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value) + .any(dim=-1) + .type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, @@ -141,80 +172,94 @@ def _create_causal_mask( return attention_mask -class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): +class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py The only differences are: - add new args cache idx for the kv retention """ - def __init__(self, config: MllamaConfig, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - # Define the general __qeff_init__() for any changes in the init calls - # Set the init in the module mapping pytorch transforms - self.config = config - self.__qeff_init__() - - def __qeff_init__(self): - self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - position_embeddings: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - use_cache: bool = False, - cache_position=None, - **kwargs, - ): + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states = self.q_norm(query_states) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # elif past_key_value is not None: + # Fetch old cache + key_states_old = past_key_value.key_cache[self.layer_idx] + value_states_old = past_key_value.value_cache[self.layer_idx] + + # if cross_attention_states is not None: + # Compute new KV states + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + # if past_key_value is not None: + # # if we have a new image + new tokens, we only computed key_states on that new image + # # we still update the cross key states, past_image, new_image. And use it! + # key_states, value_states = past_key_value.update( + # key_states, + # value_states, + # self.layer_idx, + # {"batch_index": batch_index, "position_ids": position_ids}, + # ) + + # Out-of-place Scatter new into old + # out-of-place is important so the original tensor is not affected, + # otherwise leads to same operations in both graphs + indices = (torch.arange(bsz),) + key_states_new = torch.index_put(key_states_old, indices, key_states) + value_states_new = torch.index_put(value_states_old, indices, value_states) + + # Select old or new image KV states based on q_len + key_states = torch.where(q_len == 1, key_states_old, key_states_new) + value_states = torch.where(q_len == 1, value_states_old, value_states_new) + + # Update the image cache + past_key_value.key_cache[self.layer_idx] = key_states + past_key_value.value_cache[self.layer_idx] = value_states key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + key_states = self.k_norm(key_states) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # attn_weights = torch.where( + # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + # ) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -222,82 +267,99 @@ def forward( return attn_output, attn_weights, past_key_value -class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): +class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py The only differences are: - add new args cache idx for the kv retention """ + def __init__(self, config: MllamaConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.config = config + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) + def forward( self, hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, batch_index: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, + position_embeddings: torch.Tensor = None, output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" + use_cache: bool = False, + cache_position=None, + **kwargs, + ): bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( + 1, 2 + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." ) - elif past_key_value is not None: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - key_states = self.k_norm(key_states) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( self.head_dim ) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - # attn_weights = torch.where( - # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights - # ) + attn_weights = torch.where( + attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + ) + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) if not output_attentions: @@ -1015,7 +1077,108 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - + + +class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError( + "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" + ) + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError( + "`aspect_ratio_ids` must be provided if `pixel_values` is provided" + ) + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + class VisionEncoder(nn.Module): def __init__(self, mllama: MllamaForConditionalGeneration): super().__init__() @@ -1107,4 +1270,88 @@ def forward( ) if "past_key_values" in outputs: outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() - return outputs \ No newline at end of file + return outputs + def generate_mllama_single(self, processor): + ctx_len = 1024 + txt_cfg = self.mllama.config.get_text_config() + num_hidden_layers = txt_cfg.num_hidden_layers + cross_attention_layers = txt_cfg.cross_attention_layers + num_key_value_heads = txt_cfg.num_key_value_heads + head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads + + vis_cfg = self.mllama.config.vision_config + num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 + image_tokens_len = vis_cfg.max_num_tiles * num_patches + + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + image = Image.open(requests.get(url, stream=True).raw) + conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": "How long does it take from invoice date to due date? Be short and concise.", + }, + ], + } + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = processor(text=prompt, images=image, return_tensors="pt", add_special_tokens=False) + inputs["position_ids"] = torch.where( + inputs.pop("attention_mask") == 1, + torch.arange(inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + inputs = dict(inputs) + inputs["past_key_values"] = DynamicCache(num_hidden_layers) + inputs["past_key_values"].key_cache = [0] * num_hidden_layers + inputs["past_key_values"].value_cache = [0] * num_hidden_layers + for i in range(num_hidden_layers): + if i in cross_attention_layers: + idx = cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + inputs["past_key_values"].key_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + inputs["past_key_values"].value_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) + else: + inputs["past_key_values"].key_cache[i] = torch.zeros( + 1, num_key_value_heads, ctx_len, head_dim + ) + inputs["past_key_values"].value_cache[i] = torch.zeros( + 1, num_key_value_heads, ctx_len, head_dim + ) + + output_names = [ + "logits", + # "pixel_values_RetainedState", + *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], + ] + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + for i in range(num_hidden_layers): + if i in cross_attention_layers: + dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + else: + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + inputs["past_key_values"] = inputs["past_key_values"].to_legacy_cache() + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, ctx_len - 1) + return inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7fd8ef94f..fde29942e 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -30,7 +30,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import FP16ClipTransform, RemoveCrossAttentionIOTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import get_compilation_dims from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -724,6 +724,7 @@ def from_pretrained( self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) self.continuous_batching = continuous_batching self.kv_offload = kv_offload + # self.model_name=pretrained_model_name_or_path self.is_tlm = is_tlm return self @@ -832,21 +833,33 @@ def _generate_inputs_mllama( vision_input[k] = v return lang_inputs, vision_input + def export( self, export_dir: Optional[str] = None, **kwargs, ) -> str: - self.kv_offload = True + if self.kv_offload: print("generating input") lang_inputs, vision_input = self._generate_inputs_mllama() print("generating vision model") + self.vision_export_path = self.export_vision(vision_input, export_dir) print("generating lang model") self.lang_export_path = self.export_lang(lang_inputs, export_dir) - + else: + self.model=ModelWrapper(self.model) + inputs,output_names, dynamic_axes=self.model.generate_mllama_single(self.processor) + print("Generating single qpc onnx") + self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir + ) + def export_vision(self, vision_input, export_dir): model = self.model self.vision_encoder = self.model = VisionEncoder(self.model) @@ -932,7 +945,7 @@ def export_lang(self, lang_inputs, export_dir): self.lang_output_names = lang_output_names model = self.model self.model = ModelWrapper(model) - + # self._onnx_transforms.append(RemoveCrossAttentionIOTransform) self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) self.model = model return self.lang_onnx_path @@ -995,17 +1008,34 @@ def compile( "max_image_tiles": "4", }, ] - + # num_devices=4 custom_io_lang = {} # Inputs for output_name in self.lang_output_names: if output_name.startswith("past_"): custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + + # key_to_remove=[] + # for names in self.vision_encoder.cross_attention_layers: + # key_to_remove.append(f"past_key.{names}") + # key_to_remove.append(f"past_value.{names}") + + # for key in key_to_remove: + # del custom_io_lang[key] + # outputs for output_name in self.lang_output_names: if output_name.startswith("past_"): custom_io_lang[output_name] = kv_cache_dtype + # key_to_remove=[] + # for names in self.vision_encoder.cross_attention_layers: + # key_to_remove.append(f"past_key.{names}_RetainedState") + # key_to_remove.append(f"past_value.{names}_RetainedState") + + # for key in key_to_remove: + # del custom_io_lang[key] + print("generating lang model") compiler_options.update({"retained-state": True}) self.lang_qpc_path = self._compile( @@ -1168,7 +1198,13 @@ def kv_offload_generate( stream: bool = True, **kwargs, ): + + # self.lang_qpc_path="/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" + self.lang_qpc_path="/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" + self.vision_qpc_path="/home/ubuntu/.cache/qeff_models/VisionEncoder-31e62a3c446b6bb9/qpc-7412e902c95a92c9/qpc" + lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) + vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) batch_size, ctx_len, fbs = get_compilation_dims(self.lang_qpc_path) @@ -1276,12 +1312,12 @@ def kv_offload_generate( decode_perf = (num_token - 1) / (end - loop_start) total_perf = num_token / (end - start) - print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) - print("E2ET:", round(end - start, 2), "s", file=sys.stderr) - print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) - print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) - print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) - if batch_size > 1: - print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) - print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) - print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) + # print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + # print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + # print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + # print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + # print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + # if batch_size > 1: + # print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + # print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + # print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index c3ad99f85..27257fa55 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -69,6 +69,7 @@ from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, @@ -167,6 +168,7 @@ from QEfficient.transformers.models.mllama.modeling_mllama import ( QEffMllamaCrossAttentionDecoderLayer, QEffMllamaForCausalLM, + QEffMllamaForConditionalGeneration, QEffMllamaRotaryEmbedding, QEffMllamaSelfAttentionDecoderLayer, QEffMllamaTextCrossAttention, @@ -258,14 +260,16 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, # mllama - MllamaForCausalLM: QEffMllamaForCausalLM, - MllamaTextModel: QEffMllamaTextModel, - MllamaVisionModel: QEffMllamaVisionModel, - MllamaTextSelfAttention: QEffMllamaTextSelfAttention, + MllamaTextRMSNorm: CustomRMSNormAIC, MllamaTextCrossAttention: QEffMllamaTextCrossAttention, - MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + MllamaTextSelfAttention: QEffMllamaTextSelfAttention, MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, + MllamaVisionModel: QEffMllamaVisionModel, + MllamaTextModel: QEffMllamaTextModel, + MllamaForCausalLM: QEffMllamaForCausalLM, + MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, From fd573d8eebd30927aceb540726573f564a5bbdd3 Mon Sep 17 00:00:00 2001 From: asmigosw Date: Mon, 3 Feb 2025 08:16:50 +0000 Subject: [PATCH 2/8] Minor changes Signed-off-by: Amit Raj --- .../transformers/models/modeling_auto.py | 70 +++++++++++++------ 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index fde29942e..d38a32dc4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -851,11 +851,10 @@ def export( self.lang_export_path = self.export_lang(lang_inputs, export_dir) else: self.model=ModelWrapper(self.model) - inputs,output_names, dynamic_axes=self.model.generate_mllama_single(self.processor) - print("Generating single qpc onnx") + self.inputs,self.output_names, dynamic_axes=self.model.generate_mllama_single(self.processor) self._export( - inputs, - output_names, + self.inputs, + self.output_names, dynamic_axes, export_dir=export_dir ) @@ -963,7 +962,6 @@ def compile( mxfp6_matmul: bool = False, **compiler_options, ) -> str: - self.kv_offload = True if self.kv_offload: model = self.model self.model = VisionEncoder(model) @@ -1015,27 +1013,11 @@ def compile( if output_name.startswith("past_"): custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype - # key_to_remove=[] - # for names in self.vision_encoder.cross_attention_layers: - # key_to_remove.append(f"past_key.{names}") - # key_to_remove.append(f"past_value.{names}") - - # for key in key_to_remove: - # del custom_io_lang[key] - # outputs for output_name in self.lang_output_names: if output_name.startswith("past_"): custom_io_lang[output_name] = kv_cache_dtype - # key_to_remove=[] - # for names in self.vision_encoder.cross_attention_layers: - # key_to_remove.append(f"past_key.{names}_RetainedState") - # key_to_remove.append(f"past_value.{names}_RetainedState") - - # for key in key_to_remove: - # del custom_io_lang[key] - print("generating lang model") compiler_options.update({"retained-state": True}) self.lang_qpc_path = self._compile( @@ -1052,6 +1034,52 @@ def compile( ) self.model = model return self.vision_qpc_path, self.lang_qpc_path + else: + + + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + ] + custom_io={} + kv_cache_dtype = "float16" + + #inputs + for input_name in self.output_names: + if input_name.endswith("_RetainedState"): + custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype + + # outputs + for output_name in self.output_names: + if output_name.endswith("_RetainedState"): + custom_io[output_name] = kv_cache_dtype + + compiler_options.update({"retained-state": True}) + self.lang_qpc_path = self._compile( + self.onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + **compiler_options, + ) + def generate( self, From bcff69412df26767511799264fd7a977f4303571 Mon Sep 17 00:00:00 2001 From: asmigosw Date: Mon, 3 Feb 2025 14:51:08 +0000 Subject: [PATCH 3/8] Single qpc support Signed-off-by: Amit Raj --- QEfficient/__init__.py | 7 +- QEfficient/base/modeling_qeff.py | 2 +- QEfficient/base/onnx_transforms.py | 10 +- .../models/mllama/modeling_mllama.py | 209 +++++------------- .../transformers/models/modeling_auto.py | 201 ++++++++++------- .../transformers/models/pytorch_transforms.py | 2 +- QEfficient/utils/constants.py | 4 +- 7 files changed, 196 insertions(+), 239 deletions(-) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 956ccf316..0481ace3e 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -25,7 +25,12 @@ def check_qaic_sdk(): # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" if QAIC_INSTALLED: - from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader,QEFFAutoModelForImageTextToText + from QEfficient.base import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, + QEFFCommonLoader, + ) from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e48475b78..3f8705c81 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -175,7 +175,7 @@ def _export( } if onnx_transform_kwargs is not None: transform_kwargs.update(onnx_transform_kwargs) - + for transform in self._onnx_transforms: model, transformed = transform.apply(model, **transform_kwargs) model.metadata_props.append( diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 802252ce2..4268736f8 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -106,10 +106,10 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar layers_to_remove = [3, 8, 13, 18, 23, 28, 33, 38] names_to_remove = [] for layer in layers_to_remove: - names_to_remove.append(f'past_key.{layer}_RetainedState') - names_to_remove.append(f'past_value.{layer}_RetainedState') - names_to_remove.append(f'past_key.{layer}') - names_to_remove.append(f'past_value.{layer}') + names_to_remove.append(f"past_key.{layer}_RetainedState") + names_to_remove.append(f"past_value.{layer}_RetainedState") + names_to_remove.append(f"past_key.{layer}") + names_to_remove.append(f"past_value.{layer}") graph = model.graph transformed = False @@ -136,4 +136,4 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar # graph.input.remove(input_to_remove) # transformed = True - return model, transformed \ No newline at end of file + return model, transformed diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index b4bdd4339..ac7c011e9 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -17,7 +17,6 @@ from PIL import Image from torch import nn from torch.nn import CrossEntropyLoss -from transformers import AutoProcessor from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( @@ -74,6 +73,8 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) + + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, @@ -95,14 +96,13 @@ def _prepare_cross_attention_mask( # last dimension contains negative infinity values, otherwise it's 1 negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value) - .any(dim=-1) - .type_as(cross_attention_mask)[..., None] + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] ) cross_attention_mask *= full_text_row_masked_out_mask return cross_attention_mask, full_text_row_masked_out_mask + def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, num_patches: int, @@ -124,15 +124,12 @@ def _prepare_aspect_ratio_attention_mask( # Reshape to 2D and create 4D attention mask # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) - attention_mask = ( - attention_mask - @ attention_mask.transpose(-1, -2) - * torch.tensor(-10000.0, dtype=torch.float32) - ) + attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) attention_mask = attention_mask.unsqueeze(1) return attention_mask + def _create_causal_mask( position_ids, target_length, @@ -150,9 +147,7 @@ def _create_causal_mask( pos_max = position_ids.max(1, keepdim=True).values kv_start = (pos_max // target_length) * target_length kv_indices_high = kv_indices + kv_start - kv_indices_low = torch.where( - kv_indices_high < target_length, kv_indices, kv_indices_high - target_length - ) + kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) kv_indices = kv_indices.unsqueeze(1) # ------ @@ -206,12 +201,8 @@ def forward( # Compute new KV states key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) # if past_key_value is not None: # # if we have a new image + new tokens, we only computed key_states on that new image # # we still update the cross key states, past_image, new_image. And use it! @@ -242,9 +233,7 @@ def forward( key_states = self.k_norm(key_states) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -253,9 +242,7 @@ def forward( # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights # ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -267,6 +254,7 @@ def forward( return attn_output, attn_weights, past_key_value + class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -304,12 +292,8 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -322,9 +306,7 @@ def forward( kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -334,26 +316,18 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, } - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where( - attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights - ) + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) @@ -388,9 +362,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -541,9 +513,7 @@ def __init__( else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -552,9 +522,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -566,9 +534,7 @@ def __init__( def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( - self.inv_freq - ) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) @@ -597,23 +563,15 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( - pixel_values.shape - ) + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, height, width - ) + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) # Patch embedding @@ -626,16 +584,12 @@ def forward( hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) # Add cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # Position embeddings - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) hidden_state = self.layernorm_pre(hidden_state) @@ -695,16 +649,12 @@ def forward( batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim ) hidden_state = hidden_state[:, :, :slice_index] - hidden_state = hidden_state.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, dim - ) + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) # Collect intermediate layer outputs from encoder output all_intermediate_hidden_states = output[1] intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) - intermediate_hidden_states = intermediate_hidden_states[ - ..., self.intermediate_layers_indices - ] + intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] # Remove padding from intermediate hidden states intermediate_hidden_states = intermediate_hidden_states.reshape( @@ -725,9 +675,7 @@ def forward( if output_attentions: # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range - global_attn = ( - tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) - ) + global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) attentions = tuple(output[2]) + global_attn else: attentions = None @@ -766,13 +714,9 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -792,16 +736,12 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance( - past_key_values, Cache - ): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -850,11 +790,7 @@ def forward( # TODO: vbaddi: since past_key_values are retained from previous states, the condition for is_cross_attention_cache_empty is False # so explicitly making it true in order to skip the cross attention for language model # comment once there is vision and cross attention support - if ( - is_cross_attention_layer - and cross_attention_states is None - and is_cross_attention_cache_empty - ): + if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: continue if self.gradient_checkpointing and self.training: @@ -921,11 +857,7 @@ def forward( next_cache = next_cache.to_legacy_cache() if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -970,11 +902,7 @@ def _update_causal_mask( # TODO: vbaddi: unused, comment to fix linters # sequence_length = input_tensor.shape[1] - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens - ) + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) @@ -1019,13 +947,9 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1101,13 +1025,9 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1122,15 +1042,11 @@ def forward( ) if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") if pixel_values is not None: if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") # get vision tokens from vision model vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -1179,13 +1095,12 @@ def forward( return outputs + class VisionEncoder(nn.Module): def __init__(self, mllama: MllamaForConditionalGeneration): super().__init__() self.mllama = mllama - self.cross_attention_layers = ( - self.mllama.config.get_text_config().cross_attention_layers - ) + self.cross_attention_layers = self.mllama.config.get_text_config().cross_attention_layers self.config = self.mllama.config.get_text_config() def forward( @@ -1200,9 +1115,9 @@ def forward( aspect_ratio_mask=aspect_ratio_mask, ) cross_attention_states = vision_outputs[0] - cross_attention_states = self.mllama.multi_modal_projector( - cross_attention_states - ).reshape(-1, cross_attention_states.shape[-2], self.mllama.hidden_size) + cross_attention_states = self.mllama.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.mllama.hidden_size + ) bsz = pixel_values.shape[0] outputs = [] @@ -1210,16 +1125,15 @@ def forward( cross_attn = self.mllama.language_model.model.layers[i].cross_attn key_states = cross_attn.k_proj(cross_attention_states) value_states = cross_attn.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim - ).transpose(1, 2) + key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose( + 1, 2 + ) outputs.append((key_states, value_states)) return outputs + class ModelWrapper(nn.Module): def __init__(self, mllama): super().__init__() @@ -1271,6 +1185,7 @@ def forward( if "past_key_values" in outputs: outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() return outputs + def generate_mllama_single(self, processor): ctx_len = 1024 txt_cfg = self.mllama.config.get_text_config() @@ -1312,19 +1227,13 @@ def generate_mllama_single(self, processor): if i in cross_attention_layers: idx = cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - inputs["past_key_values"].key_cache[i] = torch.zeros( - 1, num_key_value_heads, image_tokens_len, head_dim - ) + inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim) inputs["past_key_values"].value_cache[i] = torch.zeros( 1, num_key_value_heads, image_tokens_len, head_dim ) else: - inputs["past_key_values"].key_cache[i] = torch.zeros( - 1, num_key_value_heads, ctx_len, head_dim - ) - inputs["past_key_values"].value_cache[i] = torch.zeros( - 1, num_key_value_heads, ctx_len, head_dim - ) + inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) output_names = [ "logits", diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d38a32dc4..7f95ecf3b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -17,6 +17,7 @@ import requests import torch import torch.nn as nn +import transformers from PIL import Image from transformers import ( AutoModel, @@ -30,7 +31,7 @@ import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, RemoveCrossAttentionIOTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import get_compilation_dims from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -722,6 +723,7 @@ def from_pretrained( self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) + self.tokenizer = self.processor.tokenizer self.continuous_batching = continuous_batching self.kv_offload = kv_offload # self.model_name=pretrained_model_name_or_path @@ -833,32 +835,25 @@ def _generate_inputs_mllama( vision_input[k] = v return lang_inputs, vision_input - def export( self, export_dir: Optional[str] = None, **kwargs, ) -> str: - if self.kv_offload: print("generating input") lang_inputs, vision_input = self._generate_inputs_mllama() print("generating vision model") - + self.vision_export_path = self.export_vision(vision_input, export_dir) print("generating lang model") self.lang_export_path = self.export_lang(lang_inputs, export_dir) else: - self.model=ModelWrapper(self.model) - self.inputs,self.output_names, dynamic_axes=self.model.generate_mllama_single(self.processor) - self._export( - self.inputs, - self.output_names, - dynamic_axes, - export_dir=export_dir - ) - + self.model = ModelWrapper(self.model) + self.inputs, self.output_names, dynamic_axes = self.model.generate_mllama_single(self.processor) + self._export(self.inputs, self.output_names, dynamic_axes, export_dir=export_dir) + def export_vision(self, vision_input, export_dir): model = self.model self.vision_encoder = self.model = VisionEncoder(self.model) @@ -1035,8 +1030,6 @@ def compile( self.model = model return self.vision_qpc_path, self.lang_qpc_path else: - - specializations = [ { "batch_size": batch_size, @@ -1053,12 +1046,12 @@ def compile( "max_image_tiles": "4", }, ] - custom_io={} + custom_io = {} kv_cache_dtype = "float16" - #inputs + # inputs for input_name in self.output_names: - if input_name.endswith("_RetainedState"): + if input_name.endswith("_RetainedState"): custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype # outputs @@ -1079,7 +1072,6 @@ def compile( custom_io=custom_io, **compiler_options, ) - def generate( self, @@ -1105,71 +1097,121 @@ def generate( if self.kv_offload: self.kv_offload_generate(inputs, streamer, device_ids) else: - return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + return self.cloud_ai_100_generate(inputs=inputs, device_ids=device_ids) # PyTorch runtime else: return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) - # TODO: Add the code based on how we did in single inference script - def cloud_ai_100_vlm_generate( + def cloud_ai_100_generate( self, inputs: torch.Tensor, device_ids: List[int] = [0], + enable_debug_logs: bool = False, ) -> np.ndarray: - """ - Generates features with list of prompts using AI 100 runtime. - - ``Mandatory`` Args: - :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. - ``Optional`` Args: - device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + qpc_session = QAICInferenceSession( + self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False + ) - Returns: - np.ndarray: A list of dictionaries containing the generated output features. - """ + batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) - if self.qpc_session is None: - self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) - self.batch_size = self.qpc_session.bindings[0].dims[0] - self.seq_len = self.qpc_session.bindings[0].dims[1] # Skip inputs/outputs - self.qpc_session.skip_buffers( - [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] - + ["pixel_values_RetainedState", "image_sizes_RetainedState"] + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] ) # Read prompt and ctx len from session - # batch_size = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] - # ) - - # prefill_seq_len = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] - # ) - # Prepare input - input_ids_len = inputs["input_ids"].shape[1] - input_ids = np.array( - torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + batch_size = max( + [x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]] ) - attention_mask = np.array( - torch.nn.functional.pad( - inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 - ) + + prefill_seq_len = max( + [x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes] + + [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]] ) - inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + # lang_inputs = tokenizer(prompt, return_tensors="np", padding=True) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + generation_len = None + if generation_len is None: + generation_len = ctx_len - input_len.max() - outputs = { - "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( - np.float32 - ), - } - self.qpc_session.set_buffers(outputs) - outputs = self.qpc_session.run(inputs) - outputs = outputs["output"][:, :input_ids_len, :] - return outputs + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), self.tokenizer.pad_token_id) + stream = None + if stream: + streamer = transformers.TextStreamer(self.tokenizer) + + # Prepare inputs for prefill + start = perf_counter() + + inputs["position_ids"] = np.where( + inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + inputs = dict(inputs) + + # vision_session.deactivate() + qpc_session.activate() + + # Run prefill + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + outputs = qpc_session.run(chunk_inputs) + + # Skip inputs/outputs again + qpc_session.skip_buffers( + [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] + ) + + # Get first token + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] = input_len + inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + finished_sequences = inputs["input_ids"] == self.tokenizer.eos_token_id + if stream: + streamer.put(inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = qpc_session.run(inputs) + + # Prepare inputs for next iteration + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] += 1 + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) + finished_sequences |= inputs["input_ids"] == self.tokenizer.eos_token_id + if stream: + streamer.put(inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if stream: + streamer.end() + generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(1 if stream else 0, batch_size): + print(i, generated_texts[i]) + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) def pytorch_vlm_generate( self, @@ -1226,11 +1268,12 @@ def kv_offload_generate( stream: bool = True, **kwargs, ): - # self.lang_qpc_path="/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" - self.lang_qpc_path="/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" - self.vision_qpc_path="/home/ubuntu/.cache/qeff_models/VisionEncoder-31e62a3c446b6bb9/qpc-7412e902c95a92c9/qpc" - + self.lang_qpc_path = ( + "/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" + ) + self.vision_qpc_path = "/home/ubuntu/.cache/qeff_models/VisionEncoder-31e62a3c446b6bb9/qpc-7412e902c95a92c9/qpc" + lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) @@ -1340,12 +1383,12 @@ def kv_offload_generate( decode_perf = (num_token - 1) / (end - loop_start) total_perf = num_token / (end - start) - # print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) - # print("E2ET:", round(end - start, 2), "s", file=sys.stderr) - # print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) - # print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) - # print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) - # if batch_size > 1: - # print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) - # print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) - # print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 27257fa55..3580d4fda 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -353,4 +353,4 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed \ No newline at end of file + return model, transformed diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 462acf169..001acc8e3 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,11 +49,11 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 -ONNX_EXPORT_MAX_NUM_IMAGES =1 +ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 ONNX_EXPORT_IMAGE_LENGHT = 560 -ONNX_EXPORT_IMAGE_DEPTH =3 +ONNX_EXPORT_IMAGE_DEPTH = 3 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] From ddc9bfc81a66edc9897fc6ed90768932c6e33bc9 Mon Sep 17 00:00:00 2001 From: asmigosw Date: Mon, 3 Feb 2025 16:11:39 +0000 Subject: [PATCH 4/8] Minor fixes-1 Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 2 +- .../transformers/models/modeling_auto.py | 61 +------------------ 2 files changed, 3 insertions(+), 60 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index ac7c011e9..dd8e276d3 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -1186,7 +1186,7 @@ def forward( outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() return outputs - def generate_mllama_single(self, processor): + def generate_input(self, processor): ctx_len = 1024 txt_cfg = self.mllama.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 7f95ecf3b..5d6305097 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -741,64 +741,7 @@ def model_hash(self) -> str: mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash - - def _generate_inputs(self, **kwargs): - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS - - self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len - - ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now - # TODO: Create a map for the other models to have their own inputs accordingly - images = [] - placeholder = "" - - # Note: if OOM, you might consider reduce number of frames in this example. - for i in range(1, 2): - url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" - images.append(Image.open(requests.get(url, stream=True).raw)) - placeholder += f"<|image_{1}|>\n" - - messages = [ - {"role": "user", "content": placeholder + "Summarize the deck of slides."}, - ] - - prompt = self.processor.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - inputs = dict(self.processor(images=images, text=prompt, return_tensors="pt")) - inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - inputs["past_key_values"] = [] - for i in range(self.num_layers): - inputs["past_key_values"].append( - ( - torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), - torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), - ) - ) - output_names = [ - "logits", - "pixel_values_RetainedState", - "image_sizes_RetainedState", - *[f"past_{kv}.{i}_RetainedState" for i in range(self.num_layers) for kv in ["key", "value"]], - ] - dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - # "pixel_values": {0: "img_batch_size"}, - } - for i in range(self.num_layers): - dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - - # Avoid issues due to index out of range - inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) - - return inputs, dynamic_axes, output_names - + def _generate_inputs_mllama( self, ): @@ -851,7 +794,7 @@ def export( self.lang_export_path = self.export_lang(lang_inputs, export_dir) else: self.model = ModelWrapper(self.model) - self.inputs, self.output_names, dynamic_axes = self.model.generate_mllama_single(self.processor) + self.inputs, self.output_names, dynamic_axes = self.model.generate_inputs(self.processor) self._export(self.inputs, self.output_names, dynamic_axes, export_dir=export_dir) def export_vision(self, vision_input, export_dir): From 15161ced136c130abc7057788550cab9a5c5c78a Mon Sep 17 00:00:00 2001 From: asmigosw Date: Mon, 3 Feb 2025 18:37:35 +0000 Subject: [PATCH 5/8] Generate input restructure Signed-off-by: Amit Raj --- QEfficient/base/modeling_qeff.py | 3 +- .../models/mllama/modeling_mllama.py | 128 +++++++++----- .../transformers/models/modeling_auto.py | 156 ++++-------------- QEfficient/utils/constants.py | 1 + 4 files changed, 118 insertions(+), 170 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3f8705c81..b77279dcf 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -114,6 +114,7 @@ def compile(self, *args, **kwargs) -> Path: def _export( self, + model, example_inputs: Dict[str, torch.Tensor], output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]], @@ -157,7 +158,7 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs torch.onnx.export( - self.model, + model, (example_inputs,), str(tmp_onnx_path), input_names=input_names, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index dd8e276d3..6433197b7 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -42,6 +42,16 @@ ) from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.utils import constants +from QEfficient.utils.constants import Constants + +bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE +max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES +max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES +image_length = constants.ONNX_EXPORT_IMAGE_LENGHT +image_width = constants.ONNX_EXPORT_IMAGE_WIDTH +num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH +seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): @@ -1186,8 +1196,45 @@ def forward( outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() return outputs - def generate_input(self, processor): - ctx_len = 1024 + def generate_input(self, processor, kv_offload): + + #vision_inputs + vision_inputs = { + "pixel_values": torch.zeros((bs, max_num_images,max_image_tiles,num_channel, image_length, image_width ), dtype=torch.int64), + "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64) + } + + vision_output_names = [] + for i in self.mllama.config.text_config.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": { + 0: "batch_size", + 1: "max_num_images", + 2: "max_image_tiles", + }, + } + + #lang_inputs + lang_inputs = { + "input_ids": torch.zeros((bs,seq_len),dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "cross_attention_mask": torch.ones((bs, max_image_tiles),dtype=torch.int64), + "attention_mask": torch.ones((bs,seq_len),dtype=torch.int64) + } + + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + + ctx_len = Constants.CTX_LEN txt_cfg = self.mllama.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1198,54 +1245,32 @@ def generate_input(self, processor): num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 image_tokens_len = vis_cfg.max_num_tiles * num_patches - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" - image = Image.open(requests.get(url, stream=True).raw) - conversation = [ - { - "role": "user", - "content": [ - {"type": "image"}, - { - "type": "text", - "text": "How long does it take from invoice date to due date? Be short and concise.", - }, - ], - } - ] - prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - inputs = processor(text=prompt, images=image, return_tensors="pt", add_special_tokens=False) - inputs["position_ids"] = torch.where( - inputs.pop("attention_mask") == 1, - torch.arange(inputs["input_ids"].shape[1]).view(1, -1), - -1, - ) - inputs = dict(inputs) - inputs["past_key_values"] = DynamicCache(num_hidden_layers) - inputs["past_key_values"].key_cache = [0] * num_hidden_layers - inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + lang_inputs["past_key_values"] = DynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + for i in range(num_hidden_layers): if i in cross_attention_layers: idx = cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim) - inputs["past_key_values"].value_cache[i] = torch.zeros( + lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros( 1, num_key_value_heads, image_tokens_len, head_dim ) else: - inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) - inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) - output_names = [ + + lang_output_names = [ "logits", - # "pixel_values_RetainedState", *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], ] - dynamic_axes = { + + lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, - "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, "cross_attention_mask": { 0: "batch_size", 1: "seq_len", @@ -1253,14 +1278,31 @@ def generate_input(self, processor): 3: "max_image_tiles", }, } + for i in range(num_hidden_layers): if i in cross_attention_layers: - dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} - dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} else: - dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, ctx_len - 1) + + inputs = [] + output_names = [] + dynamic_axes = [] + + if kv_offload: + inputs.extend([vision_inputs, lang_inputs]) + output_names.extend([vision_output_names, lang_output_names]) + dynamic_axes.extend([vision_dynamic_axes, lang_dynamic_axes]) + else: + inputs.append({**vision_inputs, **lang_inputs}) + output_names = vision_output_names + lang_output_names + dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes}) - inputs["past_key_values"] = inputs["past_key_values"].to_legacy_cache() - inputs["position_ids"] = torch.full(inputs["position_ids"].shape, ctx_len - 1) return inputs, output_names, dynamic_axes + + \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5d6305097..185467558 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -741,150 +741,54 @@ def model_hash(self) -> str: mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash - - def _generate_inputs_mllama( - self, - ): - url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" - image = Image.open(requests.get(url, stream=True).raw) - - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}, - ], - } - ] - input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True) - - split_inputs = self.processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=32, - ) - - lang_inputs = {} - vision_input = {} - - for k, v in split_inputs.items(): - if k in ["input_ids", "attention_mask", "cross_attention_mask"]: - lang_inputs[k] = v - else: - vision_input[k] = v - - return lang_inputs, vision_input def export( self, export_dir: Optional[str] = None, **kwargs, - ) -> str: - if self.kv_offload: - print("generating input") - lang_inputs, vision_input = self._generate_inputs_mllama() - print("generating vision model") + ) -> str: - self.vision_export_path = self.export_vision(vision_input, export_dir) - print("generating lang model") - self.lang_export_path = self.export_lang(lang_inputs, export_dir) + self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.processor) + if self.kv_offload: + self.vision_export_path = self.export_vision(export_dir) + self.lang_export_path = self.export_lang(export_dir) else: self.model = ModelWrapper(self.model) - self.inputs, self.output_names, dynamic_axes = self.model.generate_inputs(self.processor) - self._export(self.inputs, self.output_names, dynamic_axes, export_dir=export_dir) - - def export_vision(self, vision_input, export_dir): - model = self.model - self.vision_encoder = self.model = VisionEncoder(self.model) - - vision_output_names = [] - for i in self.model.cross_attention_layers: - vision_output_names.append(f"past_key.{i}") - vision_output_names.append(f"past_value.{i}") - vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, - "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": { - 0: "batch_size", - 1: "max_num_images", - 2: "max_image_tiles", - }, - } + self._export(self.model, self.inputs[0], self.output_names[0], self.dynamic_axes[0], export_dir=export_dir) + + def export_vision(self, export_dir): + + self.vision_encoder_model=VisionEncoder(self.model) + + vision_inputs=self.inputs[0] + vision_output_names=self.output_names[0] + vision_dynamic_axes=self.dynamic_axes[0] self.vision_onnx_path = self._export( - vision_input, + self.vision_encoder_model, + vision_inputs, vision_output_names, vision_dynamic_axes, export_dir=export_dir, ) - self.model = model - self.vision_output_names = vision_output_names return self.vision_onnx_path - def export_lang(self, lang_inputs, export_dir): - self.num_layers = num_hidden_layers = self.model.config.get_text_config().num_hidden_layers - - lang_inputs["position_ids"] = torch.where( - lang_inputs.pop("attention_mask") == 1, - torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), - -1, - ) - - lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) - lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers - lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers - - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - idx = self.vision_encoder.cross_attention_layers.index(i) - assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) - else: - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128)) - - lang_inputs["position_ids"] = torch.full((1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1) - lang_output_names = ["logits", "past_key_values"] - pkv_idx = lang_output_names.index("past_key_values") + def export_lang(self, export_dir): + self.lang_model= ModelWrapper(self.model) - lang_output_names[pkv_idx : pkv_idx + 1] = [ - f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"] - ] + lang_inputs=self.inputs[1] + lang_output_names=self.output_names[1] + lang_dynamic_axes=self.dynamic_axes[1] - lang_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - "cross_attention_mask": { - 0: "batch_size", - 1: "seq_len", - 2: "max_num_images", - 3: "max_image_tiles", - }, - } - - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} - continue - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - - lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() - lang_inputs["input_ids"] = torch.tensor([[374]]) - lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] - self.lang_output_names = lang_output_names - model = self.model - self.model = ModelWrapper(model) - # self._onnx_transforms.append(RemoveCrossAttentionIOTransform) - self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) - self.model = model + self.lang_onnx_path = self._export( + self.lang_model, + lang_inputs, + lang_output_names, + lang_dynamic_axes, + export_dir=export_dir + ) + return self.lang_onnx_path def compile( @@ -903,7 +807,7 @@ def compile( if self.kv_offload: model = self.model self.model = VisionEncoder(model) - vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] + vision_specializations = [{"batch_size": batch_size, "max_num_images": "1", "max_image_tiles": "4"}] custom_io = {} kv_cache_dtype = "float16" diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 001acc8e3..028dd13b7 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -54,6 +54,7 @@ def get_models_dir(): ONNX_EXPORT_IMAGE_WIDTH = 560 ONNX_EXPORT_IMAGE_LENGHT = 560 ONNX_EXPORT_IMAGE_DEPTH = 3 +ONNX_EXPORT_CTX_LEN = 1024 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] From a3271c1d65093923761054eb3779cd1dab11cd4e Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 3 Feb 2025 19:04:20 +0000 Subject: [PATCH 6/8] ruff fix Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 35 +++++++++---------- .../transformers/models/modeling_auto.py | 33 +++++++---------- 2 files changed, 28 insertions(+), 40 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 6433197b7..7a4835e03 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -10,11 +10,9 @@ import math from typing import List, Optional, Tuple, Union -import requests import torch import torch.nn.functional as F import torch.utils.checkpoint -from PIL import Image from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache @@ -1197,12 +1195,13 @@ def forward( return outputs def generate_input(self, processor, kv_offload): - - #vision_inputs + # vision_inputs vision_inputs = { - "pixel_values": torch.zeros((bs, max_num_images,max_image_tiles,num_channel, image_length, image_width ), dtype=torch.int64), + "pixel_values": torch.zeros( + (bs, max_num_images, max_image_tiles, num_channel, image_length, image_width), dtype=torch.int64 + ), "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), - "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64) + "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles, 1), dtype=torch.int64), } vision_output_names = [] @@ -1220,19 +1219,19 @@ def generate_input(self, processor, kv_offload): }, } - #lang_inputs + # lang_inputs lang_inputs = { - "input_ids": torch.zeros((bs,seq_len),dtype=torch.int64), + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), - "cross_attention_mask": torch.ones((bs, max_image_tiles),dtype=torch.int64), - "attention_mask": torch.ones((bs,seq_len),dtype=torch.int64) + "cross_attention_mask": torch.ones((bs, max_image_tiles), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), } lang_inputs["position_ids"] = torch.where( lang_inputs.pop("attention_mask") == 1, torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), -1, - ) + ) ctx_len = Constants.CTX_LEN txt_cfg = self.mllama.config.get_text_config() @@ -1245,7 +1244,6 @@ def generate_input(self, processor, kv_offload): num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 image_tokens_len = vis_cfg.max_num_tiles * num_patches - lang_inputs["past_key_values"] = DynamicCache(num_hidden_layers) lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers @@ -1254,7 +1252,9 @@ def generate_input(self, processor, kv_offload): if i in cross_attention_layers: idx = cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, image_tokens_len, head_dim) + lang_inputs["past_key_values"].key_cache[i] = torch.zeros( + 1, num_key_value_heads, image_tokens_len, head_dim + ) lang_inputs["past_key_values"].value_cache[i] = torch.zeros( 1, num_key_value_heads, image_tokens_len, head_dim ) @@ -1262,12 +1262,11 @@ def generate_input(self, processor, kv_offload): lang_inputs["past_key_values"].key_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) lang_inputs["past_key_values"].value_cache[i] = torch.zeros(1, num_key_value_heads, ctx_len, head_dim) - lang_output_names = [ "logits", *[f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"]], ] - + lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, @@ -1286,10 +1285,10 @@ def generate_input(self, processor, kv_offload): else: lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, ctx_len - 1) - + inputs = [] output_names = [] dynamic_axes = [] @@ -1304,5 +1303,3 @@ def generate_input(self, processor, kv_offload): dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes}) return inputs, output_names, dynamic_axes - - \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 185467558..2e0890fee 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,11 +14,9 @@ from typing import List, Optional, Union import numpy as np -import requests import torch import torch.nn as nn import transformers -from PIL import Image from transformers import ( AutoModel, AutoModelForCausalLM, @@ -34,7 +32,6 @@ from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.generation.text_generation_inference import get_compilation_dims -from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers @@ -746,8 +743,7 @@ def export( self, export_dir: Optional[str] = None, **kwargs, - ) -> str: - + ) -> str: self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.processor) if self.kv_offload: self.vision_export_path = self.export_vision(export_dir) @@ -757,12 +753,11 @@ def export( self._export(self.model, self.inputs[0], self.output_names[0], self.dynamic_axes[0], export_dir=export_dir) def export_vision(self, export_dir): - - self.vision_encoder_model=VisionEncoder(self.model) + self.vision_encoder_model = VisionEncoder(self.model) - vision_inputs=self.inputs[0] - vision_output_names=self.output_names[0] - vision_dynamic_axes=self.dynamic_axes[0] + vision_inputs = self.inputs[0] + vision_output_names = self.output_names[0] + vision_dynamic_axes = self.dynamic_axes[0] self.vision_onnx_path = self._export( self.vision_encoder_model, @@ -775,20 +770,16 @@ def export_vision(self, export_dir): return self.vision_onnx_path def export_lang(self, export_dir): - self.lang_model= ModelWrapper(self.model) + self.lang_model = ModelWrapper(self.model) - lang_inputs=self.inputs[1] - lang_output_names=self.output_names[1] - lang_dynamic_axes=self.dynamic_axes[1] + lang_inputs = self.inputs[1] + lang_output_names = self.output_names[1] + lang_dynamic_axes = self.dynamic_axes[1] self.lang_onnx_path = self._export( - self.lang_model, - lang_inputs, - lang_output_names, - lang_dynamic_axes, - export_dir=export_dir - ) - + self.lang_model, lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir + ) + return self.lang_onnx_path def compile( From eaaa5178591b35447f43b6019c00f3c51335ead7 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 4 Feb 2025 04:43:10 +0000 Subject: [PATCH 7/8] Generate input fix Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 190 +++++++++--------- .../transformers/models/modeling_auto.py | 2 +- 2 files changed, 96 insertions(+), 96 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 7a4835e03..3a6ae903f 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -1103,98 +1103,7 @@ def forward( return outputs - -class VisionEncoder(nn.Module): - def __init__(self, mllama: MllamaForConditionalGeneration): - super().__init__() - self.mllama = mllama - self.cross_attention_layers = self.mllama.config.get_text_config().cross_attention_layers - self.config = self.mllama.config.get_text_config() - - def forward( - self, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - ) -> List[Tuple[torch.Tensor]]: - vision_outputs = self.mllama.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.mllama.multi_modal_projector(cross_attention_states).reshape( - -1, cross_attention_states.shape[-2], self.mllama.hidden_size - ) - - bsz = pixel_values.shape[0] - outputs = [] - for i in self.cross_attention_layers: - cross_attn = self.mllama.language_model.model.layers[i].cross_attn - key_states = cross_attn.k_proj(cross_attention_states) - value_states = cross_attn.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose( - 1, 2 - ) - - outputs.append((key_states, value_states)) - return outputs - - -class ModelWrapper(nn.Module): - def __init__(self, mllama): - super().__init__() - self.mllama = mllama - self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers - self.config = self.mllama.config.get_text_config() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ): - if past_key_values is not None: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - outputs = self.mllama( - input_ids=input_ids, - pixel_values=pixel_values, - aspect_ratio_mask=aspect_ratio_mask, - aspect_ratio_ids=aspect_ratio_ids, - attention_mask=attention_mask, - cross_attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - ) - if "past_key_values" in outputs: - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() - return outputs - - def generate_input(self, processor, kv_offload): + def generate_input(self, kv_offload): # vision_inputs vision_inputs = { "pixel_values": torch.zeros( @@ -1205,7 +1114,7 @@ def generate_input(self, processor, kv_offload): } vision_output_names = [] - for i in self.mllama.config.text_config.cross_attention_layers: + for i in self.config.text_config.cross_attention_layers: vision_output_names.append(f"past_key.{i}") vision_output_names.append(f"past_value.{i}") @@ -1234,13 +1143,13 @@ def generate_input(self, processor, kv_offload): ) ctx_len = Constants.CTX_LEN - txt_cfg = self.mllama.config.get_text_config() + txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers num_key_value_heads = txt_cfg.num_key_value_heads head_dim = txt_cfg.hidden_size // txt_cfg.num_attention_heads - vis_cfg = self.mllama.config.vision_config + vis_cfg = self.config.vision_config num_patches = (vis_cfg.image_size // vis_cfg.patch_size) ** 2 + 1 image_tokens_len = vis_cfg.max_num_tiles * num_patches @@ -1303,3 +1212,94 @@ def generate_input(self, processor, kv_offload): dynamic_axes.append({**vision_dynamic_axes, **lang_dynamic_axes}) return inputs, output_names, dynamic_axes + + +class VisionEncoder(nn.Module): + def __init__(self, mllama: MllamaForConditionalGeneration): + super().__init__() + self.mllama = mllama + self.cross_attention_layers = self.mllama.config.get_text_config().cross_attention_layers + self.config = self.mllama.config.get_text_config() + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + ) -> List[Tuple[torch.Tensor]]: + vision_outputs = self.mllama.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.mllama.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.mllama.hidden_size + ) + + bsz = pixel_values.shape[0] + outputs = [] + for i in self.cross_attention_layers: + cross_attn = self.mllama.language_model.model.layers[i].cross_attn + key_states = cross_attn.k_proj(cross_attention_states) + value_states = cross_attn.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose( + 1, 2 + ) + + outputs.append((key_states, value_states)) + return outputs + + +class ModelWrapper(nn.Module): + def __init__(self, mllama): + super().__init__() + self.mllama = mllama + self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers + self.config = self.mllama.config.get_text_config() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + if past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + outputs = self.mllama( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + if "past_key_values" in outputs: + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return outputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2e0890fee..a277b1ce8 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -744,7 +744,7 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: - self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.processor) + self.inputs, self.output_names, self.dynamic_axes = self.model.generate_input(self.kv_offload) if self.kv_offload: self.vision_export_path = self.export_vision(export_dir) self.lang_export_path = self.export_lang(export_dir) From 67cb5efd8609de978ef66c0cd76b069658746157 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 4 Feb 2025 05:03:45 +0000 Subject: [PATCH 8/8] Addressed Comments Signed-off-by: Amit Raj --- QEfficient/base/onnx_transforms.py | 46 --------- QEfficient/transformers/modeling_utils.py | 95 +++++++++++++++++- .../models/mllama/modeling_mllama.py | 97 +------------------ .../transformers/models/modeling_auto.py | 6 -- 4 files changed, 99 insertions(+), 145 deletions(-) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 4268736f8..543ec4e2d 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -91,49 +91,3 @@ def apply( current_file_size = tsize external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") return model, transformed - - -class RemoveCrossAttentionIOTransform(OnnxTransform): - """ - Removes the input and output names of cross-attention layers. - """ - - @classmethod - def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: - """ - :param onnx_base_dir: Base directory to load tensors (if not already loaded). - """ - layers_to_remove = [3, 8, 13, 18, 23, 28, 33, 38] - names_to_remove = [] - for layer in layers_to_remove: - names_to_remove.append(f"past_key.{layer}_RetainedState") - names_to_remove.append(f"past_value.{layer}_RetainedState") - names_to_remove.append(f"past_key.{layer}") - names_to_remove.append(f"past_value.{layer}") - - graph = model.graph - transformed = False - - # Remove outputs - for name in names_to_remove: - output_to_remove = None - for output in graph.output: - if output.name == name: - output_to_remove = output - break - if output_to_remove: - graph.output.remove(output_to_remove) - transformed = True - - # # Remove inputs - # for name in names_to_remove: - # input_to_remove = None - # for input in graph.input: - # if input.name == name: - # input_to_remove = input - # break - # if input_to_remove: - # graph.input.remove(input_to_remove) - # transformed = True - - return model, transformed diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f749cc0c3..23364655f 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -6,8 +6,9 @@ # ----------------------------------------------------------------------------- from collections import namedtuple -from typing import Dict, Type +from typing import Dict, Optional, Tuple, Type +import torch import torch.nn as nn from transformers.models.codegen.modeling_codegen import ( CodeGenAttention, @@ -242,3 +243,95 @@ GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, } + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 3a6ae903f..76f4bd102 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -40,6 +40,11 @@ ) from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_utils import ( + _create_causal_mask, + _prepare_aspect_ratio_attention_mask, + _prepare_cross_attention_mask, +) from QEfficient.utils import constants from QEfficient.utils.constants import Constants @@ -83,98 +88,6 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): return q_embed.to(q.dtype), k_embed.to(k.dtype) -def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - num_vision_tokens: int, - dtype: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - # reshape so it can be used by attn module - batch_size, text_total_length, *_ = cross_attention_mask.shape - cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) - cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) - cross_attention_mask = cross_attention_mask.unsqueeze(1) - - # invert the mask - inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) - ) - - # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - return cross_attention_mask, full_text_row_masked_out_mask - - -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) - attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) - attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - -def _create_causal_mask( - position_ids, - target_length, - sliding_window: Optional[int] = None, -): - """ - A utility attention mask class that allows one to: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - """ - if sliding_window is not None: - query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, -1) - # --- Rolling buffer --- - pos_max = position_ids.max(1, keepdim=True).values - kv_start = (pos_max // target_length) * target_length - kv_indices_high = kv_indices + kv_start - kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) - kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) - kv_indices = kv_indices.unsqueeze(1) - # ------ - causal_mask = kv_indices > query_indices - attention_mask = causal_mask - - window_indices = query_indices - sliding_window + 1 - window_mask = kv_indices < window_indices - attention_mask = attention_mask | window_mask - attention_mask = attention_mask.unsqueeze(1) - else: - query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, 1, -1) - attention_mask = kv_indices > query_indices - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index a277b1ce8..c4558cb3d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1106,12 +1106,6 @@ def kv_offload_generate( stream: bool = True, **kwargs, ): - # self.lang_qpc_path="/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" - self.lang_qpc_path = ( - "/home/ubuntu/.cache/qeff_models/ModelWrapper-31e62a3c446b6bb9_working/qpc-1e94c5946f6bdd98/qpc" - ) - self.vision_qpc_path = "/home/ubuntu/.cache/qeff_models/VisionEncoder-31e62a3c446b6bb9/qpc-7412e902c95a92c9/qpc" - lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) vision_session = QAICInferenceSession(self.vision_qpc_path, device_id)