From 8ebcc1b0a0f3e9504395983188f44f61320d1654 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 28 Jan 2025 06:16:57 +0000 Subject: [PATCH 1/6] Initial Commit for mllama Signed-off-by: Amit Raj --- QEfficient/__init__.py | 3 +- QEfficient/base/__init__.py | 2 +- .../models/mllama/modeling_mllama.py | 649 +++++++++++++----- .../transformers/models/modeling_auto.py | 542 ++++++++++++++- .../transformers/models/pytorch_transforms.py | 8 +- QEfficient/utils/constants.py | 5 + 6 files changed, 1034 insertions(+), 175 deletions(-) diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 1bc06ccf4..956ccf316 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -25,7 +25,7 @@ def check_qaic_sdk(): # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" if QAIC_INSTALLED: - from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader + from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader,QEFFAutoModelForImageTextToText from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv @@ -43,6 +43,7 @@ def check_qaic_sdk(): "QEFFAutoModel", "QEFFAutoModelForCausalLM", "QEffAutoPeftModelForCausalLM", + "QEFFAutoModelForImageTextToText", "QEFFCommonLoader", ] diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index 86cff11c1..4ae6dd9c0 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -6,4 +6,4 @@ # ----------------------------------------------------------------------------- from QEfficient.base.common import QEFFCommonLoader # noqa: F401 -from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM # noqa: F401 +from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM,QEFFAutoModelForImageTextToText # noqa: F401 diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e2f551415..b74519636 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -11,12 +11,14 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( + BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ) @@ -27,91 +29,17 @@ MllamaForCausalLM, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, + MllamaForConditionalGeneration, MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionModel, + logger, repeat_kv, rotate_half, ) - -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - - -class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - Add static sin/cos computations. - """ - - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[MllamaConfig] = None, - ): - super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) +from QEfficient.transformers.cache_utils import QEffDynamicCache def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): @@ -144,6 +72,74 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = ( + attention_mask + @ attention_mask.transpose(-1, -2) + * torch.tensor(-10000.0, dtype=torch.float32) + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where( + kv_indices_high < target_length, kv_indices, kv_indices_high - target_length + ) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): """ @@ -226,7 +222,6 @@ def forward( return attn_output, attn_weights, past_key_value - class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -255,19 +250,22 @@ def forward( if cross_attention_states is not None: key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"batch_index": batch_index, "position_ids": position_ids} + key_states, + value_states, + self.layer_idx, + {"batch_index": batch_index, "position_ids": position_ids}, ) - elif cache_position[0] != 0: + elif past_key_value is not None: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx], @@ -277,12 +275,25 @@ def forward( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + key_states = self.k_norm(key_states) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # attn_weights = torch.where( + # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + # ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -315,7 +326,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.45 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -428,22 +441,251 @@ def forward( return outputs -class QEffMllamaTextModel(MllamaTextModel): +class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py The only differences are: - - add new args cache idx for the kv retention + - Add static sin/cos computations. """ - # def __init__(self, config: MllamaTextConfig): - # super().__init__(config) - # self.config = config - # self.__qeff_init__() + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[MllamaConfig] = None, + ): + super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - # def __qeff_init__(self): - # self.layers = nn.ModuleList( - # [MllamaSelfAttentionDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.num_hidden_layers)] - # ) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( + self.inv_freq + ) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +class QEffMllamaVisionModel(MllamaVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + intermediate_hidden_states = intermediate_hidden_states[ + ..., self.intermediate_layers_indices + ] + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) + else: + hidden_states = None + + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = ( + tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) + ) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class QEffMllamaTextModel(MllamaTextModel): + """ + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - add new args cache idx for the kv retention + """ def forward( self, @@ -462,31 +704,13 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, MllamaTextModel - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaTextModel.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> text = "<|image|>If I had to write a haiku for this one" - >>> inputs = processor(text=text, return_tensors="pt") - - >>> output = model(**inputs) - - >>> print(output.last_hidden_state.shape) - torch.Size([1, 13, 4096]) - ``` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -506,27 +730,39 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance( + past_key_values, Cache + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions + attention_mask, + inputs_embeds, + cache_position, + position_ids, + past_key_values, + output_attentions, ) # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = None # decoder layers all_hidden_states = () if output_hidden_states else None @@ -552,8 +788,11 @@ def forward( # TODO: vbaddi: since past_key_values are retained from previous states, the condition for is_cross_attention_cache_empty is False # so explicitly making it true in order to skip the cross attention for language model # comment once there is vision and cross attention support - is_cross_attention_cache_empty = True - if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: + if ( + is_cross_attention_layer + and cross_attention_states is None + and is_cross_attention_cache_empty + ): continue if self.gradient_checkpointing and self.training: @@ -620,7 +859,11 @@ def forward( next_cache = next_cache.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -665,7 +908,11 @@ def _update_causal_mask( # TODO: vbaddi: unused, comment to fix linters # sequence_length = input_tensor.shape[1] - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) @@ -710,42 +957,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -797,3 +1015,94 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class VisionEncoder(nn.Module): + def __init__(self, mllama: MllamaForConditionalGeneration): + super().__init__() + self.mllama = mllama + self.cross_attention_layers = ( + self.mllama.config.get_text_config().cross_attention_layers + ) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + ) -> List[Tuple[torch.Tensor]]: + vision_outputs = self.mllama.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.mllama.multi_modal_projector( + cross_attention_states + ).reshape(-1, cross_attention_states.shape[-2], self.mllama.hidden_size) + + bsz = pixel_values.shape[0] + outputs = [] + for i in self.cross_attention_layers: + cross_attn = self.mllama.language_model.model.layers[i].cross_attn + key_states = cross_attn.k_proj(cross_attention_states) + value_states = cross_attn.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim + ).transpose(1, 2) + + outputs.append((key_states, value_states)) + return outputs + +class ModelWrapper(nn.Module): + def __init__(self, mllama): + super().__init__() + self.mllama = mllama + self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + if past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + outputs = self.mllama( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + if "past_key_values" in outputs: + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return outputs \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c2e3777bc..cef448917 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -9,12 +9,22 @@ import logging import warnings from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np +import requests import torch import torch.nn as nn -from transformers import AutoModel, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast +from PIL import Image +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextStreamer, +) import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel @@ -26,6 +36,10 @@ from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable + +from QEfficient.transformers.models.mllama.modeling_mllama import VisionEncoder, ModelWrapper +from single_qpc.qeff_classes import QEffDynamicCache + logger = logging.getLogger(__file__) @@ -59,6 +73,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + print(model) return cls(model, is_tlm=is_tlm) @property @@ -674,3 +689,526 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray torch.Tensor: A list of output features generated by the model for each prompt. """ return model(**inputs) + + + + + +class QEFFAutoModelForImageTextToText(QEFFTransformersBase): + + + _hf_auto_class = AutoModelForImageTextToText + _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__( + self, + model: nn.Module, + **kwargs, + ): + + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + super().__init__(model) + self.model.config.use_cache = True + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, two_qpc_method: bool = False, *args, **kwargs): + + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self.continuous_batching = continuous_batching + self.two_qpcmethod = two_qpc_method + + return self + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + # mhash.update(to_hashable(self.model.config.to_diff_dict())) + # mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + # mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + def _generate_inputs(self, **kwargs): + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + + self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len + + ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now + # TODO: Create a map for the other models to have their own inputs accordingly + images = [] + placeholder = "" + + # Note: if OOM, you might consider reduce number of frames in this example. + for i in range(1, 2): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{1}|>\n" + + messages = [ + {"role": "user", "content": placeholder + "Summarize the deck of slides."}, + ] + + prompt = self.processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = dict(self.processor(images=images, text=prompt, return_tensors="pt")) + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + for i in range(self.num_layers): + inputs["past_key_values"].append( + ( + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + ) + ) + output_names = [ + "logits", + "pixel_values_RetainedState", + "image_sizes_RetainedState", + *[f"past_{kv}.{i}_RetainedState" for i in range(self.num_layers) for kv in ["key", "value"]], + ] + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + # "pixel_values": {0: "img_batch_size"}, + } + for i in range(self.num_layers): + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + # Avoid issues due to index out of range + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) + + return inputs, dynamic_axes, output_names + + def export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES + max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES + image_length = constants.ONNX_EXPORT_IMAGE_LENGHT + image_width = constants.ONNX_EXPORT_IMAGE_WIDTH + num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH + + example_inputs = { + "pixel_values": torch.zeros((bs, max_num_images,max_image_tiles,num_channel, image_length, image_width ), dtype=torch.int64), + "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64) + } + vision_encoder=self.model=VisionEncoder(self.model) + vision_output_names = [] + for i in self.model.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": { + 0: "batch_size", + 1: "max_num_images", + 2: "max_image_tiles", + }, + } + + self._export( + example_inputs, + vision_output_names, + vision_dynamic_axes, + ) + + num_hidden_layers = self.model.config.get_text_config().num_hidden_layers + + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + lang_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), + "corss_attention_mask": torch.ones((bs, seq_len, max_num_images,max_image_tiles), dtype=torch.int64), + } + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + for i in range(num_hidden_layers): + if i in vision_encoder.cross_attention_layers: + idx = vision_encoder.cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + lang_inputs["past_key_values"].key_cache[i] = vision_outputs[idx][0] + lang_inputs["past_key_values"].value_cache[i] = vision_outputs[idx][1] + else: + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros( + (1, 8, 1024, 128) + ) + + lang_inputs["position_ids"] = torch.full( + (1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1 + ) + + lang_output_names = list(lang_outputs.keys()) + pkv_idx = lang_output_names.index("past_key_values") + lang_output_names[pkv_idx : pkv_idx + 1] = [ + f"past_{kv}.{i}_RetainedState" + for i in range(num_hidden_layers) + for kv in ["key", "value"] + ] + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + for i in range(num_hidden_layers): + if i in vision_encoder.cross_attention_layers: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + continue + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + self.model=ModelWrapper(self.model) + self.export( + lang_inputs, + lang_output_names, + lang_dynamic_axes + ) + + def _old_export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + :**kwargs: Keyword arguments for ``_generate_inputs``. If "ctx_len" is passed, it will be used as the context length. Otherwise, it will be set to 1280. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + + example_inputs, dynamic_axes, output_names = self._generate_inputs(**kwargs) + # breakpoint() + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: int = 1024, + ctx_len: int = 1280, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + enable_qnn: bool = False, + qnn_config: Optional[str] = None, + **compiler_options, + ) -> str: + """ + This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package. + If the model has not been exported yet, this method will handle the export process. + You can pass any other arguments that the `qaic-exec` takes as extra kwargs. + + ``Optional`` Args: + :onnx_path (str, optional): Path to pre-exported onnx model. + :compile_dir (str, optional): Path for saving the qpc generated. + :num_cores (int): Number of cores used to compile the model. + :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. + :batch_size (int, optional): Batch size. ``Defaults to 1``. + :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``. + :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``. + :full_batch_size (int, optional): Continuous batching batch size. + :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. + :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. + :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` + :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` + + Returns: + :str: Path of the compiled ``qpc`` package. + """ + # if self.is_tlm: + # # assert num_speculative_tokens cfg is acceptable if defined + # if num_speculative_tokens is None: + # raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + # if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + # ValueError( + # f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + # ) + # num_logits_to_keep = num_speculative_tokens + 1 + # if prefill_seq_len < num_logits_to_keep: + # raise ValueError( + # f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + # ) + + # if self.continuous_batching and full_batch_size is None: + # raise TypeError("missing required argument: 'full_batch_size'") + + # if kv_cache_batch_size and not full_batch_size: + # raise ValueError( + # "Prefix caching is enabled only for continuous batching. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" + # ) + + kv_cache_batch_size = ( + kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) + ) + # Define prefill specialization + prefill_specialization = { + # Prefill is always run with single BS for continuous batching. + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + # TODO: should be renamed to kv_cache_batch_size in specialzation too + } + prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... + if self.continuous_batching: + prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + prefill_specialization.update({"batch_size": kv_cache_batch_size}) + prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... + specializations = [ + prefill_specialization, + ] + + # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization + if prefill_seq_len != 1 or self.continuous_batching: + decode_specialization = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, + "ctx_len": ctx_len, + } + if self.continuous_batching: + decode_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + decode_specialization.update({"batch_size": kv_cache_batch_size}) + decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... + specializations.append(decode_specialization) + + if enable_qnn: + if compiler_options: + logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only") + + qpc_path = self._qnn_compile( + onnx_path, + compile_dir, + specializations=specializations, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mdp_ts_num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + qnn_config=qnn_config, + ) + else: + # Custom IO + custom_io = {} + kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + custom_io["pixel_values"] = kv_cache_dtype + custom_io["pixel_values_RetainedState"] = kv_cache_dtype + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + + breakpoint() + qpc_path = self._compile( + onnx_path, + compile_dir, + compile_only=True, + retained_state=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + custom_io=custom_io, + mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, + aic_num_cores=num_cores, + **compiler_options, + ) + return qpc_path + + def generate( + self, + inputs: torch.Tensor, + streamer: Optional[TextStreamer], + device_ids: List[int] = None, + runtime_ai100: bool = True, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + # AI_100 runtime + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) + + # TODO: Add the code based on how we did in single inference script + def cloud_ai_100_vlm_generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = [0], + ) -> np.ndarray: + """ + Generates features with list of prompts using AI 100 runtime. + + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``Optional`` Args: + device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + + Returns: + np.ndarray: A list of dictionaries containing the generated output features. + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + self.seq_len = self.qpc_session.bindings[0].dims[1] + # Skip inputs/outputs + self.qpc_session.skip_buffers( + [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] + + ["pixel_values_RetainedState", "image_sizes_RetainedState"] + ) + + # Read prompt and ctx len from session + # batch_size = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] + # ) + + # prefill_seq_len = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] + # ) + # Prepare input + input_ids_len = inputs["input_ids"].shape[1] + input_ids = np.array( + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + ) + attention_mask = np.array( + torch.nn.functional.pad( + inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 + ) + ) + + inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + + outputs = { + "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( + np.float32 + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + outputs = outputs["output"][:, :input_ids_len, :] + return outputs + + def pytorch_vlm_generate( + self, + model, + inputs: Union[torch.Tensor, np.ndarray], + streamer: TextStreamer, + ) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :streamer (TextStreamer): A TextStreamer object used for streaming the generated text. + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + # inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + # inputs["past_key_values"] = [] + # for _ in range(model.config.num_hidden_layers): + # inputs["past_key_values"].append(( + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len,self.head_dim), + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len, self.head_dim), + # )) + self.batch_size = inputs["input_ids"].shape[0] + generation_len = self.ctx_len - inputs["input_ids"].shape[1] + generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) + + outputs = model(**inputs) + + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + streamer.put(inputs["input_ids"]) + + for _ in range(generation_len): + outputs = model(**inputs) + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] += 1 + streamer.put(inputs["input_ids"]) + generated_ids[:, _] = inputs["input_ids"].squeeze(1) + generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(self.batch_size): + print(i, generated_texts[i]) + + return generated_ids + + def _export_two_qpc(): + pass + + def export_vision_model(): + pass + + def export_lang_model(): + pass \ No newline at end of file diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 6b8d00689..c3ad99f85 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -69,11 +69,13 @@ from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, MllamaTextRMSNorm, MllamaTextSelfAttention, + MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel @@ -165,10 +167,12 @@ from QEfficient.transformers.models.mllama.modeling_mllama import ( QEffMllamaCrossAttentionDecoderLayer, QEffMllamaForCausalLM, + QEffMllamaRotaryEmbedding, QEffMllamaSelfAttentionDecoderLayer, QEffMllamaTextCrossAttention, QEffMllamaTextModel, QEffMllamaTextSelfAttention, + QEffMllamaVisionModel, ) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, @@ -256,10 +260,12 @@ class KVCacheTransform(ModuleMappingTransform): # mllama MllamaForCausalLM: QEffMllamaForCausalLM, MllamaTextModel: QEffMllamaTextModel, + MllamaVisionModel: QEffMllamaVisionModel, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, MllamaTextCrossAttention: QEffMllamaTextCrossAttention, MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, @@ -343,4 +349,4 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed + return model, transformed \ No newline at end of file diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index ab861a788..004d2406c 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,6 +49,11 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_MAX_NUM_IMAGES =1 +ONNX_EXPORT_MAX_IMAGE_TILES = 4 +ONNX_EXPORT_IMAGE_WIDTH = 560 +ONNX_EXPORT_IMAGE_LENGHT = 560 +ONNX_EXPORT_IMAGE_DEPTH =3 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] From 61788fb30b80cd405f6588d13fc5d7a9441bb11f Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 28 Jan 2025 14:31:04 +0000 Subject: [PATCH 2/6] Working-1 Signed-off-by: Amit Raj --- .../transformers/models/modeling_auto.py | 87 ++++++++++++++++--- 1 file changed, 76 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cef448917..1f880f1f5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -38,8 +38,7 @@ from QEfficient.transformers.models.mllama.modeling_mllama import VisionEncoder, ModelWrapper -from single_qpc.qeff_classes import QEffDynamicCache - +from QEfficient.transformers.cache_utils import QEffDynamicCache logger = logging.getLogger(__file__) @@ -811,6 +810,7 @@ def export( "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64) } + model=self.model vision_encoder=self.model=VisionEncoder(self.model) vision_output_names = [] for i in self.model.cross_attention_layers: @@ -826,12 +826,12 @@ def export( }, } - self._export( + self.vision_onnx_path=self._export( example_inputs, vision_output_names, vision_dynamic_axes, ) - + self.model=model num_hidden_layers = self.model.config.get_text_config().num_hidden_layers seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN @@ -853,19 +853,18 @@ def export( if i in vision_encoder.cross_attention_layers: idx = vision_encoder.cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = vision_outputs[idx][0] - lang_inputs["past_key_values"].value_cache[i] = vision_outputs[idx][1] + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) else: lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros( - (1, 8, 1024, 128) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128) ) lang_inputs["position_ids"] = torch.full( (1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1 ) - lang_output_names = list(lang_outputs.keys()) + lang_output_names = ['logits', 'past_key_values'] pkv_idx = lang_output_names.index("past_key_values") lang_output_names[pkv_idx : pkv_idx + 1] = [ f"past_{kv}.{i}_RetainedState" @@ -890,13 +889,79 @@ def export( lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["input_ids"] = torch.tensor([[374]]) self.model=ModelWrapper(self.model) - self.export( + self.lang_onnx_path=self._export( lang_inputs, lang_output_names, lang_dynamic_axes ) + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: int = 32, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + + vision_specializations=[ + { + "batch_size": "1", + "max_num_images": "1", + "max_image_tiles": "4" + } + ] + vision_qpc_path= self._compile( + self.vision_onnx_path, + compile_dir, + compile_only=True, + specializations=vision_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, + ) + + lang_specializations=[ + { + "batch_size": "1", + "seq_len": "32", + "ctx_len": "1024", + "max_num_images": "1", + "max_image_tiles": "4" + }, + { + "batch_size": "1", + "seq_len": "1", + "ctx_len": "1024", + "max_num_images": "1", + "max_image_tiles": "4" + } + ] + + lang_qpc_path= self._compile( + self.lang_onnx_path, + compile_dir, + compile_only=True, + specializations=lang_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, + ) + + return vision_qpc_path, lang_qpc_path + + + def _old_export( self, export_dir: Optional[str] = None, @@ -923,7 +988,7 @@ def _old_export( export_dir=export_dir, ) - def compile( + def old_compile( self, onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, From 6250ba116fd42e5f7b67d28c7d3cfe96edd422cd Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 28 Jan 2025 15:27:23 +0000 Subject: [PATCH 3/6] Working-2 Signed-off-by: Amit Raj --- .../transformers/models/modeling_auto.py | 122 +++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 1f880f1f5..43ebda8fa 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -25,6 +25,7 @@ PreTrainedTokenizerFast, TextStreamer, ) +from transformers import AutoProcessor import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel @@ -718,8 +719,9 @@ def from_pretrained( if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self.processor= AutoProcessor.from_pretrained(pretrained_model_name_or_path) self.continuous_batching = continuous_batching self.two_qpcmethod = two_qpc_method @@ -798,6 +800,124 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: + + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + messages = [ + {"role": "user", "content": [ + {"type": "image"}, + {"type": "text", "text": "If I had to write a haiku for this one, it would be: "} + ]} + ] + input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + + split_inputs = self.processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=32, + ) + + lang_inputs = {} + vision_input = {} + + for k, v in split_inputs.items(): + if k in ["input_ids", "attention_mask", "cross_attention_mask"]: + lang_inputs[k] = v + else: + vision_input[k] = v + + self.vision_export_path= self.export_vision(vision_input) + self.lang_export_path = self.export_lang(lang_inputs) + + def export_vision(self, vision_input): + model=self.model + self.vision_encoder=self.model=VisionEncoder(self.model) + + vision_output_names = [] + for i in self.model.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles",}, + } + + self.vision_onnx_path=self._export( + vision_input, + vision_output_names, + vision_dynamic_axes, + ) + + self.model=model + return self.vision_export_path + + def export_lang(self, lang_inputs): + num_hidden_layers = self.model.config.get_text_config().num_hidden_layers + + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + + lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + for i in range(num_hidden_layers): + if i in self.vision_encoder.cross_attention_layers: + idx = self.vision_encoder.cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) + else: + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128) + ) + lang_inputs["position_ids"] = torch.full( + (1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1 + ) + lang_output_names = ['logits', 'past_key_values'] + pkv_idx = lang_output_names.index("past_key_values") + lang_output_names[pkv_idx : pkv_idx + 1] = [ + f"past_{kv}.{i}_RetainedState" + for i in range(num_hidden_layers) + for kv in ["key", "value"] + ] + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "cross_attention_mask": { 0: "batch_size", 1: "seq_len", 2: "max_num_images", 3: "max_image_tiles", + }, + } + + for i in range(num_hidden_layers): + if i in self.vision_encoder.cross_attention_layers: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + continue + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["input_ids"] = torch.tensor([[374]]) + self.model = ModelWrapper(self.model) + self.lang_onnx_path=self._export( + lang_inputs, + lang_output_names, + lang_dynamic_axes + ) + return self.lang_onnx_path + + def working_old_export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES From 3c5c612b38e5b4fbbcf43396f38761677b846165 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 28 Jan 2025 17:54:23 +0000 Subject: [PATCH 4/6] Integrated generation part Signed-off-by: Rishin Raj --- .../transformers/models/modeling_auto.py | 619 +++++++----------- pyproject.toml | 3 +- 2 files changed, 234 insertions(+), 388 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 43ebda8fa..0899fb31f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,9 +7,11 @@ import hashlib import logging +import sys import warnings from pathlib import Path -from typing import List, Optional, Tuple, Union +from time import perf_counter +from typing import List, Optional, Union import numpy as np import requests @@ -25,21 +27,19 @@ PreTrainedTokenizerFast, TextStreamer, ) -from transformers import AutoProcessor import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.cache import to_hashable - -from QEfficient.transformers.models.mllama.modeling_mllama import VisionEncoder, ModelWrapper -from QEfficient.transformers.cache_utils import QEffDynamicCache logger = logging.getLogger(__file__) @@ -73,7 +73,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = Fals kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - print(model) return cls(model, is_tlm=is_tlm) @property @@ -328,7 +327,7 @@ def compile( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - # TODO: should be renamed to kv_cache_batch_size in specialzation too + # TODO: should be renamed to kv_cache_batch_size in specialization too } prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... if self.continuous_batching: @@ -691,12 +690,7 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray return model(**inputs) - - - class QEFFAutoModelForImageTextToText(QEFFTransformersBase): - - _hf_auto_class = AutoModelForImageTextToText _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] @@ -706,24 +700,29 @@ def __init__( model: nn.Module, **kwargs, ): - if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - + super().__init__(model) self.model.config.use_cache = True @classmethod def from_pretrained( - cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, two_qpc_method: bool = False, *args, **kwargs): - + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + is_tlm: bool = False, + kv_offload: bool = False, + *args, + **kwargs, + ): if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) - self.processor= AutoProcessor.from_pretrained(pretrained_model_name_or_path) + self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) self.continuous_batching = continuous_batching - self.two_qpcmethod = two_qpc_method + self.kv_offload = kv_offload return self @@ -795,30 +794,30 @@ def _generate_inputs(self, **kwargs): return inputs, dynamic_axes, output_names - def export( + def _generate_inputs_mllama( self, - export_dir: Optional[str] = None, - **kwargs, - ) -> str: - + ): url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" image = Image.open(requests.get(url, stream=True).raw) messages = [ - {"role": "user", "content": [ - {"type": "image"}, - {"type": "text", "text": "If I had to write a haiku for this one, it would be: "} - ]} + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}, + ], + } ] input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True) - + split_inputs = self.processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=32, + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=32, ) lang_inputs = {} @@ -830,108 +829,23 @@ def export( else: vision_input[k] = v - self.vision_export_path= self.export_vision(vision_input) - self.lang_export_path = self.export_lang(lang_inputs) - - def export_vision(self, vision_input): - model=self.model - self.vision_encoder=self.model=VisionEncoder(self.model) - - vision_output_names = [] - for i in self.model.cross_attention_layers: - vision_output_names.append(f"past_key.{i}") - vision_output_names.append(f"past_value.{i}") - vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, - "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles",}, - } - - self.vision_onnx_path=self._export( - vision_input, - vision_output_names, - vision_dynamic_axes, - ) - - self.model=model - return self.vision_export_path - - def export_lang(self, lang_inputs): - num_hidden_layers = self.model.config.get_text_config().num_hidden_layers - - lang_inputs["position_ids"] = torch.where( - lang_inputs.pop("attention_mask") == 1, - torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), - -1, - ) - - lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) - lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers - lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + return lang_inputs, vision_input - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - idx = self.vision_encoder.cross_attention_layers.index(i) - assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) - else: - lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128) - ) - lang_inputs["position_ids"] = torch.full( - (1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1 - ) - lang_output_names = ['logits', 'past_key_values'] - pkv_idx = lang_output_names.index("past_key_values") - lang_output_names[pkv_idx : pkv_idx + 1] = [ - f"past_{kv}.{i}_RetainedState" - for i in range(num_hidden_layers) - for kv in ["key", "value"] - ] - lang_dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - "cross_attention_mask": { 0: "batch_size", 1: "seq_len", 2: "max_num_images", 3: "max_image_tiles", - }, - } - - for i in range(num_hidden_layers): - if i in self.vision_encoder.cross_attention_layers: - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} - continue - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() - lang_inputs["input_ids"] = torch.tensor([[374]]) - self.model = ModelWrapper(self.model) - self.lang_onnx_path=self._export( - lang_inputs, - lang_output_names, - lang_dynamic_axes - ) - return self.lang_onnx_path - - def working_old_export( + def export( self, export_dir: Optional[str] = None, **kwargs, ) -> str: - bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES - max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES - image_length = constants.ONNX_EXPORT_IMAGE_LENGHT - image_width = constants.ONNX_EXPORT_IMAGE_WIDTH - num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH + if self.kv_offload: + lang_inputs, vision_input = self._generate_inputs_mllama() + + self.vision_export_path = self.export_vision(vision_input, export_dir) + self.lang_export_path = self.export_lang(lang_inputs, export_dir) + + def export_vision(self, vision_input, export_dir): + model = self.model + self.vision_encoder = self.model = VisionEncoder(self.model) - example_inputs = { - "pixel_values": torch.zeros((bs, max_num_images,max_image_tiles,num_channel, image_length, image_width ), dtype=torch.int64), - "aspect_ratio_ids": torch.ones((bs, max_num_images), dtype=torch.int64), - "aspect_ratio_mask": torch.ones((bs, max_num_images, max_image_tiles,1 ), dtype=torch.int64) - } - model=self.model - vision_encoder=self.model=VisionEncoder(self.model) vision_output_names = [] for i in self.model.cross_attention_layers: vision_output_names.append(f"past_key.{i}") @@ -946,51 +860,47 @@ def working_old_export( }, } - self.vision_onnx_path=self._export( - example_inputs, + self.vision_onnx_path = self._export( + vision_input, vision_output_names, vision_dynamic_axes, + export_dir=export_dir, ) - self.model=model + + self.model = model + return self.vision_onnx_path + + def export_lang(self, lang_inputs, export_dir): num_hidden_layers = self.model.config.get_text_config().num_hidden_layers - seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - lang_inputs = { - "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64), - "corss_attention_mask": torch.ones((bs, seq_len, max_num_images,max_image_tiles), dtype=torch.int64), - } lang_inputs["position_ids"] = torch.where( lang_inputs.pop("attention_mask") == 1, torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), -1, ) + lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers - + for i in range(num_hidden_layers): - if i in vision_encoder.cross_attention_layers: - idx = vision_encoder.cross_attention_layers.index(i) + if i in self.vision_encoder.cross_attention_layers: + idx = self.vision_encoder.cross_attention_layers.index(i) assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) else: lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128) - ) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128)) - lang_inputs["position_ids"] = torch.full( - (1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1 - ) - - lang_output_names = ['logits', 'past_key_values'] + lang_inputs["position_ids"] = torch.full((1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1) + lang_output_names = ["logits", "past_key_values"] pkv_idx = lang_output_names.index("past_key_values") + lang_output_names[pkv_idx : pkv_idx + 1] = [ - f"past_{kv}.{i}_RetainedState" - for i in range(num_hidden_layers) - for kv in ["key", "value"] + f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"] ] + lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, @@ -1001,264 +911,73 @@ def working_old_export( 3: "max_image_tiles", }, } + for i in range(num_hidden_layers): - if i in vision_encoder.cross_attention_layers: + if i in self.vision_encoder.cross_attention_layers: lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} continue lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["input_ids"] = torch.tensor([[374]]) - self.model=ModelWrapper(self.model) - self.lang_onnx_path=self._export( - lang_inputs, - lang_output_names, - lang_dynamic_axes - ) + lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] + self.model = ModelWrapper(self.model) + self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) + return self.lang_onnx_path def compile( self, - onnx_path: Optional[str] = None, + vision_onnx_path: Optional[str] = None, + lang_onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, - *, - seq_len: int = 32, - batch_size: int = 1, - num_devices: int = 1, - num_cores: int = 16, # FIXME: Make this mandatory arg - mxfp6_matmul: bool = False, - **compiler_options, - ) -> str: - - vision_specializations=[ - { - "batch_size": "1", - "max_num_images": "1", - "max_image_tiles": "4" - } - ] - vision_qpc_path= self._compile( - self.vision_onnx_path, - compile_dir, - compile_only=True, - specializations=vision_specializations, - convert_to_fp16=True, - mxfp6_matmul=mxfp6_matmul, - mdp_ts_num_devices=num_devices, - aic_num_cores=num_cores, - **compiler_options, - ) - - lang_specializations=[ - { - "batch_size": "1", - "seq_len": "32", - "ctx_len": "1024", - "max_num_images": "1", - "max_image_tiles": "4" - }, - { - "batch_size": "1", - "seq_len": "1", - "ctx_len": "1024", - "max_num_images": "1", - "max_image_tiles": "4" - } - ] - - lang_qpc_path= self._compile( - self.lang_onnx_path, - compile_dir, - compile_only=True, - specializations=lang_specializations, - convert_to_fp16=True, - mxfp6_matmul=mxfp6_matmul, - mdp_ts_num_devices=num_devices, - aic_num_cores=num_cores, - **compiler_options, - ) - - return vision_qpc_path, lang_qpc_path - - - - def _old_export( - self, - export_dir: Optional[str] = None, - **kwargs, - ) -> str: - """ - Exports the model to ``ONNX`` format using ``torch.onnx.export``. - - ``Optional`` Args: - :export_dir (str, optional): The directory path to store ONNX-graph. - :**kwargs: Keyword arguments for ``_generate_inputs``. If "ctx_len" is passed, it will be used as the context length. Otherwise, it will be set to 1280. - - Returns: - :str: Path of the generated ``ONNX`` graph. - """ - - - example_inputs, dynamic_axes, output_names = self._generate_inputs(**kwargs) - # breakpoint() - return self._export( - example_inputs, - output_names, - dynamic_axes, - export_dir=export_dir, - ) - - def old_compile( - self, - onnx_path: Optional[str] = None, - compile_dir: Optional[str] = None, - *, - prefill_seq_len: int = 1024, - ctx_len: int = 1280, + prefill_seq_len: int = 32, + ctx_len: int = 128, batch_size: int = 1, - full_batch_size: Optional[int] = None, - kv_cache_batch_size: Optional[int] = None, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, - mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, - enable_qnn: bool = False, - qnn_config: Optional[str] = None, **compiler_options, ) -> str: - """ - This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package. - If the model has not been exported yet, this method will handle the export process. - You can pass any other arguments that the `qaic-exec` takes as extra kwargs. - - ``Optional`` Args: - :onnx_path (str, optional): Path to pre-exported onnx model. - :compile_dir (str, optional): Path for saving the qpc generated. - :num_cores (int): Number of cores used to compile the model. - :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. - :batch_size (int, optional): Batch size. ``Defaults to 1``. - :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``. - :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``. - :full_batch_size (int, optional): Continuous batching batch size. - :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. - :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. - :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. - :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. - :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. - :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` - :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` - - Returns: - :str: Path of the compiled ``qpc`` package. - """ - # if self.is_tlm: - # # assert num_speculative_tokens cfg is acceptable if defined - # if num_speculative_tokens is None: - # raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") - # if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: - # ValueError( - # f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" - # ) - # num_logits_to_keep = num_speculative_tokens + 1 - # if prefill_seq_len < num_logits_to_keep: - # raise ValueError( - # f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" - # ) - - # if self.continuous_batching and full_batch_size is None: - # raise TypeError("missing required argument: 'full_batch_size'") - - # if kv_cache_batch_size and not full_batch_size: - # raise ValueError( - # "Prefix caching is enabled only for continuous batching. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" - # ) - - kv_cache_batch_size = ( - kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) - ) - # Define prefill specialization - prefill_specialization = { - # Prefill is always run with single BS for continuous batching. - "batch_size": 1 if self.continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - # TODO: should be renamed to kv_cache_batch_size in specialzation too - } - prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... - if self.continuous_batching: - prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) - else: - prefill_specialization.update({"batch_size": kv_cache_batch_size}) - prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... - specializations = [ - prefill_specialization, - ] - - # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization - if prefill_seq_len != 1 or self.continuous_batching: - decode_specialization = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, - "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, - "ctx_len": ctx_len, - } - if self.continuous_batching: - decode_specialization.update({"full_batch_size": kv_cache_batch_size}) - else: - decode_specialization.update({"batch_size": kv_cache_batch_size}) - decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... - specializations.append(decode_specialization) - - if enable_qnn: - if compiler_options: - logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only") + if self.kv_offload: - qpc_path = self._qnn_compile( - onnx_path, + vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] + self.vision_qpc_path = self._compile( + vision_onnx_path, compile_dir, - specializations=specializations, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - full_batch_size=full_batch_size, - mdp_ts_num_devices=num_devices, - num_cores=num_cores, + compile_only=True, + specializations=vision_specializations, + convert_to_fp16=True, mxfp6_matmul=mxfp6_matmul, - mxint8_kv_cache=mxint8_kv_cache, - qnn_config=qnn_config, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, ) - else: - # Custom IO - custom_io = {} - kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" - custom_io["pixel_values"] = kv_cache_dtype - custom_io["pixel_values_RetainedState"] = kv_cache_dtype - for suffix in ["", "_RetainedState"]: - for i in range(self.num_layers): - for kv in ["key", "value"]: - custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - breakpoint() - qpc_path = self._compile( - onnx_path, + lang_specializations = [ + {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, + {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, + ] + + self.lang_qpc_path = self._compile( + lang_onnx_path, compile_dir, compile_only=True, - retained_state=True, - specializations=specializations, + specializations=lang_specializations, convert_to_fp16=True, mxfp6_matmul=mxfp6_matmul, - custom_io=custom_io, mdp_ts_num_devices=num_devices, - num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, **compiler_options, ) - return qpc_path + + return self.vision_qpc_path, self.lang_qpc_path def generate( self, inputs: torch.Tensor, - streamer: Optional[TextStreamer], + streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, ) -> Union[torch.Tensor, np.ndarray]: @@ -1274,10 +993,12 @@ def generate( """ # AI_100 runtime if runtime_ai100: - if not isinstance(self.qpc_path, Path): - raise TypeError("Please run compile API first!") - - return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + # if not isinstance(self.qpc_path, Path): + # raise TypeError("Please run compile API first!") + if self.kv_offload: + self.encoder_decoder_generate(inputs, streamer, device_ids) + else: + return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) # PyTorch runtime else: return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) @@ -1389,11 +1110,135 @@ def pytorch_vlm_generate( return generated_ids - def _export_two_qpc(): - pass - - def export_vision_model(): - pass + def encoder_decoder_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_id: List[int] = None, + generation_len: int = None, + ctx_len: int = 512, + stream: bool = True, + **kwargs, + ): + # self.lang_qpc_path = Path( + # "/home/rishinr/vision/vision_infra/llama-vision/qpc/Llama-3.2-11B-Vision-Instruct-language" + # ) + # self.vision_qpc_path = Path( + # "/home/rishinr/vision/vision_infra/llama-vision/qpc/Llama-3.2-11B-Vision-Instruct-vision" + # ) + + lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) + vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) + + tokenizer = self.processor.tokenizer + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + if streamer is None: + streamer = TextStreamer(tokenizer) + + # Skip inputs/outputs + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), tokenizer.pad_token_id) + + # Prepare inputs for prefill + start = perf_counter() + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_outputs = vision_session.run(dict(vision_inputs)) + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + lang_inputs = dict(lang_inputs) + + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + # Run prefill + for i in range(num_chunks): + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + + # Skip inputs/outputs again + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) - def export_lang_model(): - pass \ No newline at end of file + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = input_len + lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + finished_sequences = lang_inputs["input_ids"] == tokenizer.eos_token_id + if stream: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = lang_session.run(lang_inputs) + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + finished_sequences |= lang_inputs["input_ids"] == tokenizer.eos_token_id + + if stream: + streamer.put(lang_inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if stream: + streamer.end() + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(1 if stream else 0, batch_size): + print(i, generated_texts[i]) + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) diff --git a/pyproject.toml b/pyproject.toml index 9867181ca..e04bba103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.45.2", + "transformers==4.46.0", "huggingface-hub==0.27.0", "peft==0.13.2", "datasets==2.20.0", @@ -32,6 +32,7 @@ dependencies = [ "numpy==1.26.4", "protobuf==3.20.2", "onnxscript==0.1.0.dev20240327", + "pillow===11.1.0", "sympy", "tensorboard", "fire", From ce5259f4c0ab64851b3c0898100c61259753fb3c Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Wed, 29 Jan 2025 11:07:25 +0000 Subject: [PATCH 5/6] Working-3 Signed-off-by: Amit Raj --- QEfficient/base/__init__.py | 6 +++- .../transformers/models/modeling_auto.py | 31 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index 4ae6dd9c0..4344cac53 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -6,4 +6,8 @@ # ----------------------------------------------------------------------------- from QEfficient.base.common import QEFFCommonLoader # noqa: F401 -from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM,QEFFAutoModelForImageTextToText # noqa: F401 +from QEfficient.transformers.models.modeling_auto import ( # noqa: F401 + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, +) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0899fb31f..0c16f5cc4 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -706,6 +706,7 @@ def __init__( super().__init__(model) self.model.config.use_cache = True + @classmethod def from_pretrained( cls, @@ -836,6 +837,7 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: + self.kv_offload=True if self.kv_offload: lang_inputs, vision_input = self._generate_inputs_mllama() @@ -868,10 +870,11 @@ def export_vision(self, vision_input, export_dir): ) self.model = model + self.vision_output_names=vision_output_names return self.vision_onnx_path def export_lang(self, lang_inputs, export_dir): - num_hidden_layers = self.model.config.get_text_config().num_hidden_layers + self.num_layers=num_hidden_layers = self.model.config.get_text_config().num_hidden_layers lang_inputs["position_ids"] = torch.where( lang_inputs.pop("attention_mask") == 1, @@ -923,8 +926,11 @@ def export_lang(self, lang_inputs, export_dir): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["input_ids"] = torch.tensor([[374]]) lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] + model=self.model self.model = ModelWrapper(self.model) self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) + self.model=model + self.lang_output_names=lang_output_names return self.lang_onnx_path def compile( @@ -940,11 +946,20 @@ def compile( mxfp6_matmul: bool = False, **compiler_options, ) -> str: + self.kv_offload = True if self.kv_offload: vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] + + custom_io = {} + kv_cache_dtype ="float16" + for output_name in self.vision_output_names: + custom_io[output_name] = kv_cache_dtype + + model=self.model + self.model=self.vision_encoder self.vision_qpc_path = self._compile( - vision_onnx_path, + self.vision_onnx_path, compile_dir, compile_only=True, specializations=vision_specializations, @@ -952,16 +967,21 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + custom_io=custom_io, **compiler_options, ) - + self.model=ModelWrapper(model) lang_specializations = [ {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, ] + custom_io_lang={} + for output_name in self.lang_output_names: + custom_io_lang[output_name]=kv_cache_dtype + self.lang_qpc_path = self._compile( - lang_onnx_path, + self.lang_onnx_path, compile_dir, compile_only=True, specializations=lang_specializations, @@ -969,9 +989,10 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + custom_io= custom_io_lang, **compiler_options, ) - + self.model=model return self.vision_qpc_path, self.lang_qpc_path def generate( From 3a066178166c5d9f6ce2bfaa5be2de711ea0d9e1 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Wed, 29 Jan 2025 16:49:31 +0000 Subject: [PATCH 6/6] Custom IO fix Signed-off-by: Rishin Raj --- .../models/mllama/modeling_mllama.py | 6 +- .../transformers/models/modeling_auto.py | 99 +++++++++++++------ 2 files changed, 75 insertions(+), 30 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index b74519636..90be64096 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -27,18 +27,18 @@ MllamaConfig, MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, - MllamaForConditionalGeneration, MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, MllamaVisionModel, - logger, repeat_kv, rotate_half, ) + from QEfficient.transformers.cache_utils import QEffDynamicCache @@ -1023,6 +1023,7 @@ def __init__(self, mllama: MllamaForConditionalGeneration): self.cross_attention_layers = ( self.mllama.config.get_text_config().cross_attention_layers ) + self.config = self.mllama.config.get_text_config() def forward( self, @@ -1061,6 +1062,7 @@ def __init__(self, mllama): super().__init__() self.mllama = mllama self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers + self.config = self.mllama.config.get_text_config() def forward( self, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0c16f5cc4..2e714840d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -32,6 +32,7 @@ from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import get_compilation_dims from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform @@ -706,7 +707,6 @@ def __init__( super().__init__(model) self.model.config.use_cache = True - @classmethod def from_pretrained( cls, @@ -724,6 +724,7 @@ def from_pretrained( self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) self.continuous_batching = continuous_batching self.kv_offload = kv_offload + self.is_tlm = is_tlm return self @@ -731,9 +732,9 @@ def from_pretrained( def model_hash(self) -> str: # Compute the hash with: model_config, continuous_batching, transforms mhash = hashlib.sha256() - # mhash.update(to_hashable(self.model.config.to_diff_dict())) - # mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) - # mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) mhash.update(to_hashable(self._transform_names())) mhash = mhash.hexdigest()[:16] return mhash @@ -837,11 +838,13 @@ def export( export_dir: Optional[str] = None, **kwargs, ) -> str: - self.kv_offload=True + self.kv_offload = True if self.kv_offload: + print("generating input") lang_inputs, vision_input = self._generate_inputs_mllama() - + print("generating vision model") self.vision_export_path = self.export_vision(vision_input, export_dir) + print("generating lang model") self.lang_export_path = self.export_lang(lang_inputs, export_dir) def export_vision(self, vision_input, export_dir): @@ -870,11 +873,11 @@ def export_vision(self, vision_input, export_dir): ) self.model = model - self.vision_output_names=vision_output_names + self.vision_output_names = vision_output_names return self.vision_onnx_path def export_lang(self, lang_inputs, export_dir): - self.num_layers=num_hidden_layers = self.model.config.get_text_config().num_hidden_layers + self.num_layers = num_hidden_layers = self.model.config.get_text_config().num_hidden_layers lang_inputs["position_ids"] = torch.where( lang_inputs.pop("attention_mask") == 1, @@ -926,11 +929,12 @@ def export_lang(self, lang_inputs, export_dir): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["input_ids"] = torch.tensor([[374]]) lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] - model=self.model - self.model = ModelWrapper(self.model) + self.lang_output_names = lang_output_names + model = self.model + self.model = ModelWrapper(model) + self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) - self.model=model - self.lang_output_names=lang_output_names + self.model = model return self.lang_onnx_path def compile( @@ -948,16 +952,19 @@ def compile( ) -> str: self.kv_offload = True if self.kv_offload: - + model = self.model + self.model = VisionEncoder(model) vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] custom_io = {} - kv_cache_dtype ="float16" + kv_cache_dtype = "float16" + custom_io["pixel_values"] = kv_cache_dtype for output_name in self.vision_output_names: custom_io[output_name] = kv_cache_dtype - model=self.model - self.model=self.vision_encoder + model = self.model + self.model = self.vision_encoder + print("compiling vision model") self.vision_qpc_path = self._compile( self.vision_onnx_path, compile_dir, @@ -970,16 +977,45 @@ def compile( custom_io=custom_io, **compiler_options, ) - self.model=ModelWrapper(model) + self.model = ModelWrapper(model) + lang_specializations = [ - {"batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, - {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "max_num_images": "1", "max_image_tiles": "4"}, + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, ] - custom_io_lang={} + custom_io_lang = {} + # Inputs for output_name in self.lang_output_names: - custom_io_lang[output_name]=kv_cache_dtype - + if output_name.startswith("past_"): + custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + # outputs + for output_name in self.lang_output_names: + if output_name.startswith("past_"): + custom_io_lang[output_name] = kv_cache_dtype + + # custom_io = {} + # kv_cache_dtype = "float16" + # custom_io["pixel_values"] = kv_cache_dtype + # custom_io["pixel_values_RetainedState"] = kv_cache_dtype + # for suffix in ["", "_RetainedState"]: + # for i in range(self.num_layers): + # for kv in ["key", "value"]: + # custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + + print("generating lang model") self.lang_qpc_path = self._compile( self.lang_onnx_path, compile_dir, @@ -989,10 +1025,10 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, - custom_io= custom_io_lang, + custom_io=custom_io_lang, **compiler_options, ) - self.model=model + self.model = model return self.vision_qpc_path, self.lang_qpc_path def generate( @@ -1017,7 +1053,7 @@ def generate( # if not isinstance(self.qpc_path, Path): # raise TypeError("Please run compile API first!") if self.kv_offload: - self.encoder_decoder_generate(inputs, streamer, device_ids) + self.kv_offload_generate(inputs, streamer, device_ids) else: return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) # PyTorch runtime @@ -1131,13 +1167,12 @@ def pytorch_vlm_generate( return generated_ids - def encoder_decoder_generate( + def kv_offload_generate( self, inputs: List[str] = None, streamer: Optional[TextStreamer] = None, device_id: List[int] = None, generation_len: int = None, - ctx_len: int = 512, stream: bool = True, **kwargs, ): @@ -1147,15 +1182,23 @@ def encoder_decoder_generate( # self.vision_qpc_path = Path( # "/home/rishinr/vision/vision_infra/llama-vision/qpc/Llama-3.2-11B-Vision-Instruct-vision" # ) + # self.lang_qpc_path = Path( + # "/home/rishinr/.cache/qeff_models/mllama_bc/ModelWrapper-e34b1a9bd1cf14cb/qpc-0fd0400e8969c49e/qpc" + # ) + # self.vision_qpc_path = Path( + # "/home/rishinr/.cache/qeff_models/mllama_bc/VisionEncoder-e34b1a9bd1cf14cb/qpc-b4c5b2ba8c79d148/qpc" + # ) lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_qpc_path) + tokenizer = self.processor.tokenizer if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - + if streamer is None: streamer = TextStreamer(tokenizer)