diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 1bc06ccf4..956ccf316 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -25,7 +25,7 @@ def check_qaic_sdk(): # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" if QAIC_INSTALLED: - from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader + from QEfficient.base import QEFFAutoModel, QEFFAutoModelForCausalLM, QEFFCommonLoader,QEFFAutoModelForImageTextToText from QEfficient.compile.compile_helper import compile from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv @@ -43,6 +43,7 @@ def check_qaic_sdk(): "QEFFAutoModel", "QEFFAutoModelForCausalLM", "QEffAutoPeftModelForCausalLM", + "QEFFAutoModelForImageTextToText", "QEFFCommonLoader", ] diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index 86cff11c1..4344cac53 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -6,4 +6,8 @@ # ----------------------------------------------------------------------------- from QEfficient.base.common import QEFFCommonLoader # noqa: F401 -from QEfficient.transformers.models.modeling_auto import QEFFAutoModel, QEFFAutoModelForCausalLM # noqa: F401 +from QEfficient.transformers.models.modeling_auto import ( # noqa: F401 + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForImageTextToText, +) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e2f551415..90be64096 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -11,12 +11,14 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( + BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ) @@ -25,93 +27,19 @@ MllamaConfig, MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaForConditionalGeneration, MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionModel, logger, repeat_kv, rotate_half, ) -from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask - - -class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - """ - Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py - The only differences are: - - Add static sin/cos computations. - """ - - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[MllamaConfig] = None, - ): - super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) +from QEfficient.transformers.cache_utils import QEffDynamicCache def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): @@ -144,6 +72,74 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): # Cast back to original dtype return q_embed.to(q.dtype), k_embed.to(k.dtype) +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = ( + attention_mask + @ attention_mask.transpose(-1, -2) + * torch.tensor(-10000.0, dtype=torch.float32) + ) + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where( + kv_indices_high < target_length, kv_indices, kv_indices_high - target_length + ) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): """ @@ -226,7 +222,6 @@ def forward( return attn_output, attn_weights, past_key_value - class QEffMllamaTextCrossAttention(MllamaTextCrossAttention): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py @@ -255,19 +250,22 @@ def forward( if cross_attention_states is not None: key_states = self.k_proj(cross_attention_states) value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"batch_index": batch_index, "position_ids": position_ids} + key_states, + value_states, + self.layer_idx, + {"batch_index": batch_index, "position_ids": position_ids}, ) - elif cache_position[0] != 0: + elif past_key_value is not None: key_states, value_states = ( past_key_value.key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx], @@ -277,12 +275,25 @@ def forward( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + key_states = self.k_norm(key_states) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + # attn_weights = torch.where( + # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights + # ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -315,7 +326,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.45 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -428,22 +441,251 @@ def forward( return outputs -class QEffMllamaTextModel(MllamaTextModel): +class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): """ Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py The only differences are: - - add new args cache idx for the kv retention + - Add static sin/cos computations. """ - # def __init__(self, config: MllamaTextConfig): - # super().__init__(config) - # self.config = config - # self.__qeff_init__() + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[MllamaConfig] = None, + ): + super(MllamaRotaryEmbedding, self).__init__() # Initialize nn.Module + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.45" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as( + self.inv_freq + ) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) - # def __qeff_init__(self): - # self.layers = nn.ModuleList( - # [MllamaSelfAttentionDecoderLayer(self.config, layer_idx) for layer_idx in range(self.config.num_hidden_layers)] - # ) + +class QEffMllamaVisionModel(MllamaVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( + pixel_values.shape + ) + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, height, width + ) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim + ) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches, dim + ) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, + 0, + 0, + num_padding_patches, + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, dim + ) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + intermediate_hidden_states = intermediate_hidden_states[ + ..., self.intermediate_layers_indices + ] + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) + else: + hidden_states = None + + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = ( + tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) + ) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class QEffMllamaTextModel(MllamaTextModel): + """ + Copied from MllamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py + The only differences are: + - add new args cache idx for the kv retention + """ def forward( self, @@ -462,31 +704,13 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - """ - - Returns: - - Example: - - ```python - >>> from transformers import AutoProcessor, MllamaTextModel - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaTextModel.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> text = "<|image|>If I had to write a haiku for this one" - >>> inputs = processor(text=text, return_tensors="pt") - - >>> output = model(**inputs) - - >>> print(output.last_hidden_state.shape) - torch.Size([1, 13, 4096]) - ``` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -506,27 +730,39 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance( + past_key_values, Cache + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions + attention_mask, + inputs_embeds, + cache_position, + position_ids, + past_key_values, + output_attentions, ) # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = None # decoder layers all_hidden_states = () if output_hidden_states else None @@ -552,8 +788,11 @@ def forward( # TODO: vbaddi: since past_key_values are retained from previous states, the condition for is_cross_attention_cache_empty is False # so explicitly making it true in order to skip the cross attention for language model # comment once there is vision and cross attention support - is_cross_attention_cache_empty = True - if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: + if ( + is_cross_attention_layer + and cross_attention_states is None + and is_cross_attention_cache_empty + ): continue if self.gradient_checkpointing and self.training: @@ -620,7 +859,11 @@ def forward( next_cache = next_cache.to_legacy_cache() if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -665,7 +908,11 @@ def _update_causal_mask( # TODO: vbaddi: unused, comment to fix linters # sequence_length = input_tensor.shape[1] - target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) @@ -710,42 +957,13 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -797,3 +1015,96 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class VisionEncoder(nn.Module): + def __init__(self, mllama: MllamaForConditionalGeneration): + super().__init__() + self.mllama = mllama + self.cross_attention_layers = ( + self.mllama.config.get_text_config().cross_attention_layers + ) + self.config = self.mllama.config.get_text_config() + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + ) -> List[Tuple[torch.Tensor]]: + vision_outputs = self.mllama.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.mllama.multi_modal_projector( + cross_attention_states + ).reshape(-1, cross_attention_states.shape[-2], self.mllama.hidden_size) + + bsz = pixel_values.shape[0] + outputs = [] + for i in self.cross_attention_layers: + cross_attn = self.mllama.language_model.model.layers[i].cross_attn + key_states = cross_attn.k_proj(cross_attention_states) + value_states = cross_attn.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim + ).transpose(1, 2) + + outputs.append((key_states, value_states)) + return outputs + +class ModelWrapper(nn.Module): + def __init__(self, mllama): + super().__init__() + self.mllama = mllama + self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers + self.config = self.mllama.config.get_text_config() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ): + if past_key_values is not None: + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + outputs = self.mllama( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + if "past_key_values" in outputs: + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return outputs \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c2e3777bc..2e714840d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -7,19 +7,34 @@ import hashlib import logging +import sys import warnings from pathlib import Path +from time import perf_counter from typing import List, Optional, Union import numpy as np +import requests import torch import torch.nn as nn -from transformers import AutoModel, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast +from PIL import Image +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextStreamer, +) import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import get_compilation_dims +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.models.mllama.modeling_mllama import ModelWrapper, VisionEncoder from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform @@ -313,7 +328,7 @@ def compile( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - # TODO: should be renamed to kv_cache_batch_size in specialzation too + # TODO: should be renamed to kv_cache_batch_size in specialization too } prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... if self.continuous_batching: @@ -674,3 +689,620 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray torch.Tensor: A list of output features generated by the model for each prompt. """ return model(**inputs) + + +class QEFFAutoModelForImageTextToText(QEFFTransformersBase): + _hf_auto_class = AutoModelForImageTextToText + _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__( + self, + model: nn.Module, + **kwargs, + ): + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + super().__init__(model) + self.model.config.use_cache = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + continuous_batching: bool = False, + is_tlm: bool = False, + kv_offload: bool = False, + *args, + **kwargs, + ): + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self.processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, padding_side="right", **kwargs) + self.continuous_batching = continuous_batching + self.kv_offload = kv_offload + self.is_tlm = is_tlm + + return self + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + def _generate_inputs(self, **kwargs): + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + + self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len + + ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now + # TODO: Create a map for the other models to have their own inputs accordingly + images = [] + placeholder = "" + + # Note: if OOM, you might consider reduce number of frames in this example. + for i in range(1, 2): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{1}|>\n" + + messages = [ + {"role": "user", "content": placeholder + "Summarize the deck of slides."}, + ] + + prompt = self.processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = dict(self.processor(images=images, text=prompt, return_tensors="pt")) + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + for i in range(self.num_layers): + inputs["past_key_values"].append( + ( + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + ) + ) + output_names = [ + "logits", + "pixel_values_RetainedState", + "image_sizes_RetainedState", + *[f"past_{kv}.{i}_RetainedState" for i in range(self.num_layers) for kv in ["key", "value"]], + ] + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + # "pixel_values": {0: "img_batch_size"}, + } + for i in range(self.num_layers): + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + # Avoid issues due to index out of range + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) + + return inputs, dynamic_axes, output_names + + def _generate_inputs_mllama( + self, + ): + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "If I had to write a haiku for this one, it would be: "}, + ], + } + ] + input_text = self.processor.apply_chat_template(messages, add_generation_prompt=True) + + split_inputs = self.processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=32, + ) + + lang_inputs = {} + vision_input = {} + + for k, v in split_inputs.items(): + if k in ["input_ids", "attention_mask", "cross_attention_mask"]: + lang_inputs[k] = v + else: + vision_input[k] = v + + return lang_inputs, vision_input + + def export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: + self.kv_offload = True + if self.kv_offload: + print("generating input") + lang_inputs, vision_input = self._generate_inputs_mllama() + print("generating vision model") + self.vision_export_path = self.export_vision(vision_input, export_dir) + print("generating lang model") + self.lang_export_path = self.export_lang(lang_inputs, export_dir) + + def export_vision(self, vision_input, export_dir): + model = self.model + self.vision_encoder = self.model = VisionEncoder(self.model) + + vision_output_names = [] + for i in self.model.cross_attention_layers: + vision_output_names.append(f"past_key.{i}") + vision_output_names.append(f"past_value.{i}") + vision_dynamic_axes = { + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_image_tiles"}, + "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": { + 0: "batch_size", + 1: "max_num_images", + 2: "max_image_tiles", + }, + } + + self.vision_onnx_path = self._export( + vision_input, + vision_output_names, + vision_dynamic_axes, + export_dir=export_dir, + ) + + self.model = model + self.vision_output_names = vision_output_names + return self.vision_onnx_path + + def export_lang(self, lang_inputs, export_dir): + self.num_layers = num_hidden_layers = self.model.config.get_text_config().num_hidden_layers + + lang_inputs["position_ids"] = torch.where( + lang_inputs.pop("attention_mask") == 1, + torch.arange(lang_inputs["input_ids"].shape[1]).view(1, -1), + -1, + ) + + lang_inputs["past_key_values"] = QEffDynamicCache(num_hidden_layers) + lang_inputs["past_key_values"].key_cache = [0] * num_hidden_layers + lang_inputs["past_key_values"].value_cache = [0] * num_hidden_layers + + for i in range(num_hidden_layers): + if i in self.vision_encoder.cross_attention_layers: + idx = self.vision_encoder.cross_attention_layers.index(i) + assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}" + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 6404, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 6404, 128)) + else: + lang_inputs["past_key_values"].key_cache[i] = torch.zeros((1, 8, 1024, 128)) + lang_inputs["past_key_values"].value_cache[i] = torch.zeros((1, 8, 1024, 128)) + + lang_inputs["position_ids"] = torch.full((1, 1), lang_inputs["past_key_values"].key_cache[0].shape[2] - 1) + lang_output_names = ["logits", "past_key_values"] + pkv_idx = lang_output_names.index("past_key_values") + + lang_output_names[pkv_idx : pkv_idx + 1] = [ + f"past_{kv}.{i}_RetainedState" for i in range(num_hidden_layers) for kv in ["key", "value"] + ] + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "cross_attention_mask": { + 0: "batch_size", + 1: "seq_len", + 2: "max_num_images", + 3: "max_image_tiles", + }, + } + + for i in range(num_hidden_layers): + if i in self.vision_encoder.cross_attention_layers: + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size"} + continue + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() + lang_inputs["input_ids"] = torch.tensor([[374]]) + lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:] + self.lang_output_names = lang_output_names + model = self.model + self.model = ModelWrapper(model) + + self.lang_onnx_path = self._export(lang_inputs, lang_output_names, lang_dynamic_axes, export_dir=export_dir) + self.model = model + return self.lang_onnx_path + + def compile( + self, + vision_onnx_path: Optional[str] = None, + lang_onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + prefill_seq_len: int = 32, + ctx_len: int = 128, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + self.kv_offload = True + if self.kv_offload: + model = self.model + self.model = VisionEncoder(model) + vision_specializations = [{"batch_size": "1", "max_num_images": "1", "max_image_tiles": "4"}] + + custom_io = {} + kv_cache_dtype = "float16" + custom_io["pixel_values"] = kv_cache_dtype + for output_name in self.vision_output_names: + custom_io[output_name] = kv_cache_dtype + + model = self.model + self.model = self.vision_encoder + print("compiling vision model") + self.vision_qpc_path = self._compile( + self.vision_onnx_path, + compile_dir, + compile_only=True, + specializations=vision_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io, + **compiler_options, + ) + self.model = ModelWrapper(model) + + lang_specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": "1", + "max_image_tiles": "4", + }, + ] + + custom_io_lang = {} + # Inputs + for output_name in self.lang_output_names: + if output_name.startswith("past_"): + custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype + # outputs + for output_name in self.lang_output_names: + if output_name.startswith("past_"): + custom_io_lang[output_name] = kv_cache_dtype + + # custom_io = {} + # kv_cache_dtype = "float16" + # custom_io["pixel_values"] = kv_cache_dtype + # custom_io["pixel_values_RetainedState"] = kv_cache_dtype + # for suffix in ["", "_RetainedState"]: + # for i in range(self.num_layers): + # for kv in ["key", "value"]: + # custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + + print("generating lang model") + self.lang_qpc_path = self._compile( + self.lang_onnx_path, + compile_dir, + compile_only=True, + specializations=lang_specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_lang, + **compiler_options, + ) + self.model = model + return self.vision_qpc_path, self.lang_qpc_path + + def generate( + self, + inputs: torch.Tensor, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + runtime_ai100: bool = True, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + # AI_100 runtime + if runtime_ai100: + # if not isinstance(self.qpc_path, Path): + # raise TypeError("Please run compile API first!") + if self.kv_offload: + self.kv_offload_generate(inputs, streamer, device_ids) + else: + return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) + + # TODO: Add the code based on how we did in single inference script + def cloud_ai_100_vlm_generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = [0], + ) -> np.ndarray: + """ + Generates features with list of prompts using AI 100 runtime. + + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``Optional`` Args: + device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + + Returns: + np.ndarray: A list of dictionaries containing the generated output features. + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + self.seq_len = self.qpc_session.bindings[0].dims[1] + # Skip inputs/outputs + self.qpc_session.skip_buffers( + [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] + + ["pixel_values_RetainedState", "image_sizes_RetainedState"] + ) + + # Read prompt and ctx len from session + # batch_size = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] + # ) + + # prefill_seq_len = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] + # ) + # Prepare input + input_ids_len = inputs["input_ids"].shape[1] + input_ids = np.array( + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + ) + attention_mask = np.array( + torch.nn.functional.pad( + inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 + ) + ) + + inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + + outputs = { + "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( + np.float32 + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + outputs = outputs["output"][:, :input_ids_len, :] + return outputs + + def pytorch_vlm_generate( + self, + model, + inputs: Union[torch.Tensor, np.ndarray], + streamer: TextStreamer, + ) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :streamer (TextStreamer): A TextStreamer object used for streaming the generated text. + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + # inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + # inputs["past_key_values"] = [] + # for _ in range(model.config.num_hidden_layers): + # inputs["past_key_values"].append(( + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len,self.head_dim), + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len, self.head_dim), + # )) + self.batch_size = inputs["input_ids"].shape[0] + generation_len = self.ctx_len - inputs["input_ids"].shape[1] + generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) + + outputs = model(**inputs) + + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + streamer.put(inputs["input_ids"]) + + for _ in range(generation_len): + outputs = model(**inputs) + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] += 1 + streamer.put(inputs["input_ids"]) + generated_ids[:, _] = inputs["input_ids"].squeeze(1) + generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(self.batch_size): + print(i, generated_texts[i]) + + return generated_ids + + def kv_offload_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_id: List[int] = None, + generation_len: int = None, + stream: bool = True, + **kwargs, + ): + # self.lang_qpc_path = Path( + # "/home/rishinr/vision/vision_infra/llama-vision/qpc/Llama-3.2-11B-Vision-Instruct-language" + # ) + # self.vision_qpc_path = Path( + # "/home/rishinr/vision/vision_infra/llama-vision/qpc/Llama-3.2-11B-Vision-Instruct-vision" + # ) + # self.lang_qpc_path = Path( + # "/home/rishinr/.cache/qeff_models/mllama_bc/ModelWrapper-e34b1a9bd1cf14cb/qpc-0fd0400e8969c49e/qpc" + # ) + # self.vision_qpc_path = Path( + # "/home/rishinr/.cache/qeff_models/mllama_bc/VisionEncoder-e34b1a9bd1cf14cb/qpc-b4c5b2ba8c79d148/qpc" + # ) + + lang_session = QAICInferenceSession(self.lang_qpc_path, device_id, activate=False) + vision_session = QAICInferenceSession(self.vision_qpc_path, device_id) + + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_qpc_path) + + tokenizer = self.processor.tokenizer + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + if streamer is None: + streamer = TextStreamer(tokenizer) + + # Skip inputs/outputs + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), tokenizer.pad_token_id) + + # Prepare inputs for prefill + start = perf_counter() + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_outputs = vision_session.run(dict(vision_inputs)) + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + lang_inputs = dict(lang_inputs) + + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + # Run prefill + for i in range(num_chunks): + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + + # Skip inputs/outputs again + lang_session.skip_buffers( + [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] + ) + + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = input_len + lang_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"][:, -1:, :, :] + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + finished_sequences = lang_inputs["input_ids"] == tokenizer.eos_token_id + if stream: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + loop_start = perf_counter() + for num_token in range(1, generation_len): + outputs = lang_session.run(lang_inputs) + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + finished_sequences |= lang_inputs["input_ids"] == tokenizer.eos_token_id + + if stream: + streamer.put(lang_inputs["input_ids"][0]) + if finished_sequences.all(): + break + + end = perf_counter() + if stream: + streamer.end() + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(1 if stream else 0, batch_size): + print(i, generated_texts[i]) + + prefill_perf = 1 / (loop_start - start) + decode_perf = (num_token - 1) / (end - loop_start) + total_perf = num_token / (end - start) + + print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) + print("E2ET:", round(end - start, 2), "s", file=sys.stderr) + print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) + print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) + print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) + if batch_size > 1: + print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) + print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 6b8d00689..c3ad99f85 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -69,11 +69,13 @@ from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, MllamaForCausalLM, + MllamaRotaryEmbedding, MllamaSelfAttentionDecoderLayer, MllamaTextCrossAttention, MllamaTextModel, MllamaTextRMSNorm, MllamaTextSelfAttention, + MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel @@ -165,10 +167,12 @@ from QEfficient.transformers.models.mllama.modeling_mllama import ( QEffMllamaCrossAttentionDecoderLayer, QEffMllamaForCausalLM, + QEffMllamaRotaryEmbedding, QEffMllamaSelfAttentionDecoderLayer, QEffMllamaTextCrossAttention, QEffMllamaTextModel, QEffMllamaTextSelfAttention, + QEffMllamaVisionModel, ) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, @@ -256,10 +260,12 @@ class KVCacheTransform(ModuleMappingTransform): # mllama MllamaForCausalLM: QEffMllamaForCausalLM, MllamaTextModel: QEffMllamaTextModel, + MllamaVisionModel: QEffMllamaVisionModel, MllamaTextSelfAttention: QEffMllamaTextSelfAttention, MllamaTextCrossAttention: QEffMllamaTextCrossAttention, MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, @@ -343,4 +349,4 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: f"model class {model_class} does not yet support returning multiple logits to keep." ) - return model, transformed + return model, transformed \ No newline at end of file diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index ab861a788..004d2406c 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,6 +49,11 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_MAX_NUM_IMAGES =1 +ONNX_EXPORT_MAX_IMAGE_TILES = 4 +ONNX_EXPORT_IMAGE_WIDTH = 560 +ONNX_EXPORT_IMAGE_LENGHT = 560 +ONNX_EXPORT_IMAGE_DEPTH =3 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] diff --git a/pyproject.toml b/pyproject.toml index 9867181ca..e04bba103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.11" dependencies = [ - "transformers==4.45.2", + "transformers==4.46.0", "huggingface-hub==0.27.0", "peft==0.13.2", "datasets==2.20.0", @@ -32,6 +32,7 @@ dependencies = [ "numpy==1.26.4", "protobuf==3.20.2", "onnxscript==0.1.0.dev20240327", + "pillow===11.1.0", "sympy", "tensorboard", "fire",