diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 260a32f1d..424e9fc4c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,22 +1,160 @@ ## Contributing to PROJECT Hi there! -We’re thrilled that you’d like to contribute to this project. +We're thrilled that you'd like to contribute to this project. Your help is essential for keeping this project great and for making it better. -## Branching Strategy -In general, contributors should develop on branches based off of `main` and pull requests should be made against `main`. +## Submitting Your Contribution -## Submitting a pull request +Follow these steps to submit your example to the QEfficient repository: 1. Please read our [code of conduct](CODE-OF-CONDUCT.md) and [license](LICENSE). -1. Fork and clone the repository. -1. Create a new branch based on `main`: `git checkout -b main`. -1. Make your changes, add tests, and make sure the tests still pass. -1. Commit your changes using the [DCO](http://developercertificate.org/). You can attest to the DCO by commiting with the **-s** or **--signoff** options or manually adding the "Signed-off-by". -1. Push to your fork and submit a pull request from your branch to `main`. -1. Pat yourself on the back and wait for your pull request to be reviewed. + +### 1. Fork and Clone the Repository + +First, fork the repository to your GitHub account, then clone your fork: + +```bash +# Fork the repository on GitHub (click the "Fork" button) +# Then clone your fork +git clone git@github.com:YOUR_USERNAME/efficient-transformers.git +cd efficient-transformers + +# Add upstream remote to keep your fork in sync +git remote add upstream git@github.com:quic/efficient-transformers.git +``` + +### 2. Create a Feature Branch + +Create a descriptive branch for your changes: + +```bash +# Update your main branch +git checkout main +git pull upstream main + +# Create a new branch +git checkout -b +``` + +### 3. Make Your Changes + +When making changes to the codebase: + +- **Follow Existing Design Patterns** + - Review similar implementations before creating new code + - Maintain consistency with the project's architecture and coding style + - Reuse existing utilities and base classes where applicable + +- **Onboarding New Models** + - For adding new model support, refer to the comprehensive guide: `examples/onboarding_guide/causallm/` + - Follow the step-by-step process with code examples provided + +- **Testing is Mandatory** + - Add tests for all new features in the appropriate `tests/` subdirectory + - Run tests locally before pushing: `pytest tests/path/to/your/test.py -v` + - For model additions, verify all 4 pipeline stages (PyTorch HF → KV → ORT → AI 100) and make sure tokens are matching with refernce PyTorch HF + +- **Documentation** + - **For New Features/Flags:** + - Document usage in `docs/source/` with feature description and usage examples + - Ensure documentation is clear enough for others to understand and use the feature + - **For New Models:** + - Test with basic inference scripts in the `examples/` folder + - If specific changes are needed, create a dedicated example file + - Update `docs/source/validate.md` with the model's HuggingFace card name and relevant details + + +- **Code Quality Checks** + - Pre-commit hooks, DCO sign-off, and CI checks are covered in the following steps + - Ensure you complete steps 4-8 before finalizing your PR + +### 4. Run Pre-commit Checks + +Before committing, ensure your code passes all quality checks: + +```bash +# Install pre-commit and ruff if not already installed +pip install pre-commit +pip install ruff + +# Run pre-commit on your changed files +pre-commit run --files path/to/your/file1.py path/to/your/file2.py + +# Run Ruff check +ruff check +``` + +**Important:** If pre-commit reports any failures: +- Some issues will be auto-fixed (formatting, trailing whitespace, etc.) +- For issues that aren't auto-fixed, manually correct them +- Re-run `pre-commit run --files ` or `ruff check` until all checks pass + +### 5. Commit with Sign-off (DCO) + +All commits must be signed off to comply with the Developer Certificate of Origin (DCO): + +```bash +# Stage your changes +git add examples/your_domain/your_example.py +git add examples/your_domain/README.md + +# Commit with sign-off +git commit -s --author "Your Name " -m "Add [model-name] support + +- Implements inference for [model-name] +- Includes documentation and usage examples +- Tested with [specific configurations]" +``` + +**Commit Message Guidelines:** +- Use a clear, descriptive title +- Add a blank line, then detailed description if needed +- Always include the `-s` flag for DCO sign-off + +### 6. Push to Your Fork + +Push your branch to your forked repository: + +```bash +git push origin +``` + +### 7. Create a Pull Request + +1. Go to your fork on GitHub +2. Click "Compare & pull request" for your branch +3. Fill out the PR template with: + - **Title:** Clear, descriptive title (e.g., "Add Llama-3.2-Vision Support" or "Fix memory leak in KV cache") + - **Description:** + - What changes were made and why + - What problem it solves or feature it adds + - Any special considerations or breaking changes + - Links to relevant documentation, issues, or model cards (if applicable) + - **Testing:** Describe how you tested your changes + +### 8. Ensure CI Checks Pass + +After creating the PR, verify that all automated checks pass: + +- ✅ **DCO Check:** Ensures all commits are signed off +- ✅ **Lint Check:** Code style and formatting validation +- ✅ **Tests:** Automated test suite (if applicable) + +If any checks fail: +1. Review the error messages in the PR +2. Make necessary fixes in your local branch +3. Commit and push the fixes (with sign-off) +4. The PR will automatically update and re-run checks + +### 9. Address Review Feedback + +Maintainers will review your PR and may request changes: +- Make requested changes in your local branch +- Commit with sign-off and push to update the PR +- Respond to comments to facilitate discussion + Here are a few things you can do that will increase the likelihood of your pull request to be accepted: diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 33c6f5588..3c9f68efd 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,23 +6,64 @@ # ----------------------------------------------------------------------------- import os -import warnings - -import QEfficient.utils.model_registery # noqa: F401 -from QEfficient.utils import custom_format_warning -from QEfficient.utils.logging_utils import logger +# ----------------------------------------------------------------------------- # # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# DO NOT ADD ANY CODE ABOVE THIS LINE +# Please contact maintainers if you must edit this file above this line. +# ----------------------------------------------------------------------------- # # Placeholder for all non-transformer models registered in QEfficient +import warnings # noqa: I001 +import QEfficient.utils.model_registery # noqa: F401 +from QEfficient.base import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForCTC, + QEFFAutoModelForImageTextToText, + QEFFAutoModelForSpeechSeq2Seq, + QEFFCommonLoader, +) +from QEfficient.compile.compile_helper import compile +from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline +from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline +from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter +from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv +from QEfficient.peft import QEffAutoPeftModelForCausalLM +from QEfficient.transformers.transform import transform +from QEfficient.utils import custom_format_warning +from QEfficient.utils.logging_utils import logger # custom warning for the better logging experience warnings.formatwarning = custom_format_warning +# Users can use QEfficient.export for exporting models to ONNX +export = qualcomm_efficient_converter +__all__ = [ + "transform", + "export", + "compile", + "cloud_ai_100_exec_kv", + "QEFFAutoModel", + "QEFFAutoModelForCausalLM", + "QEFFAutoModelForCTC", + "QEffAutoPeftModelForCausalLM", + "QEFFAutoModelForImageTextToText", + "QEFFAutoModelForSpeechSeq2Seq", + "QEFFCommonLoader", + "QEffFluxPipeline", + "QEffWanPipeline", +] + + +# Conditionally import QAIC-related modules if the SDK is installed +__version__ = "0.0.1.dev0" + + def check_qaic_sdk(): """Check if QAIC SDK is installed""" try: @@ -37,40 +78,5 @@ def check_qaic_sdk(): return False -# Conditionally import QAIC-related modules if the SDK is installed -__version__ = "0.0.1.dev0" - -if check_qaic_sdk(): - from QEfficient.base import ( - QEFFAutoModel, - QEFFAutoModelForCausalLM, - QEFFAutoModelForCTC, - QEFFAutoModelForImageTextToText, - QEFFAutoModelForSpeechSeq2Seq, - QEFFCommonLoader, - ) - from QEfficient.compile.compile_helper import compile - from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter - from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv - from QEfficient.peft import QEffAutoPeftModelForCausalLM - from QEfficient.transformers.transform import transform - - # Users can use QEfficient.export for exporting models to ONNX - export = qualcomm_efficient_converter - - __all__ = [ - "transform", - "export", - "compile", - "cloud_ai_100_exec_kv", - "QEFFAutoModel", - "QEFFAutoModelForCausalLM", - "QEFFAutoModelForCTC", - "QEffAutoPeftModelForCausalLM", - "QEFFAutoModelForImageTextToText", - "QEFFAutoModelForSpeechSeq2Seq", - "QEFFCommonLoader", - ] - -else: +if not check_qaic_sdk(): logger.warning("QAIC SDK is not installed, eager mode features won't be available!") diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..b5c838a94 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -18,7 +18,10 @@ import onnx import torch -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import ( + BaseOnnxTransform, + OnnxTransformPipeline, +) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession @@ -27,11 +30,11 @@ create_json, create_model_params, dump_qconfig, - export_wrapper, generate_mdp_partition_config, hash_dict_params, load_json, ) +from QEfficient.utils.export_utils import export_wrapper logger = logging.getLogger(__name__) @@ -47,7 +50,7 @@ class QEFFBaseModel(ABC): """ _pytorch_transforms: List[PytorchTransform] - _onnx_transforms: List[OnnxTransform] + _onnx_transforms = [BaseOnnxTransform] @classmethod def _transform_names(cls) -> List[str]: @@ -57,6 +60,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -78,26 +82,26 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: else: logger.info(f"Pytorch transforms applied to model: {self.model_name}") - def _offload_model_weights(self, offload_pt_weights) -> bool: - """ - Clear PyTorch weights after export if offload_pt_weights is set to True - - Returns: - bool: True if weights were successfully offloaded, False otherwise - """ - # Check if offloading is enabled and weights are not already offloaded + def _offload_model_weights(self, offload_pt_weights: bool) -> bool: + """Clear PyTorch model weights to reduce memory usage after ONNX export.""" if offload_pt_weights and not self._is_weights_offloaded: try: - self.model = self.model.to_empty(device="meta") - self._is_weights_offloaded = True - logger.info("Model weights offloaded to meta device") - + for param in self.model.parameters(): + if param.storage(): + param.storage().resize_(0) + for buffer in self.model.buffers(): + if buffer.storage(): + buffer.storage().resize_(0) + + meta_model = self.model.to("meta") + del self.model gc.collect() - logger.info("PyTorch weights cleared after export") - return True + self.model = meta_model + self._is_weights_offloaded = True + return True except Exception as e: - logger.error(f"Failed to offload model weights: {e}") + logger.warning(f"Weight clearing failed, continuing: {e}") return False return False @@ -116,9 +120,35 @@ def _model_offloaded_check(self) -> None: logger.error(error_msg) raise RuntimeError(error_msg) + @property + def model_name(self) -> str: + """ + Get the model class name without QEff/QEFF prefix. + + This property extracts the underlying model's class name and removes + any QEff or QEFF prefix that may have been added during wrapping. + + Returns: + str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel") + """ + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + @property @abstractmethod - def model_name(self) -> str: ... + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + This is an abstract property that must be implemented by all subclasses. + Typically returns: self.model.config.__dict__ + + Returns: + Dict: The configuration dictionary of the underlying model + """ + pass @abstractmethod def export(self, export_dir: Optional[str] = None) -> Path: @@ -175,10 +205,11 @@ def _export( example_inputs: Dict[str, torch.Tensor], output_names: List[str], dynamic_axes: Dict[str, Dict[int, str]], - export_kwargs: Optional[Dict[str, any]] = None, onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + prefill_only: Optional[bool] = False, + **export_kwargs, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -203,11 +234,16 @@ def _export( instance using from_pretrained() for re-export. """ + # TODO: Hack for retain_full_kv, handle this outside + export_kwargs.pop("retain_full_kv", None) onnx_path = export_dir / f"{self.model_name}.onnx" # Return early if ONNX already exists if onnx_path.is_file(): - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path # check if the model is in meta state or weights are offloaded @@ -243,7 +279,6 @@ def _export( input_names.append(param) try: - export_kwargs = {} if export_kwargs is None else export_kwargs torch.onnx.export( self.model, (example_inputs,), @@ -255,10 +290,10 @@ def _export( **export_kwargs, ) logger.info("PyTorch export successful") - _ = self._offload_model_weights(offload_pt_weights) - model = onnx.load(tmp_onnx_path, load_external_data=False) + + # Clear temporary references transform_kwargs = { "onnx_base_dir": str(tmp_onnx_dir), "model_name": self.model_name, @@ -266,15 +301,18 @@ def _export( if onnx_transform_kwargs is not None: transform_kwargs.update(onnx_transform_kwargs) - for transform in self._onnx_transforms: - model, transformed = transform.apply(model, **transform_kwargs) + onnx_transforms = OnnxTransformPipeline(transforms=self._onnx_transforms) + model, transformed = onnx_transforms.apply(model, **transform_kwargs) + # Add metadata to the model model.metadata_props.append( onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names())) ) logger.info("ONNX transforms applied") onnx.save(model, onnx_path) + del model + gc.collect() logger.info("Transformed ONNX saved") except Exception as e: @@ -284,9 +322,42 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path + def get_onnx_path( + self, + prefill_only: Optional[bool] = False, + enable_chunking: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + use_onnx_subfunctions: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + kwargs = { + "offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions, + "retain_full_kv": retain_full_kv, + } + if prefill_only: + if self.prefill_onnx_path is None: + kwargs.update( + { + "prefill_only": prefill_only, + "prefill_seq_len": specializations[0].get("seq_len"), + "enable_chunking": enable_chunking, + } + ) + self.export(**kwargs) + return self.prefill_onnx_path + else: + if self.onnx_path is None: + self.export(**kwargs) + return self.onnx_path + @dump_qconfig def _compile( self, @@ -300,6 +371,11 @@ def _compile( num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + use_onnx_subfunctions: bool = False, + prefill_only: Optional[str] = None, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -325,10 +401,18 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - if onnx_path is None and self.onnx_path is None: - self.export() - - onnx_path = Path(onnx_path or self.onnx_path) + onnx_path = Path( + onnx_path + if onnx_path + else self.get_onnx_path( + prefill_only, + enable_chunking, + specializations, + offload_pt_weights, + use_onnx_subfunctions, + retain_full_kv, + ) + ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): @@ -383,6 +467,10 @@ def _compile( else: mdp_ts_json = None + if use_onnx_subfunctions: + logger.info("Using ONNX subfunctions for compilation.") + command.append("-sub-functions") + compile_hash_params = { "command": command, "specializations": specializations, @@ -390,6 +478,7 @@ def _compile( "mdp_ts_num_devices": mdp_ts_num_devices, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, + "prefill_only": prefill_only, } compile_hash = hash_dict_params(compile_hash_params) @@ -429,6 +518,7 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: @@ -449,5 +539,4 @@ def _compile( logger.info("Hashed parameters exported successfully.") self.qpc_path = qpc_path - return qpc_path diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 61b5c00f6..16697cec9 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -5,97 +5,273 @@ # # ---------------------------------------------------------------------------- -from typing import Optional, Tuple +import logging +import os +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np -from onnx import ModelProto, external_data_helper, numpy_helper +import onnx +import torch +from onnx import ModelProto, TensorProto, external_data_helper, numpy_helper +from QEfficient.customop.ctx_scatter_gather import ( + CtxGather, + CtxGather3D, + CtxGatherBlockedKV, + CtxGatherFunc, + CtxGatherFunc3D, + CtxGatherFuncBlockedKV, + CtxScatter, + CtxScatter3D, + CtxScatterFunc, + CtxScatterFunc3D, +) +from QEfficient.customop.ctx_scatter_gather_cb import ( + CtxGatherBlockedKVCB, + CtxGatherCB, + CtxGatherCB3D, + CtxGatherFuncBlockedKVCB, + CtxGatherFuncCB, + CtxGatherFuncCB3D, + CtxScatterCB, + CtxScatterCB3D, + CtxScatterFuncCB, + CtxScatterFuncCB3D, +) +from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc +from QEfficient.utils.constants import FILE_CHUNK_SIZE_DEFAULT, ONNX_EXPORT_OPSET, SIZE_THRESHOLD_DEFAULT -class OnnxTransform: - """ - OnnxTransform is the base class for graph modifications on exported onnx. - """ +logger = logging.getLogger(__name__) + + +class BaseOnnxTransform: + """Base class for ONNX graph modifications. Should NOT be instantiated.""" def __init__(self): - raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") + raise TypeError("Transform classes are not to be instantiated. Use the `apply` method directly.") @classmethod def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: - """ - Override this class to apply a transformation. - :param model: The model's ONNX graph to transform - :param kwargs: Parameters needed for specific transforms. All transforms should take **kwargs to ignore unneeded kwargs. - - :returns: ONNX graph after applying the transform - :returns: Boolean indicating whether transform was applied - """ raise NotImplementedError("Use subclasses for ONNX transform") -class FP16ClipTransform(OnnxTransform): - """ - Clips the tensor values to be in FP16 range, but preserves -inf values. - """ +class FP16ClipTransform(BaseOnnxTransform): + """Clip FP32 tensors to FP16 range to avoid overflow during conversion.""" @classmethod - def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: - """ - :param onnx_base_dir: Base directory to load tensors - """ - finfo = np.finfo(np.float16) - fp16_max = finfo.max - fp16_min = finfo.min - transformed = False + def apply(cls, tensor: TensorProto, onnx_base_dir: str, fp16_max: float, fp16_min: float) -> bool: + nptensor = numpy_helper.to_array(tensor, onnx_base_dir) + if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): + neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) + clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) + + if neg_inf_mask.any(): + clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) + + tensor.CopyFrom(numpy_helper.from_array(clipped_tensor, tensor.name)) + return True + return False + + +class SplitTensorsTransform(BaseOnnxTransform): + """Split large tensors into external data files for efficient storage.""" + + @classmethod + def apply( + cls, tensor: TensorProto, model_name: str, file_num: int, mapping: Dict[str, Tuple[TensorProto, str]] + ) -> None: + file_name = f"{model_name}_{file_num}.onnx.data" + mapping[tensor.name] = (tensor, file_name) + + +class CustomOpTransform(BaseOnnxTransform): + """Register custom ONNX ops and append their function prototypes to the model.""" - for tensor in external_data_helper._get_all_tensors(model): - nptensor = numpy_helper.to_array(tensor, onnx_base_dir) - if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): - neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) - clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) + _custom_ops: Dict[str, Tuple[Any, Any]] = { + "CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm), + "CtxScatterFunc": (CtxScatterFunc, CtxScatter), + "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), + "CtxGatherFunc": (CtxGatherFunc, CtxGather), + "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), + "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), + "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), + "CtxGatherFuncBlockedKV": (CtxGatherFuncBlockedKV, CtxGatherBlockedKV), + "CtxGatherFuncBlockedKVCB": (CtxGatherFuncBlockedKVCB, CtxGatherBlockedKVCB), + "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), + "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), + } - # Restore -inf values - if neg_inf_mask.any(): - clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) + @classmethod + def apply(cls, model: ModelProto) -> bool: + op_applied = False + for op_name, (func_class, _) in cls._custom_ops.items(): + if hasattr(func_class, "symbolic"): + torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET) + + existing = {f.name for f in model.functions} + for _, onnxscript_func in cls._custom_ops.values(): + proto = onnxscript_func.to_function_proto() + if proto.name not in existing: + model.functions.append(proto) + op_applied = True + return op_applied + + +class RenameFunctionOutputsTransform(BaseOnnxTransform): + """Rename outputs of decoder-related functions for better clarity.""" + + @classmethod + def apply(cls, model: ModelProto) -> bool: + graph = model.graph + op_type_to_func = {f.name: f for f in model.functions} + decoder_patterns = ["DecoderLayer", "Block", "Layer"] + renamed = False + model_out_map = {v.name: i for i, v in enumerate(graph.output)} + layer_idx = 0 + + for node in graph.node: + if any(p in node.name or p in node.op_type for p in decoder_patterns): + func = op_type_to_func.get(node.op_type) + if not func: + continue + for i, out_name in enumerate(func.output): + if "_InternalRetainedState" in out_name: + renamed = True + orig = node.output[i] + new = ( + f"past_key.{layer_idx}_RetainedState" + if "key" in out_name + else f"past_value.{layer_idx}_RetainedState" + if "value" in out_name + else orig + ) + node.output[i] = new + if orig in model_out_map: + graph.output[model_out_map[orig]].name = new + layer_idx += 1 + return renamed + + +class AdapterWeightsToInputsTransform(BaseOnnxTransform): + @classmethod + def apply(cls, model: onnx.ModelProto, *, adapter_name: str, **kwargs) -> Tuple[onnx.ModelProto, bool]: + transformed = False + removed_initializers = [] - new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) - tensor.CopyFrom(new_tensor) + # Find nodes with lora weights as inputs + weight_suffix = f".{adapter_name}.weight" + lora_weight_nodes = { + inp: node for node in model.graph.node for inp in node.input if inp.endswith(weight_suffix) + } + + for i, weight in enumerate(model.graph.initializer): + if weight.name.endswith(weight_suffix): transformed = True + # Create input/output for lora weights + new_weight_name = weight.name[: -len(weight_suffix)] + ".weight" + type_proto = onnx.helper.make_tensor_type_proto(weight.data_type, shape=list(weight.dims)) + inp = onnx.ValueInfoProto(name=new_weight_name, type=type_proto) + out = onnx.ValueInfoProto(name=new_weight_name + "_RetainedState", type=type_proto) + model.graph.input.append(inp) + model.graph.output.append(out) + + # Create a node that connects input -> output + node = onnx.helper.make_node("Identity", [inp.name], [out.name], new_weight_name + "_identity") + model.graph.node.append(node) + + # Rename weight input + lora_weight_node = lora_weight_nodes[weight.name] + for j, inp in enumerate(lora_weight_node.input): + if inp == weight.name: + lora_weight_node.input[j] = new_weight_name + + # Remove weight initializers + removed_initializers.append(i) + + if transformed: + for i in sorted(removed_initializers, reverse=True): + model.graph.initializer.pop(i) + return model, transformed -class SplitTensorsTransform(OnnxTransform): - """ - Split external tensors file - """ +class OnnxTransformPipeline(BaseOnnxTransform): + """Pipeline to apply multiple ONNX transformations in sequence.""" + + def __init__(self, transforms: List[Type[BaseOnnxTransform]]): + if not transforms: + warnings.warn("Transform list is empty. No transformations will be applied.") + self.transforms = transforms - @classmethod def apply( - cls, + self, model: ModelProto, *, - model_name: str, + model_name: str = "", onnx_base_dir: Optional[str] = None, - file_chunk_size: int = 10 * 2**30, # 10 GiB - size_threshold: int = 1024, + file_chunk_size: int = FILE_CHUNK_SIZE_DEFAULT, + size_threshold: int = SIZE_THRESHOLD_DEFAULT, **kwargs, ) -> Tuple[ModelProto, bool]: - """ - :param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data - :param onnx_base_dir: Base directory to load tensors (if not already loaded). - :param file_chunk_size: Chunk size to split external files into. - :param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally. - """ - file_num = 0 - current_file_size = 0 - transformed = False + if not self.transforms: + return model, False + + # Same logic as before, but replace `transforms` with `self.transforms` + mapping: Dict[str, Tuple[TensorProto, str]] = {} + requested = set(self.transforms) + applied = {t: False for t in requested} + f16_applied = False + do_fp16 = FP16ClipTransform in requested + do_split = SplitTensorsTransform in requested + fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max + file_num_tracker = {"num": 0, "size": 0} external_data_helper.load_external_data_for_model(model, onnx_base_dir) - for tensor in external_data_helper._get_all_tensors(model): - if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): - transformed = True - current_file_size += tsize - if current_file_size > file_chunk_size: - file_num += 1 - current_file_size = tsize - external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") - return model, transformed + + if do_fp16 or do_split: + for tensor in external_data_helper._get_all_tensors(model): + if do_fp16 and FP16ClipTransform.apply(tensor, onnx_base_dir, fp16_max, fp16_min): + f16_applied = True + applied[FP16ClipTransform] = f16_applied + + if do_split and tensor.HasField("raw_data"): + tsize = len(tensor.raw_data) + if tsize > size_threshold: + if file_num_tracker["size"] + tsize > file_chunk_size: + file_num_tracker["num"] += 1 + file_num_tracker["size"] = tsize + else: + file_num_tracker["size"] += tsize + applied[SplitTensorsTransform] = True + SplitTensorsTransform.apply(tensor, model_name, file_num_tracker["num"], mapping) + + def _set_external_data(tensor, file_name): + external_data_helper.set_external_data(tensor, file_name) + + max_workers = min(32, (os.cpu_count() or 1) * 4) + logger.info(f"Applying external data mapping with {max_workers} threads") + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_set_external_data, tensor, file_name) for tensor, file_name in mapping.values()] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Failed to set external data: {e}") + + # Non-looping transforms + if CustomOpTransform in requested: + applied[CustomOpTransform] = CustomOpTransform.apply(model) + + if RenameFunctionOutputsTransform in requested: + applied[RenameFunctionOutputsTransform] = RenameFunctionOutputsTransform.apply(model) + + if AdapterWeightsToInputsTransform in requested: + applied[AdapterWeightsToInputsTransform] = AdapterWeightsToInputsTransform.apply(model, **kwargs) + + for t, done in applied.items(): + logger.info(f"Transform '{t.__name__}' applied={done}") + + return model, any(applied.values()) diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index a20fc4cb3..e503a057f 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: class SplitGateUpWeightsTransform(PytorchTransform): """ - split fused Gate+Up weights and copy into the model + Split fused Gate+Up weights and copy into the model. + Handles both standard MoE models and GptOss models. For every transformer layer inside `model`: - • expects .experts.gate_up_proj in the *source* `sd` - • copies halves into - .experts.gate_proj <-- Gate [E,H,I] - .experts.up_proj <-- Up [E,H,I] + • expects .experts.gate_up_proj in the *source* `sd` + • copies halves into + .experts.gate_proj <-- Gate [E,H,I] + .experts.up_proj <-- Up [E,H,I] + + Handles both interleaved weights (GptOss) and concatenated weights (standard MoE). + Also handles bias terms when present. """ @classmethod def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ - if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: return model, transformed model_tmp = model.language_model if hasattr(model, "language_model") else model - num_layers = len(model_tmp.model.layers) delete_fused_key = True sd = model_tmp.state_dict() + for layer_idx in range(num_layers): + # Determine if this is a GptOss model or standard MoE model + is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp") + # ---- build the textual prefix once per layer ---------- - prefix = f"model.layers.{layer_idx}.feed_forward.experts." + if is_gpt_oss: + prefix = f"model.layers.{layer_idx}.mlp.experts." + experts = model_tmp.model.layers[layer_idx].mlp.experts + else: + prefix = f"model.layers.{layer_idx}.feed_forward.experts." + experts = model_tmp.model.layers[layer_idx].feed_forward.experts fused_key = prefix + "gate_up_proj" gate_key = prefix + "gate_proj" up_key = prefix + "up_proj" - # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- - fused = sd[fused_key] # [E, H, 2I] (no .weight here) + # Check if we have bias terms (GptOss case) + has_bias = fused_key + "_bias" in sd + if has_bias: + fused_bias_key = fused_key + "_bias" + gate_bias_key = gate_key + "_bias" + up_bias_key = up_key + "_bias" + + # ---- split weights based on model type ---------------------- + fused = sd[fused_key] # [E, H, 2I] E, H, two_I = fused.shape - ffn_dim = two_I // 2 - gate, up = fused.split(ffn_dim, dim=-1) # views – no copy - experts = model_tmp.model.layers[layer_idx].feed_forward.experts + if is_gpt_oss: + # For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...] + gate = fused[..., ::2] # [E, H, I] - even indices + up = fused[..., 1::2] # [E, H, I] - odd indices + else: + # For standard MoE, gate/up are concatenated: [gate, up] + ffn_dim = two_I // 2 + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy + + # Copy weights to model experts.gate_proj.data.copy_(gate) experts.up_proj.data.copy_(up) + # Handle bias if present + if has_bias: + fused_bias = sd[fused_bias_key] # [E, 2I] + + if is_gpt_oss: + gate_bias = fused_bias[..., ::2] # [E, I] - even indices + up_bias = fused_bias[..., 1::2] # [E, I] - odd indices + else: + ffn_dim = fused_bias.shape[-1] // 2 + gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1) + + experts.gate_proj_bias.data.copy_(gate_bias) + experts.up_proj_bias.data.copy_(up_bias) + # ---- update the state-dict so load_state_dict sees the right keys sd[gate_key] = gate sd[up_key] = up + if has_bias: + sd[gate_bias_key] = gate_bias + sd[up_bias_key] = up_bias + + # Delete fused keys if delete_fused_key: del sd[fused_key] + if has_bias: + del sd[fused_bias_key] - logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") transformed = True if hasattr(model, "language_model"): model.language_model = model_tmp else: model = model_tmp + return model, transformed -VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"} +# Keep the existing list of supported models +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"} diff --git a/QEfficient/cloud/execute.py b/QEfficient/cloud/execute.py index 27ea529cd..09e989ea0 100644 --- a/QEfficient/cloud/execute.py +++ b/QEfficient/cloud/execute.py @@ -115,7 +115,7 @@ def main( "--prompts_txt_file_path", "--prompts-txt-file-path", type=str, - help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples/sample_prompts folder", ) parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") parser.add_argument( diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/cloud/finetune_experimental.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 814122b9d..ef05d29ab 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -340,6 +340,18 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")], + default=[512], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")], + default=[2048], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", @@ -378,7 +390,7 @@ def main( "--prompts_txt_file_path", "--prompts-txt-file-path", type=str, - help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples/sample_prompts folder", ) parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") parser.add_argument( diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index ff0709f82..35830aa91 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,8 +5,15 @@ # # ----------------------------------------------------------------------------- -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc, + CtxGatherFunc3D, + CtxGatherFuncBlockedKV, + CtxScatterFunc, + CtxScatterFunc3D, +) from QEfficient.customop.ctx_scatter_gather_cb import ( + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterFuncCB, @@ -16,12 +23,14 @@ __all__ = [ "CtxGatherFunc", + "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", "CtxScatterFunc3D", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", + "CtxGatherFuncBlockedKVCB", "CtxScatterFuncCB", "CtxGatherFuncCB3D", "CtxScatterFuncCB3D", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..7b15effe7 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -126,6 +132,33 @@ class CtxGatherFunc(torch.autograd.Function): Function to gather only the valid key values from KV-cache. """ + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): + batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: + ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) + return ops.GatherND(data, ctx_indices, batch_dims=2) + + +class CtxGatherFuncBlockedKV(torch.autograd.Function): + """ + Function to gather only the valid key values from KV-cache. + """ + @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) @@ -138,4 +171,4 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 75d9a12ef..c15b60810 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,11 +97,56 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + batch_size = ops.Gather(ops.Shape(batch_index), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + exp_shape = ops.Concat( + ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0 + ) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [3]), exp_shape) + indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) + + return ops.GatherND(data, indices) + + +class CtxGatherFuncCB(torch.autograd.Function): + @staticmethod + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): + batch_indices = batch_index.view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGatherBlockedKVCB( data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + ctx_len = ops.Gather(ops.Shape(ctx_indices), [2]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) @@ -117,7 +162,7 @@ def CtxGatherCB( return ops.GatherND(data, indices) -class CtxGatherFuncCB(torch.autograd.Function): +class CtxGatherFuncBlockedKVCB(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = batch_index.view(-1, 1, 1) @@ -130,7 +175,7 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + return g.onnxscript_op(CtxGatherBlockedKVCB, data, batch_index, ctx_indices).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md new file mode 100644 index 000000000..4777d48fb --- /dev/null +++ b/QEfficient/diffusers/README.md @@ -0,0 +1,95 @@ + +
+ + +# **Diffusion Models on Qualcomm Cloud AI 100** + + +
+ +### 🎨 **Experience the Future of AI Image Generation** + +* Optimized for Qualcomm Cloud AI 100* + +Sample Output + +**Generated with**: `black-forest-labs/FLUX.1-schnell` • `"A girl laughing"` • 4 steps • 0.0 guidance scale • ⚡ + + + +
+ + + +[![Diffusers](https://img.shields.io/badge/Diffusers-0.35.1-orange.svg)](https://github.com/huggingface/diffusers) +
+ +--- + +## ✨ Overview + +QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware. + +## 🛠️ Installation + +### Prerequisites + +Ensure you have Python 3.8+ and the required dependencies: + +```bash +# Create Python virtual environment (Recommended Python 3.10) +sudo apt install python3.10-venv +python3.10 -m venv qeff_env +source qeff_env/bin/activate +pip install -U pip +``` + +### Install QEfficient + +```bash +# Install from GitHub (includes diffusers support) +pip install git+https://github.com/quic/efficient-transformers + +# Or build from source +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install build wheel +python -m build --wheel --outdir dist +pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl +``` + +--- + +## 🎯 Supported Models +- ✅ [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +- ✅ [`lightx2v/Wan2.2-Lightning`](https://huggingface.co/lightx2v/Wan2.2-Lightning) + +--- + + +## 📚 Examples + +Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory: + +--- + +## 🤝 Contributing + +We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details. + + + +--- + +## 🙏 Acknowledgments + +- **HuggingFace Diffusers**: For the excellent foundation library +--- + +## 📞 Support + +- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/) +- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues) + +--- + diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/modeling_utils.py b/QEfficient/diffusers/models/modeling_utils.py new file mode 100644 index 000000000..59727be2d --- /dev/null +++ b/QEfficient/diffusers/models/modeling_utils.py @@ -0,0 +1,456 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import math +import os +from typing import Optional + +import torch + + +def get_attention_blocking_config(): + """ + Get attention blocking configuration from environment variables. + + Returns: + tuple: (blocking_mode, head_block_size, num_kv_blocks, num_q_blocks) + - blocking_mode (str): The blocking strategy ('kv', 'q', 'qkv', 'default') + - head_block_size (int or None): Number of attention heads per block + - num_kv_blocks (int or None): Number of key-value blocks + - num_q_blocks (int or None): Number of query blocks + """ + mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + head_block_size = int(os.environ.get("head_block_size", 0)) or None + num_kv_blocks = int(os.environ.get("num_kv_blocks", 0)) or None + num_q_blocks = int(os.environ.get("num_q_blocks", 0)) or None + + # Validate blocking mode + valid_modes = ["kv", "qkv", "q", "default"] + if mode not in valid_modes: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {mode}. Must be one of {valid_modes}") + + return mode, head_block_size, num_kv_blocks, num_q_blocks + + +def apply_head_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with head-only blocking (default mode). + + This method processes attention heads in blocks while computing full attention + matrices for each head block. It's less memory-efficient than other blocking + modes but simpler and faster for moderate sequence lengths. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get head blocking configuration + head_block_size = head_block_size or NH + num_head_blocks = math.ceil(NH / head_block_size) + + # Optimization: Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + outputs = [] + + # Process each head block independently + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + + # Extract head blocks + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + # Compute full attention matrix for this head block + qkblock = torch.matmul(q_g, k_g.transpose(-2, -1)) * scale_factor + + # Standard softmax computation + probs = torch.softmax(qkblock, dim=-1) + + # Compute attention output + output_blocks = torch.matmul(probs, v_g) + outputs.append(output_blocks) + + # Concatenate all head blocks along head dimension + out = torch.cat(outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_kv_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with Key-Value blocking and head blocking. + + This method processes key-value pairs in blocks while keeping queries intact. + It uses online softmax to maintain numerical stability and reduce memory usage + compared to computing full attention matrices. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration + head_block_size = head_block_size or NH + num_kv_blocks = num_kv_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)] + + # Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process each head block + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + num_h = h_end - h_start + + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + # Initialize online softmax statistics + running_exp_sum = torch.zeros((BS, num_h, CL), device=q.device, dtype=q.dtype) + running_max = torch.full((BS, num_h, CL), float("-inf"), device=q.device, dtype=q.dtype) + output_blocks = torch.zeros_like(q_g) + + # Process K,V in blocks using online softmax + for kv_block_idx in range(num_kv_blocks): + ki = block_positions[kv_block_idx] + + # Calculate KV block size + if kv_block_idx == num_kv_blocks - 1: + real_kv_len = CL - ki + else: + real_kv_len = block_positions[kv_block_idx + 1] - ki + + k_block = k_g[:, :, ki : ki + real_kv_len, :] + v_block = v_g[:, :, ki : ki + real_kv_len, :] + + # Compute attention scores for current KV block + qkblock = torch.matmul(q_g, k_block.transpose(-2, -1)) * scale_factor + + # Online softmax: Update running maximum + prev_max = running_max.clone() + running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0]) + + # Calculate numerical stability adjustments + delta_max = prev_max - running_max + curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1)) + + # Update running sum of exponentials + prev_exp_sum = running_exp_sum.clone() + curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp) + running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum + + # Compute normalized attention weights + inv_running_exp_sum = 1.0 / running_exp_sum + softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1) + + # Update output with rescaling + prev_out = output_blocks.clone() + rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max) + output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block) + + head_outputs.append(output_blocks) + + out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_q_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_q_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with Query blocking and head blocking. + + This method processes query tokens in blocks while keeping key-value pairs intact. + It's useful when the sequence length is large but memory constraints are primarily + due to the query dimension. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration + head_block_size = head_block_size or NH + num_q_blocks = num_q_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] + + # Handle small sequences with standard attention + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process each head block + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + + q_output_list = [] + + # Process queries in blocks + for q_block_idx in range(num_q_blocks): + qi = q_block_positions[q_block_idx] + + # Calculate Q block size + if q_block_idx == num_q_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + q_block = q_g[:, :, qi : qi + real_q_len, :] + + # Compute attention for this query block against all keys + scores = torch.matmul(q_block, k_g.transpose(-2, -1)) * scale_factor + probs = torch.softmax(scores, dim=-1) + out_block = torch.matmul(probs, v_g) + + q_output_list.append(out_block) + + # Concatenate query blocks + head_output = torch.cat(q_output_list, dim=2) + head_outputs.append(head_output) + + out = torch.cat(head_outputs, dim=1) # (BS, NH, CL, DH) + return out + + +def apply_qkv_blocking( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + num_q_blocks: int, + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Forward pass with combined Query, Key, Value blocking and head blocking. + + This method implements the most memory-efficient attention computation by blocking + along all three dimensions: heads, queries, and key-values. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + BS, NH, CL, DH = q.shape + scale_factor = 1.0 / math.sqrt(DH) + + # Get blocking configuration from environment variables + head_block_size = head_block_size or NH + num_kv_blocks = num_kv_blocks or CL + num_q_blocks = num_q_blocks or CL + num_head_blocks = math.ceil(NH / head_block_size) + + # Calculate block positions for even distribution + kv_block_positions = [(i * CL) // num_kv_blocks for i in range(num_kv_blocks)] + q_block_positions = [(i * CL) // num_q_blocks for i in range(num_q_blocks)] + + # Optimization: Use standard attention for small sequences + BS, NH, K_CL, DH = k.shape + if K_CL <= 512: + scores = torch.matmul(q, k.transpose(-2, -1)) * scale_factor + if attention_mask is not None: + scores = torch.where(attention_mask, scores, torch.tensor(-1e4, dtype=scores.dtype, device=scores.device)) + probs = torch.softmax(scores, dim=-1) + out = torch.matmul(probs, v) + return out + + head_outputs = [] + + # Process attention heads in blocks to reduce memory usage + for head_block_idx in range(num_head_blocks): + h_start = head_block_idx * head_block_size + h_end = min(h_start + head_block_size, NH) + num_h = h_end - h_start + + # Extract current head block + q_g = q[:, h_start:h_end, :, :] + k_g = k[:, h_start:h_end, :, :] + v_g = v[:, h_start:h_end, :, :] + q_output_list = [] + + # Process queries in blocks within each head block + for q_block_idx in range(num_q_blocks): + qi = q_block_positions[q_block_idx] + + # Calculate actual Q block size (handle remainder for last block) + if q_block_idx == num_q_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + q_block = q_g[:, :, qi : qi + real_q_len, :] + + # Initialize online softmax statistics for this Q block + running_exp_sum = torch.zeros((BS, num_h, real_q_len), device=q.device, dtype=q.dtype) + running_max = torch.full((BS, num_h, real_q_len), float("-inf"), device=q.device, dtype=q.dtype) + output_blocks = torch.zeros((BS, num_h, real_q_len, DH), device=q.device, dtype=q.dtype) + + # Process K,V in blocks for this Q block (online softmax) + for kv_block_idx in range(num_kv_blocks): + ki = kv_block_positions[kv_block_idx] + + # Calculate actual KV block size + if kv_block_idx == num_kv_blocks - 1: + real_kv_len = CL - ki + else: + real_kv_len = kv_block_positions[kv_block_idx + 1] - ki + + k_block = k_g[:, :, ki : ki + real_kv_len, :] + v_block = v_g[:, :, ki : ki + real_kv_len, :] + + # Compute attention scores for current Q-K block + qkblock = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale_factor + + # Online softmax: Update running maximum + prev_max = running_max.clone() + if qkblock.shape[-1] == 0: + running_max = prev_max + else: + running_max = torch.maximum(prev_max, torch.max(qkblock, dim=-1)[0]) + + # Calculate adjustment factor for numerical stability + delta_max = prev_max - running_max + curr_exp = torch.exp(qkblock - running_max.unsqueeze(-1)) + + # Online softmax: Update running sum of exponentials + prev_exp_sum = running_exp_sum.clone() + curr_exp_sum = torch.einsum("bhqk->bhq", curr_exp) + running_exp_sum = prev_exp_sum * torch.exp(delta_max) + curr_exp_sum + + # Compute normalized attention weights for this block + inv_running_exp_sum = 1.0 / running_exp_sum + softmax_qkblock = curr_exp * inv_running_exp_sum.unsqueeze(-1) + + # Online softmax: Update output with rescaling of previous blocks + prev_out = output_blocks.clone() + rescale_factor = (prev_exp_sum * inv_running_exp_sum) * torch.exp(delta_max) + output_blocks = rescale_factor.unsqueeze(-1) * prev_out + torch.matmul(softmax_qkblock, v_block) + + q_output_list.append(output_blocks) + + # Concatenate all Q blocks for this head block + head_output = torch.cat(q_output_list, dim=2) + head_outputs.append(head_output) + + # Concatenate all head blocks + out = torch.cat(head_outputs, dim=1) + return out + + +def compute_blocked_attention( + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + head_block_size: int, + num_kv_blocks: int, + num_q_blocks: int, + blocking_mode: str = "default", + attention_mask: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + """ + Main dispatcher function for different attention blocking strategies. + + Args: + q (torch.FloatTensor): Query tensor of shape (BS, NH, CL, DH) + k (torch.FloatTensor): Key tensor of shape (BS, NH, CL, DH) + v (torch.FloatTensor): Value tensor of shape (BS, NH, CL, DH) + head_block_size (int) : Head blocking size + num_kv_blocks (int) : Number of KV blocks + num_q_blocks (int) : Number of Q blocks + blocking_mode (str): Blocking strategy ('kv', 'q', 'qkv', 'default') + attention_mask (Optional[torch.FloatTensor]): Attention mask tensor + + Returns: + torch.FloatTensor: Attention output of shape (BS, NH, CL, DH) + """ + if blocking_mode == "kv": + return apply_kv_blocking(q, k, v, head_block_size, num_kv_blocks, attention_mask) + elif blocking_mode == "q": + return apply_q_blocking(q, k, v, head_block_size, num_q_blocks, attention_mask) + elif blocking_mode == "qkv": + return apply_qkv_blocking(q, k, v, head_block_size, num_kv_blocks, num_q_blocks, attention_mask) + else: # default + return apply_head_blocking(q, k, v, head_block_size, attention_mask) diff --git a/QEfficient/diffusers/models/normalization.py b/QEfficient/diffusers/models/normalization.py new file mode 100644 index 000000000..933832ed8 --- /dev/null +++ b/QEfficient/diffusers/models/normalization.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Optional, Tuple + +import torch +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class QEffAdaLayerNormZero(AdaLayerNormZero): + def forward( + self, + x: torch.Tensor, + shift_msa: Optional[torch.Tensor] = None, + scale_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle): + def forward( + self, + x: torch.Tensor, + scale_msa: Optional[torch.Tensor] = None, + shift_msa: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + +class QEffAdaLayerNormContinuous(AdaLayerNormContinuous): + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = conditioning_embedding + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py new file mode 100644 index 000000000..4fb5c3f12 --- /dev/null +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -0,0 +1,65 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, +) +from diffusers.models.transformers.transformer_wan import WanAttention, WanAttnProcessor, WanTransformer3DModel +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.normalization import ( + QEffAdaLayerNormContinuous, + QEffAdaLayerNormZero, + QEffAdaLayerNormZeroSingle, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxAttention, + QEffFluxAttnProcessor, + QEffFluxSingleTransformerBlock, + QEffFluxTransformer2DModel, + QEffFluxTransformerBlock, +) +from QEfficient.diffusers.models.transformers.transformer_wan import ( + QEffWanAttention, + QEffWanAttnProcessor, + QEffWanTransformer3DModel, +) + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + RMSNorm: CustomRMSNormAIC, + nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm + } + + +class AttentionTransform(ModuleMappingTransform): + _module_mapping = { + FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock, + FluxTransformerBlock: QEffFluxTransformerBlock, + FluxTransformer2DModel: QEffFluxTransformer2DModel, + FluxAttention: QEffFluxAttention, + FluxAttnProcessor: QEffFluxAttnProcessor, + WanAttnProcessor: QEffWanAttnProcessor, + WanAttention: QEffWanAttention, + WanTransformer3DModel: QEffWanTransformer3DModel, + } + + +class NormalizationTransform(ModuleMappingTransform): + _module_mapping = { + AdaLayerNormZero: QEffAdaLayerNormZero, + AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle, + AdaLayerNormContinuous: QEffAdaLayerNormContinuous, + } diff --git a/QEfficient/diffusers/models/transformers/__init__.py b/QEfficient/diffusers/models/transformers/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/transformers/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py new file mode 100644 index 000000000..40b7e3e7e --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -0,0 +1,339 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_flux import ( + FluxAttention, + FluxAttnProcessor, + FluxSingleTransformerBlock, + FluxTransformer2DModel, + FluxTransformerBlock, + _get_qkv_projections, +) + +from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention, get_attention_blocking_config +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_emb( + x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + cos, sin = freqs_cis # [S, D] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = cos.to(x.device), sin.to(x.device) + B, S, H, D = x.shape + x_real, x_imag = x.reshape(B, -1, H, D // 2, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +class QEffFluxAttnProcessor(FluxAttnProcessor): + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "QEffFluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = qeff_apply_rotary_emb(query, image_rotary_emb) + key = qeff_apply_rotary_emb(key, image_rotary_emb) + + # Get blocking configuration + blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config() + # Apply blocking using pipeline_utils + hidden_states = compute_blocked_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + blocking_mode=blocking_mode, + head_block_size=head_block_size, + num_kv_blocks=num_kv_blocks, + num_q_blocks=num_q_blocks, + attention_mask=attention_mask, + ) + + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class QEffFluxAttention(FluxAttention): + def __qeff_init__(self): + processor = QEffFluxAttnProcessor() + self.processor = processor + + +class QEffFluxSingleTransformerBlock(FluxSingleTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + shift_msa, scale_msa, gate = torch.split(temb, 1) + residual = hidden_states + norm_hidden_states = self.norm(hidden_states, scale_msa, shift_msa) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + # if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(torch.finfo(torch.float32).min, torch.finfo(torch.float32).max) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformerBlock(FluxTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + temb1 = tuple(torch.split(temb[:6], 1)) + temb2 = tuple(torch.split(temb[6:], 1)) + norm_hidden_states = self.norm1(hidden_states, shift_msa=temb1[0], scale_msa=temb1[1]) + gate_msa, shift_mlp, scale_mlp, gate_mlp = temb1[-4:] + + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, shift_msa=temb2[0], scale_msa=temb2[1]) + + c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = temb2[-4:] + + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + # if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QEffFluxTransformer2DModel(FluxTransformer2DModel): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + adaln_emb: torch.Tensor = None, + adaln_single_emb: torch.Tensor = None, + adaln_out: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] + + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py new file mode 100644 index 000000000..31d3be2ce --- /dev/null +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -0,0 +1,291 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +QEfficient WAN Transformer Implementation + +This module provides optimized implementations of WAN transformers +with various attention blocking strategies for memory efficiency and performance optimization. +The implementation includes multiple blocking modes: head-only, KV-blocking, Q-blocking, +and combined QKV-blocking. +""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.transformers.transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanTransformer3DModel, + _get_qkv_projections, +) +from diffusers.utils import set_weights_and_activate_adapters + +from QEfficient.diffusers.models.modeling_utils import ( + compute_blocked_attention, + get_attention_blocking_config, +) + + +class QEffWanAttnProcessor(WanAttnProcessor): + """ + QEfficient WAN Attention Processor with Memory-Efficient Blocking Strategies. + + This processor implements multiple attention blocking modes to reduce memory usage + and enable processing of longer sequences. It supports: + - Head blocking: Process attention heads in chunks + - KV blocking: Process key-value pairs in blocks + - Q blocking: Process query tokens in blocks + - QKV blocking: Combined query, key, and value blocking + + Environment Variables: + ATTENTION_BLOCKING_MODE: Controls blocking strategy ('kv', 'q', 'qkv', 'default') + head_block_size: Number of attention heads to process per block + num_kv_blocks: Number of blocks for key-value processing + num_q_blocks: Number of blocks for query processing + """ + + def __call__( + self, + attn: "WanAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Main attention processing pipeline with support for multiple blocking strategies. + + This method orchestrates the complete attention computation including: + 1. QKV projection and normalization + 2. Rotary position embedding application + 3. Attention computation with selected blocking strategy + 4. Output projection + + Args: + attn (WanAttention): The attention module instance + hidden_states (torch.Tensor): Input hidden states + encoder_hidden_states (Optional[torch.Tensor]): Cross-attention encoder states + attention_mask (Optional[torch.Tensor]): Attention mask + rotary_emb (Optional[Tuple[torch.Tensor, torch.Tensor]]): Rotary embeddings (cos, sin) + + Returns: + torch.Tensor: Processed hidden states after attention + """ + # Project inputs to query, key, value + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + # Apply layer normalization to queries and keys + query = attn.norm_q(query) + key = attn.norm_k(key) + + # Reshape for multi-head attention: (batch, seq, dim) -> (batch, seq, heads, head_dim) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # Apply rotary position embeddings if provided + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + """Apply rotary position embeddings to the input tensor.""" + # Split into real and imaginary parts for complex rotation + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2].type_as(hidden_states) + sin = freqs_sin[..., 1::2].type_as(hidden_states) + + # Apply rotation: (x1 + ix2) * (cos + isin) = (x1*cos - x2*sin) + i(x1*sin + x2*cos) + real = x1 * cos - x2 * sin + img = x1 * sin + x2 * cos + x_rot = torch.stack([real, img], dim=-1) + return x_rot.flatten(-2).type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # Get blocking configuration + blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config() + # Apply blocking using pipeline_utils + hidden_states = compute_blocked_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_block_size, + num_kv_blocks, + num_q_blocks, + blocking_mode=blocking_mode, + attention_mask=attention_mask, + ) + + # Reshape back to original format + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + # Apply output projection layers + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class QEffWanAttention(WanAttention): + """ + QEfficient WAN Attention module with optimized processor. + + This class extends the base WanAttention with QEfficient optimizations, + automatically setting up the QEffWanAttnProcessor for memory-efficient + attention computation. + """ + + def __qeff_init__(self): + """Initialize the QEfficient attention processor.""" + processor = QEffWanAttnProcessor() + self.processor = processor + + +class QEffWanTransformer3DModel(WanTransformer3DModel): + """ + QEfficient 3D WAN Transformer Model with adapter support. + + This model extends the base WanTransformer3DModel with QEfficient optimizations. + """ + + def set_adapters( + self, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, + ): + """ + Set the currently active adapters for use in the diffusion network. + + This method manages PEFT adapters, allowing for efficient fine-tuning + and model customization without modifying the base model parameters. + + Args: + adapter_names (Union[List[str], str]): Names of adapters to activate + weights (Optional[Union[float, Dict, List[float], List[Dict], List[None]]]): + Weights for each adapter. Can be: + - Single float: Applied to all adapters + - List of floats: One weight per adapter + - Dict: Detailed weight configuration + - None: Uses default weight of 1.0 + + Raises: + ValueError: If adapter names and weights lists have different lengths + + Note: + - Adapters enable parameter-efficient fine-tuning + - Multiple adapters can be active simultaneously with different weights + - Weights control the influence of each adapter on the model output + """ + # Normalize adapter names to list format + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand weights into a list, one entry per adapter + # Examples for 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." + ) + + # Set None values to default of 1.0 + # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] + weights = [w if w is not None else 1.0 for w in weights] + + # Expand weights using model-specific scaling function + # e.g. [{...}, 7] -> [{expanded dict...}, 7] + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[ + self.config._class_name + ] # updated to use WanTransformer3DModel + weights = scale_expansion_fn(self, weights) + set_weights_and_activate_adapters(self, adapter_names, weights) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + rotary_emb: torch.Tensor, + temb: torch.Tensor, + timestep_proj: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass of the 3D WAN Transformer. + + This method implements the complete forward pass including: + 1. Patch embedding of input + 2. Rotary embedding preparation + 3. Cross-attention with encoder states + 4. Transformer block processing + 5. Output normalization and projection + + Args: + hidden_states (torch.Tensor): Input tensor to transform + encoder_hidden_states (torch.Tensor): Cross-attention encoder states + rotary_emb (torch.Tensor): Rotary position embeddings + temb (torch.Tensor): Time embedding for diffusion process + timestep_proj (torch.Tensor): Projected timestep embeddings + encoder_hidden_states_image (Optional[torch.Tensor]): Image encoder states for I2V + return_dict (bool): Whether to return a dictionary or tuple + attention_kwargs (Optional[Dict[str, Any]]): Additional attention arguments + + Returns: + Union[torch.Tensor, Dict[str, torch.Tensor]]: + Transformed hidden states, either as tensor or in a dictionary + """ + # Prepare rotary embeddings by splitting along batch dimension + rotary_emb = torch.split(rotary_emb, 1, dim=0) + + # Apply patch embedding and reshape for transformer processing + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + + # Concatenate image and text encoder states if image conditioning is present + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # Standard forward pass + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Output normalization, projection & unpatchify + if temb.ndim == 3: + # Handle 3D time embeddings: batch_size, seq_len, inner_dim (WAN 2.2 T2V) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # Handle 2D time embeddings: batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Ensure tensors are on the same device as hidden_states + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + # Apply adaptive layer normalization with time conditioning + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + + # Final output projection + hidden_states = self.proj_out(hidden_states) + + # Store output for return (compiler optimization) + output = hidden_states + + # Return in requested format + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/configs/flux_config.json b/QEfficient/diffusers/pipelines/configs/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/QEfficient/diffusers/pipelines/configs/wan_config.json b/QEfficient/diffusers/pipelines/configs/wan_config.json new file mode 100644 index 000000000..3f5edce07 --- /dev/null +++ b/QEfficient/diffusers/pipelines/configs/wan_config.json @@ -0,0 +1,36 @@ +{ + "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)", + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 16, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +} \ No newline at end of file diff --git a/QEfficient/diffusers/pipelines/flux/__init__.py b/QEfficient/diffusers/pipelines/flux/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 000000000..eeb260c53 --- /dev/null +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,854 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +# TODO: Pipeline Architecture Improvements +# 1. Introduce QEffDiffusionPipeline base class to provide unified export, compile, +# and inference APIs across all diffusion pipelines, promoting code reusability +# and consistent interface design. +# 2. Implement persistent QPC session management strategy to retain/drop compiled model +# sessions in memory across all pipeline modules. + +import os +import time +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from tqdm import tqdm + +from QEfficient.diffusers.pipelines.pipeline_module import ( + QEffFluxTransformerModel, + QEffTextEncoder, + QEffVAE, +) +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + calculate_compressed_latent_dimension, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils.logging_utils import logger + + +class QEffFluxPipeline: + """ + QEfficient-optimized Flux pipeline for high-performance text-to-image generation on Qualcomm AI hardware. + + This pipeline provides an optimized implementation of the Flux diffusion model specifically designed + for deployment on Qualcomm AI Cloud (QAIC) devices. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX format and compiled + into Qualcomm Program Container (QPC) files for efficient inference. + + The pipeline supports the complete Flux workflow including: + - Dual text encoding with CLIP and T5 encoders + - Transformer-based denoising with adaptive layer normalization + - VAE decoding for final image generation + - Performance monitoring and optimization + + Attributes: + text_encoder (QEffTextEncoder): Optimized CLIP text encoder for pooled embeddings + text_encoder_2 (QEffTextEncoder): Optimized T5 text encoder for sequence embeddings + transformer (QEffFluxTransformerModel): Optimized Flux transformer for denoising + vae_decode (QEffVAE): Optimized VAE decoder for latent-to-image conversion + modules (Dict[str, Any]): Dictionary of all pipeline modules for batch operations + model (FluxPipeline): Original HuggingFace Flux model reference + tokenizer: CLIP tokenizer for text preprocessing + scheduler: Diffusion scheduler for timestep management + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> images = pipeline( + ... prompt="A beautiful sunset over mountains", + ... height=512, + ... width=512, + ... num_inference_steps=28 + ... ) + >>> images.images[0].save("generated_image.png") + """ + + _hf_auto_class = FluxPipeline + + def __init__(self, model, *args, **kwargs): + """ + Initialize the QEfficient Flux pipeline. + + This pipeline provides an optimized implementation of the Flux text-to-image model + for deployment on Qualcomm AI hardware. It wraps the original HuggingFace Flux model + components with QEfficient-optimized versions that can be exported to ONNX and compiled + for QAIC devices. + + Args: + model: Pre-loaded FluxPipeline model + **kwargs: Additional arguments including height and width + """ + + # Wrap model components with QEfficient optimized versions + self.model = model + self.text_encoder = QEffTextEncoder(model.text_encoder) + self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2) + self.transformer = QEffFluxTransformerModel(model.transformer) + self.vae_decode = QEffVAE(model.vae, "decoder") + + # Store all modules in a dictionary for easy iteration during export/compile + self.modules = { + "text_encoder": self.text_encoder, + "text_encoder_2": self.text_encoder_2, + "transformer": self.transformer, + "vae_decoder": self.vae_decode, + } + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.text_encoder_2.tokenizer = model.tokenizer_2 + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + # Override VAE forward method to use decode directly + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + # Sync max position embeddings between text encoders + self.text_encoder_2.model.config.max_position_embeddings = ( + self.text_encoder.model.config.max_position_embeddings + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **kwargs, + ): + """ + Load a pretrained Flux model from HuggingFace Hub or local path and wrap it with QEfficient optimizations. + + This class method provides a convenient way to instantiate a QEffFluxPipeline from a pretrained + Flux model. It automatically loads the base FluxPipeline model in float32 precision on CPU + and wraps all components with QEfficient-optimized versions for QAIC deployment. + + Args: + pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier + (e.g., "black-forest-labs/FLUX.1-schnell") or a local path to a saved model directory. + **kwargs: Additional keyword arguments passed to FluxPipeline.from_pretrained(). + + Returns: + QEffFluxPipeline: A fully initialized pipeline instance with QEfficient-optimized components + ready for export, compilation, and inference on QAIC devices. + + Raises: + ValueError: If the model path is invalid or model cannot be loaded + OSError: If there are issues accessing the model files + RuntimeError: If model initialization fails + + Example: + >>> # Load from HuggingFace Hub + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> + >>> # Load from local path + >>> pipeline = QEffFluxPipeline.from_pretrained("/path/to/local/flux/model") + >>> + >>> # Load with custom cache directory + >>> pipeline = QEffFluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... cache_dir="/custom/cache/dir" + ... ) + """ + # Load the base Flux model in float32 on CPU + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method systematically exports each pipeline component (CLIP text encoder, T5 text encoder, + Flux transformer, and VAE decoder) to ONNX format. Each module is exported with its specific + configuration including dynamic axes, input/output specifications, and optimization settings. + + The export process prepares the models for subsequent compilation to QPC format, enabling + efficient inference on QAIC hardware. ONNX subfunctions can be used for certain modules + to optimize memory usage and performance. + + Args: + export_dir (str, optional): Target directory for saving ONNX model files. If None, + uses the default export directory structure based on model name and configuration. + The directory will be created if it doesn't exist. + use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction + optimization for supported modules. This can optimize thegraph and + improve compilation efficiency for models like the transformer. + + Returns: + str: Absolute path to the export directory containing all ONNX model files. + Each module will have its own subdirectory with the exported ONNX file. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Note: + - All models are exported in float32 precision for maximum compatibility + - Dynamic axes are configured to support variable batch sizes and sequence lengths + - The export process may take several minutes depending on model size + - Exported ONNX files can be large (several GB for complete pipeline) + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + >>> print(f"Models exported to: {export_path}") + """ + for module_name, module_obj in tqdm(self.modules.items(), desc="Exporting modules", unit="module"): + # Get ONNX export configuration for this module + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path() -> str: + """ + Get the absolute path to the default Flux pipeline configuration file. + + Returns: + str: Absolute path to the flux_config.json file containing default pipeline + configuration settings for compilation and device allocation. + """ + return "QEfficient/diffusers/pipelines/configs/flux_config.json" + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = 512, + width: int = 512, + use_onnx_subfunctions: bool = False, + ) -> None: + """ + Compile ONNX models into optimized QPC format for deployment on Qualcomm AI hardware. + + Args: + compile_config (str, optional): Path to a JSON configuration file containing + compilation settings, device mappings, and optimization parameters. If None, + uses the default configuration from get_default_config_path(). + parallel (bool, default=False): Compilation mode selection: + - True: Compile modules in parallel using ThreadPoolExecutor for faster processing + - False: Compile modules sequentially for lower resource usage + height (int, default=512): Target image height in pixels. + width (int, default=512): Target image width in pixels. + use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX + subfunctions before compilation. + + Raises: + RuntimeError: If compilation fails for any module or if QAIC compiler is not available + FileNotFoundError: If ONNX models haven't been exported or config file is missing + ValueError: If configuration parameters are invalid + OSError: If there are issues with file I/O during compilation + + Example: + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> # Sequential compilation with default config + >>> pipeline.compile(height=1024, width=1024) + >>> + >>> # Parallel compilation with custom config + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=512, + ... width=512 + ... ) + """ + # Ensure all modules are exported to ONNX before compilation + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.text_encoder_2.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Load compilation configuration + config_manager(self, config_source=compile_config, use_onnx_subfunctions=use_onnx_subfunctions) + + # Calculate compressed latent dimension using utility function + cl, latent_height, latent_width = calculate_compressed_latent_dimension( + height, width, self.model.vae_scale_factor + ) + + # Prepare dynamic specialization updates based on image dimensions + specialization_updates = { + "transformer": {"cl": cl}, + "vae_decoder": { + "latent_height": latent_height, + "latent_width": latent_width, + }, + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the T5 text encoder for detailed semantic understanding. + + T5 provides rich sequence embeddings that capture fine-grained text details, + complementing CLIP's global representation in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + max_sequence_length (int): Maximum token sequence length (default: 512) + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (prompt_embeds, inference_time) + - prompt_embeds (torch.Tensor): Encoded embeddings [batch*num_images, seq_len, 4096] + - inference_time (float): T5 encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts with padding and truncation + text_inputs = self.text_encoder_2.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.text_encoder_2.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.text_encoder_2.tokenizer.batch_decode( + untruncated_ids[:, self.text_encoder_2.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because `max_sequence_length` is set to " + f"{self.text_encoder_2.tokenizer.model_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder_2.qpc_session is None: + self.text_encoder_2.qpc_session = QAICInferenceSession( + str(self.text_encoder_2.qpc_path), device_ids=device_ids + ) + + # Allocate output buffers for QAIC inference + text_encoder_2_output = { + "last_hidden_state": np.random.rand( + batch_size, max_sequence_length, self.text_encoder_2.model.config.d_model + ).astype(np.int32), + } + self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run T5 encoder inference and measure time + start_t5_time = time.perf_counter() + prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"]) + end_t5_time = time.perf_counter() + text_encoder_2_perf = end_t5_time - start_t5_time + + # Duplicate embeddings for multiple images per prompt + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, text_encoder_2_perf + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device_ids: Optional[List[int]] = None, + ): + """ + Encode text prompts using the CLIP text encoder for global semantic representation. + + CLIP provides pooled embeddings that capture high-level semantic meaning, + working alongside T5's detailed sequence embeddings in Flux's dual encoder setup. + + Args: + prompt (str or List[str]): Input prompt(s) to encode + num_images_per_prompt (int): Number of images to generate per prompt + device_ids (List[int], optional): QAIC device IDs for inference + + Returns: + tuple: (pooled_prompt_embeds, inference_time) + - pooled_prompt_embeds (torch.Tensor): Pooled embeddings [batch*num_images, 768] + - inference_time (float): CLIP encoder inference time in seconds + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # Tokenize prompts + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + # Check for truncation and warn user + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + f"The following part of your input was truncated because CLIP can only handle sequences up to " + f"{self.tokenizer_max_length} tokens: {removed_text}" + ) + + # Initialize QAIC inference session if not already created + if self.text_encoder.qpc_session is None: + self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids) + + # Allocate output buffers for QAIC inference + text_encoder_output = { + "last_hidden_state": np.random.rand( + batch_size, self.tokenizer_max_length, self.text_encoder.model.config.hidden_size + ).astype(np.float32), + "pooler_output": np.random.rand(batch_size, self.text_encoder.model.config.hidden_size).astype(np.int32), + } + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + + # Prepare input for QAIC inference + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + + # Run CLIP encoder inference and measure time + start_text_encoder_time = time.perf_counter() + aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input) + end_text_encoder_time = time.perf_counter() + text_encoder_perf = end_text_encoder_time - start_text_encoder_time + # Extract pooled output (used for conditioning in Flux) + prompt_embeds = torch.tensor(aic_embeddings["pooler_output"]) + + # Duplicate embeddings for multiple images per prompt + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, text_encoder_perf + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + """ + Encode text prompts using Flux's dual text encoder architecture. + + Flux employs both CLIP and T5 encoders for comprehensive text understanding: + - CLIP provides pooled embeddings for global semantic conditioning + - T5 provides detailed sequence embeddings for fine-grained text control + + Args: + prompt (str or List[str]): Primary prompt(s) for both encoders + prompt_2 (str or List[str], optional): Secondary prompt(s) for T5. If None, uses primary prompt + num_images_per_prompt (int): Number of images to generate per prompt + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 embeddings + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings + max_sequence_length (int): Maximum sequence length for T5 tokenization + + Returns: + tuple: (prompt_embeds, pooled_prompt_embeds, text_ids, encoder_perf_times) + - prompt_embeds (torch.Tensor): T5 sequence embeddings [batch*num_images, seq_len, 4096] + - pooled_prompt_embeds (torch.Tensor): CLIP pooled embeddings [batch*num_images, 768] + - text_ids (torch.Tensor): Position IDs for text tokens [seq_len, 3] + - encoder_perf_times (List[float]): Performance times [CLIP_time, T5_time] + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + # Use primary prompt for both encoders if secondary not provided + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # Encode with CLIP (returns pooled embeddings) + pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds( + prompt=prompt, + device_ids=self.text_encoder.device_ids, + num_images_per_prompt=num_images_per_prompt, + ) + + # Encode with T5 (returns sequence embeddings) + prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device_ids=self.text_encoder_2.device_ids, + ) + + # Create text position IDs (required by Flux transformer) + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + + return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf] + + def __call__( + self, + height: int = 512, + width: int = 512, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + use_onnx_subfunctions: bool = False, + ): + """ + Generate images from text prompts using the QEfficient-optimized Flux pipeline on QAIC hardware. + + This is the main entry point for text-to-image generation. It orchestrates the complete Flux + diffusion pipeline optimized for Qualcomm AI Cloud devices. + + Args: + height (int, optional): Target image height in pixels. Must be divisible by 8. Default: 512. + width (int, optional): Target image width in pixels. Must be divisible by 8. Default: 512. + prompt (str or List[str]): Primary text prompt(s) describing the desired image(s). + Required unless `prompt_embeds` is provided. + prompt_2 (str or List[str], optional): Secondary prompt for T5 encoder. If None, uses `prompt`. + negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid. + Only used when `true_cfg_scale > 1.0`. + negative_prompt_2 (str or List[str], optional): Secondary negative prompt for T5. If None, uses `negative_prompt`. + true_cfg_scale (float, optional): True classifier-free guidance scale. Values > 1.0 enable + negative prompting. Default: 1.0 (disabled). + num_inference_steps (int, optional): Number of denoising steps. Default: 28. + timesteps (List[int], optional): Custom timestep schedule. If provided, overrides `num_inference_steps`. + guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.5. + num_images_per_prompt (int, optional): Number of images to generate per prompt. Default: 1. + generator (torch.Generator or List[torch.Generator], optional): Random generator for reproducibility. + latents (torch.FloatTensor, optional): Pre-generated latent tensors. If None, random latents are generated. + prompt_embeds (torch.FloatTensor, optional): Pre-computed T5 text embeddings. Shape: [batch, seq_len, 4096]. + pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed CLIP pooled embeddings. Shape: [batch, 768]. + negative_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative T5 embeddings. + negative_pooled_prompt_embeds (torch.FloatTensor, optional): Pre-computed negative CLIP embeddings. + output_type (str, optional): Output format. Options: "pil" (default), "np", or "latent". + callback_on_step_end (Callable, optional): Callback function executed after each denoising step. + callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. Default: ["latents"]. + max_sequence_length (int, optional): Maximum token sequence length for T5 encoder. Default: 512. + custom_config_path (str, optional): Path to custom JSON configuration file for compilation settings. + parallel_compile (bool, optional): Whether to compile modules in parallel. Default: False. + use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. Default: False. + + Returns: + QEffPipelineOutput: A dataclass containing: + - images: Generated image(s) in the format specified by `output_type` + - pipeline_module: Performance metrics for each pipeline component (text encoders, transformer, VAE) + + Raises: + ValueError: If input validation fails or parameters are incompatible. + RuntimeError: If compilation fails or QAIC devices are unavailable. + FileNotFoundError: If custom config file is specified but not found. + + Example: + >>> from QEfficient.diffusers.pipelines.flux import QEffFluxPipeline + >>> pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + >>> result = pipeline( + ... prompt="A serene mountain landscape at sunset", + ... height=1024, + ... width=1024, + ... num_inference_steps=28, + ... guidance_scale=7.5 + ... ) + >>> result.images[0].save("mountain_sunset.png") + >>> print(f"Transformer inference time: {sum(result.pipeline_module[2].perf):.2f}s") + """ + device = self.model._execution_device + + if height is None or width is None: + logger.warning("Height or width is None. Setting default values of 512 for both dimensions.") + + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # Set device IDs for all modules based on configuration + set_module_device_ids(self) + + # Validate all inputs + self.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # Step 2: Determine batch size from inputs + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 3: Encode prompts with both text encoders + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + (prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Encode negative prompts if using true classifier-free guidance + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = self.transformer.model.config.in_channels // 4 + latents, latent_image_ids = self.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Calculate compressed latent dimension for transformer buffer allocation + cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor) + + # Initialize transformer inference session + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + self.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds) + + # Compute AdaLN (Adaptive Layer Normalization) embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.transformer_blocks)): + block = self.transformer.model.transformer_blocks[block_idx] + # Process through norm1 and norm1_context + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(self.transformer.model.single_transformer_blocks)): + block = self.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = self.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": prompt_embeds.detach().numpy(), + "pooled_projections": pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # Run transformer inference and measure time + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Update latents using scheduler (x_t -> x_t-1) + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch (workaround for MPS backend bug) + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Execute callback if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images (unless output_type is "latent") + if output_type == "latent": + image = latents + else: + # Unpack and denormalize latents + latents = self.model._unpack_latents(latents, height, width, self.model.vae_scale_factor) + latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor + + # Initialize VAE decoder inference session + if self.vae_decode.qpc_session is None: + self.vae_decode.qpc_session = QAICInferenceSession( + str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.int32)} + self.vae_decode.qpc_session.set_buffers(output_buffer) + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.perf_counter() + image = self.vae_decode.qpc_session.run(inputs) + end_decode_time = time.perf_counter() + vae_decode_perf = end_decode_time - start_decode_time + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = self.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py new file mode 100644 index 000000000..19e7701d4 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -0,0 +1,632 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from diffusers.models.transformers.transformer_wan import WanTransformerBlock + +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.diffusers.models.pytorch_transforms import ( + AttentionTransform, + CustomOpsTransform, + NormalizationTransform, +) +from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxSingleTransformerBlock, + QEffFluxTransformerBlock, +) +from QEfficient.transformers.models.pytorch_transforms import ( + T5ModelTransform, +) +from QEfficient.utils import constants + + +class QEffTextEncoder(QEFFBaseModel): + """ + Wrapper for text encoder models with ONNX export and QAIC compilation capabilities. + + This class handles text encoder models (CLIP, T5) with specific transformations and + optimizations for efficient inference on Qualcomm AI hardware. It applies custom + PyTorch and ONNX transformations to prepare models for deployment. + + Attributes: + model (nn.Module): The wrapped text encoder model (deep copy of original) + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform, T5ModelTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying text encoder model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the text encoder wrapper. + + Args: + model (nn.Module): The text encoder model to wrap (CLIP or T5) + """ + super().__init__(model) + self.model = model + + def get_onnx_params(self) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the text encoder. + + Creates example inputs, dynamic axes specifications, and output names + tailored to the specific text encoder type (CLIP vs T5). + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # Create example input with max sequence length + example_inputs = { + "input_ids": torch.zeros((bs, self.model.config.max_position_embeddings), dtype=torch.int64), + } + + # Define which dimensions can vary at runtime + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}} + + # T5 only outputs hidden states, CLIP outputs both hidden states and pooled output + if self.model.__class__.__name__ == "T5EncoderModel": + output_names = ["last_hidden_state"] + else: + output_names = ["last_hidden_state", "pooler_output"] + example_inputs["output_hidden_states"] = False + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the text encoder model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffUNet(QEFFBaseModel): + """ + Wrapper for UNet models with ONNX export and QAIC compilation capabilities. + + This class handles UNet models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. UNet is commonly used in + diffusion models for image generation tasks. + + Attributes: + model (nn.Module): The wrapped UNet model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying UNet model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the UNet wrapper. + + Args: + model (nn.Module): The pipeline model containing the UNet + """ + super().__init__(model.unet) + self.model = model.unet + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the UNet model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffVAE(QEFFBaseModel): + """ + Wrapper for Variational Autoencoder (VAE) models with ONNX export and QAIC compilation. + + This class handles VAE models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. VAE models are used in diffusion + pipelines for encoding images to latent space and decoding latents back to images. + + Attributes: + model (nn.Module): The wrapped VAE model (deep copy of original) + type (str): VAE operation type ("encoder" or "decoder") + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying VAE model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module, type: str) -> None: + """ + Initialize the VAE wrapper. + + Args: + model (nn.Module): The pipeline model containing the VAE + type (str): VAE operation type ("encoder" or "decoder") + """ + super().__init__(model) + self.model = model + + # To have different hashing for encoder/decoder + self.model.config["type"] = type + + def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the VAE decoder. + + Args: + latent_height (int): Height of latent representation (default: 32) + latent_width (int): Width of latent representation (default: 32) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + + # VAE decoder takes latent representation as input + example_inputs = { + "latent_sample": torch.randn(bs, 16, latent_height, latent_width), + "return_dict": False, + } + + output_names = ["sample"] + + # All dimensions except channels can be dynamic + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "latent_height", 3: "latent_width"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + ) -> str: + """ + Export the VAE model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments + + Returns: + str: Path to the exported ONNX model + """ + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffFluxTransformerModel(QEFFBaseModel): + """ + Wrapper for Flux Transformer2D models with ONNX export and QAIC compilation capabilities. + + This class handles Flux Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. Flux uses a transformer-based diffusion + architecture instead of traditional UNet, with dual transformer blocks and adaptive layer + normalization (AdaLN) for conditioning. + + Attributes: + model (nn.Module): The wrapped Flux transformer model + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Flux transformer model + """ + return self.model.config.__dict__ + + def __init__(self, model: nn.Module) -> None: + """ + Initialize the Flux transformer wrapper. + + Args: + model (nn.Module): The Flux transformer model to wrap + """ + super().__init__(model) + + def get_onnx_params( + self, + batch_size: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_length: int = constants.FLUX_ONNX_EXPORT_SEQ_LENGTH, + cl: int = constants.FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM, + ) -> Tuple[Dict, Dict, List[str]]: + """ + Generate ONNX export configuration for the Flux transformer. + + Creates example inputs for all Flux-specific inputs including hidden states, + text embeddings, timestep conditioning, and AdaLN embeddings. + + Args: + batch_size (int): Batch size for example inputs (default: FLUX_ONNX_EXPORT_BATCH_SIZE) + seq_length (int): Text sequence length (default: FLUX_ONNX_EXPORT_SEQ_LENGTH) + cl (int): Compressed latent dimension (default: FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM) + + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + example_inputs = { + # Latent representation of the image + "hidden_states": torch.randn(batch_size, cl, self.model.config.in_channels, dtype=torch.float32), + "encoder_hidden_states": torch.randn( + batch_size, seq_length, self.model.config.joint_attention_dim, dtype=torch.float32 + ), + "pooled_projections": torch.randn(batch_size, self.model.config.pooled_projection_dim, dtype=torch.float32), + "timestep": torch.tensor([1.0], dtype=torch.float32), + "img_ids": torch.randn(cl, 3, dtype=torch.float32), + "txt_ids": torch.randn(seq_length, 3, dtype=torch.float32), + # AdaLN embeddings for dual transformer blocks + # Shape: [num_layers, FLUX_ADALN_DUAL_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_emb": torch.randn( + self.model.config["num_layers"], + constants.FLUX_ADALN_DUAL_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # AdaLN embeddings for single transformer blocks + # Shape: [num_single_layers, FLUX_ADALN_SINGLE_BLOCK_CHUNKS, FLUX_ADALN_HIDDEN_DIM] + "adaln_single_emb": torch.randn( + self.model.config["num_single_layers"], + constants.FLUX_ADALN_SINGLE_BLOCK_CHUNKS, + constants.FLUX_ADALN_HIDDEN_DIM, + dtype=torch.float32, + ), + # Output AdaLN embedding + # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection + "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32), + } + + output_names = ["output"] + + # Define dynamic dimensions for runtime flexibility + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "cl"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_len"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "steps"}, + "img_ids": {0: "cl"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export the Flux transformer model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization + + Returns: + str: Path to the exported ONNX model + """ + + if use_onnx_subfunctions: + export_kwargs = { + "export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}, + "use_onnx_subfunctions": True, + } + + # Sort _use_default_values in config to ensure consistent hash generation during export + self.model.config["_use_default_values"].sort() + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=False, # As weights are needed with AdaLN changes + **export_kwargs, + ) + + def compile(self, specializations: List[Dict], **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) + + +class QEffWanUnifiedTransformer(QEFFBaseModel): + """ + Wrapper for WAN Unified Transformer with ONNX export and QAIC compilation capabilities. + + This class handles the unified WAN transformer model that combines high and low noise transformers + into a single model for efficient deployment. Based on the timestep shape, the model dynamically + selects between high and low noise transformers during inference. + + The wrapper applies specific transformations and optimizations for efficient inference on + Qualcomm AI hardware, particularly for video diffusion models. + + Attributes: + model (nn.Module): The QEffWanUnifiedWrapper model that combines high/low noise transformers + _pytorch_transforms (List): PyTorch transformations applied before ONNX export + _onnx_transforms (List): ONNX transformations applied after export + """ + + _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__(self, unified_transformer): + """ + Initialize the Wan unified transformer. + + Args: + model (nn.Module): Wan unified transformer model + """ + super().__init__(unified_transformer) + self.model = unified_transformer + + @property + def get_model_config(self) -> Dict: + """ + Get the model configuration as a dictionary. + + Returns: + Dict: The configuration dictionary of the underlying Wan transformer model + """ + return self.model.config.__dict__ + + def get_onnx_params(self): + """ + Generate ONNX export configuration for the Wan transformer. + + Creates example inputs for all Wan-specific inputs including hidden states, + text embeddings, timestep conditioning, + Returns: + Tuple containing: + - example_inputs (Dict): Sample inputs for ONNX export + - dynamic_axes (Dict): Specification of dynamic dimensions + - output_names (List[str]): Names of model outputs + """ + batch_size = constants.WAN_ONNX_EXPORT_BATCH_SIZE + example_inputs = { + # hidden_states = [ bs, in_channels, frames, latent_height, latent_width] + "hidden_states": torch.randn( + batch_size, + self.model.config.in_channels, + constants.WAN_ONNX_EXPORT_LATENT_FRAMES, + constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P, + constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P, + dtype=torch.float32, + ), + # encoder_hidden_states = [BS, seq len , text dim] + "encoder_hidden_states": torch.randn( + batch_size, constants.WAN_ONNX_EXPORT_SEQ_LEN, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32 + ), + # Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs + "rotary_emb": torch.randn( + 2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 + ), + # Timestep embeddings: [batch_size=1, embedding_dim] + "temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32), + # Projected timestep embeddings: [batch_size=1, projection_dim, embedding_dim] + "timestep_proj": torch.randn( + batch_size, + constants.WAN_PROJECTION_DIM, + constants.WAN_TEXT_EMBED_DIM, + dtype=torch.float32, + ), + # Timestep parameter: Controls high/low noise transformer selection based on shape + "tsp": torch.ones(1, dtype=torch.int64), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": { + 0: "batch_size", + 1: "num_channels", + 2: "num_frames", + 3: "latent_height", + 4: "latent_width", + }, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "rotary_emb": {1: "cl"}, + "tsp": {0: "model_type"}, + } + + return example_inputs, dynamic_axes, output_names + + def export( + self, + inputs: Dict, + output_names: List[str], + dynamic_axes: Dict, + export_dir: str = None, + export_kwargs: Dict = {}, + use_onnx_subfunctions: bool = False, + ) -> str: + """Export the Wan transformer model to ONNX format. + + Args: + inputs (Dict): Example inputs for ONNX export + output_names (List[str]): Names of model outputs + dynamic_axes (Dict): Specification of dynamic dimensions + export_dir (str, optional): Directory to save ONNX model + export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions) + use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions + for better modularity and potential optimization + Returns: + str: Path to the exported ONNX model + """ + if use_onnx_subfunctions: + export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True} + + return self._export( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=True, + **export_kwargs, + ) + + def compile(self, specializations, **compiler_options) -> None: + """ + Compile the ONNX model for Qualcomm AI hardware. + + Args: + specializations (List[Dict]): Model specialization configurations + **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) + """ + self._compile(specializations=specializations, **compiler_options) diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000..135a6bd07 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,355 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import math +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn as nn +from tqdm import tqdm + +from QEfficient.utils._utils import load_json +from QEfficient.utils.logging_utils import logger + + +def calculate_compressed_latent_dimension(height: int, width: int, vae_scale_factor: int) -> int: + """ + Calculate the compressed latent dimension. + Args: + height (int): Target image height in pixels + width (int): Target image width in pixels + vae_scale_factor (int): VAE downsampling factor (typically 8 for Flux) + + Returns: + int: Compressed latent dimension (cl) for transformer input buffer allocation + """ + latent_height = height // vae_scale_factor + latent_width = width // vae_scale_factor + # cl = compressed latent dimension (divided by 4 for Flux's 2x2 packing) + cl = (latent_height * latent_width) // 4 + return cl, latent_height, latent_width + + +def calculate_latent_dimensions_with_frames( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int, + vae_scale_factor_temporal: int, + patch_height: int, + patch_width: int, +) -> int: + """ + Calculate the latent dimensions for video generation models. + + This method computes the compressed sequence length (cl), + Latent height, Latent width , Latent frames based on the + target video dimensions, VAE scale factors, and patch sizes. + + Args: + height (int): Target video height in pixels + width (int): Target video width in pixels + num_frames (int): Target video frames in pixels + vae_scale_factor_spatial (int): spatial vae_scale_factor from model config + vae_scale_factor_temporal (int): temporal vae_scale_factor from model config + patch_height (int): patch_height from model config + patch_width (int): patch_width from model config + + Returns: + tuple: (cl, latent_height, latent_width) + - cl (int): Compressed latent dimension for transformer input + - latent_height (int): Height in latent space + - latent_width (int): Width in latent space + - latent_frames (int): frames in latent space + + Mathematical Formula: + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + latent_frames = math.ceil(num_frames / vae_scale_factor_temporal) + cl = (latent_height // patch_height) * (latent_width // patch_width) * latent_frames + + """ + # Calculate latent space dimensions after VAE encoding + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + latent_frames = math.ceil(num_frames / vae_scale_factor_temporal) + cl = (latent_height // patch_height * latent_width // patch_width) * latent_frames + return cl, latent_height, latent_width, latent_frames + + +def config_manager(cls, config_source: Optional[str] = None, use_onnx_subfunctions: bool = False): + """ + JSON-based compilation configuration manager for diffusion pipelines. + + Supports loading configuration from JSON files only. Automatically detects + model type and handles model-specific requirements. + Initialize the configuration manager. + + Args: + config_source: Path to JSON configuration file. If None, uses default config. + """ + if config_source is None: + config_source = cls.get_default_config_path() + + if not isinstance(config_source, str): + raise ValueError("config_source must be a path to JSON configuration file") + + # Direct use of load_json utility - no wrapper needed + if not os.path.exists(config_source): + raise FileNotFoundError(f"Configuration file not found: {config_source}") + + cls.custom_config = load_json(config_source) + + # Enable ONNX subfunctions for specific modules if requested + for module_name, _ in cls.modules.items(): + if module_name in ONNX_SUBFUNCTION_MODULE: + cls.custom_config["modules"][module_name]["compilation"]["use_onnx_subfunctions"] = use_onnx_subfunctions + + +def set_module_device_ids(cls): + """ + Set device IDs for each module based on the custom configuration. + + Iterates through all modules in the pipeline and assigns device IDs + from the configuration file to each module's device_ids attribute. + """ + config_modules = cls.custom_config["modules"] + for module_name, module_obj in cls.modules.items(): + module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"] + + +def compile_modules_parallel( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules in parallel using ThreadPoolExecutor. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + """ + + def _prepare_and_compile(module_name: str, module_obj: Any) -> None: + """Prepare specializations and compile a single module.""" + specializations = config["modules"][module_name]["specializations"].copy() + compile_kwargs = config["modules"][module_name]["compilation"] + + if ( + specialization_updates and module_name in specialization_updates + ): # Apply specialization updates if available + if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}] + for i, spec in enumerate(specializations): + spec.update(specialization_updates[module_name][i]) + else: + specializations.update(specialization_updates[module_name]) + specializations = [specializations] + else: + specializations = [specializations] + # Compile with prepared specializations + module_obj.compile(specializations=specializations, **compile_kwargs) + + # Execute compilations in parallel + with ThreadPoolExecutor(max_workers=len(modules)) as executor: + futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()} + + with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar: + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Compilation failed for {futures[future]}: {e}") + raise + pbar.update(1) + + +def compile_modules_sequential( + modules: Dict[str, Any], + config: Dict[str, Any], + specialization_updates: Dict[str, Dict[str, Any]] = None, +) -> None: + """ + Compile multiple pipeline modules sequentially. + + This function provides a generic way to compile diffusion pipeline modules + sequentially, which is the default behavior for backward compatibility. + + Args: + modules: Dictionary of module_name -> module_object pairs to compile + config: Configuration dictionary containing module-specific compilation settings + specialization_updates: Optional dictionary of module_name -> specialization_updates + to apply dynamic values (e.g., image dimensions) + + """ + for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"): + module_config = config["modules"] + specializations = module_config[module_name]["specializations"].copy() + compile_kwargs = module_config[module_name]["compilation"] + + if ( + specialization_updates and module_name in specialization_updates + ): # Apply specialization updates if available + if isinstance(specializations, list): # for unified models spec will be [{high_noise}, {low_noise}] + for i, spec in enumerate(specializations): + spec.update(specialization_updates[module_name][i]) + else: + specializations.update(specialization_updates[module_name]) + specializations = [specializations] + else: + specializations = [specializations] + # Compile with prepared specializations + module_obj.compile(specializations=specializations, **compile_kwargs) + + +@dataclass(frozen=True) +class ModulePerf: + """ + Data class to store performance metrics for a pipeline module. + + Attributes: + module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder') + perf: Performance metric in seconds. Can be a single float for modules that run once, + or a list of floats for modules that run multiple times (e.g., transformer steps) + """ + + module_name: str + perf: int + + +@dataclass(frozen=True) +class QEffPipelineOutput: + """ + Data class to store the output of a QEfficient diffusion pipeline. + + Attributes: + pipeline_module: List of ModulePerf objects containing performance metrics for each module + images: Generated images as either a list of PIL Images or numpy array + """ + + pipeline_module: list[ModulePerf] + images: Union[List[PIL.Image.Image], np.ndarray] + + def __repr__(self): + output_str = "=" * 60 + "\n" + output_str += "QEfficient Diffusers Pipeline Inference Report\n" + output_str += "=" * 60 + "\n\n" + + # Module-wise inference times + output_str += "Module-wise Inference Times:\n" + output_str += "-" * 60 + "\n" + + # Calculate E2E time while iterating + e2e_time = 0 + for module_perf in self.pipeline_module: + module_name = module_perf.module_name + inference_time = module_perf.perf + + # Add to E2E time + e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time + + # Format module name for display + display_name = module_name.replace("_", " ").title() + + # Handle transformer specially as it has a list of times + if isinstance(inference_time, list) and len(inference_time) > 0: + total_time = sum(inference_time) + avg_time = total_time / len(inference_time) + output_str += f" {display_name:25s} {total_time:.4f} s\n" + output_str += f" - Total steps: {len(inference_time)}\n" + output_str += f" - Average per step: {avg_time:.4f} s\n" + output_str += f" - Min step time: {min(inference_time):.4f} s\n" + output_str += f" - Max step time: {max(inference_time):.4f} s\n" + else: + # Single inference time value + output_str += f" {display_name:25s} {inference_time:.4f} s\n" + + output_str += "-" * 60 + "\n\n" + + # Print E2E time after all modules + output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n" + output_str += "=" * 60 + "\n" + + return output_str + + +# List of module name that require special handling during export +# when use_onnx_subfunctions is enabled +ONNX_SUBFUNCTION_MODULE = ["transformer"] + + +class QEffWanUnifiedWrapper(nn.Module): + """ + A wrapper class that combines WAN high and low noise transformers into a single unified transformer. + + This wrapper dynamically selects between high and low noise transformers based on the timestep shape + in the ONNX graph during inference. This approach enables efficient deployment of both transformer + variants in a single model. + + Attributes: + transformer_high(nn.Module): The high noise transformer component + transformer_low(nn.Module): The low noise transformer component + config: Configuration shared between both transformers (from high noise transformer) + """ + + def __init__(self, transformer_high, transformer_low): + super().__init__() + self.transformer_high = transformer_high + self.transformer_low = transformer_low + # Both high and low noise transformers share the same configuration + self.config = transformer_high.config + + def forward( + self, + hidden_states, + encoder_hidden_states, + rotary_emb, + temb, + timestep_proj, + tsp, + attention_kwargs=None, + return_dict=False, + ): + # Condition based on timestep shape + is_high_noise = tsp.shape[0] == torch.tensor(1) + + high_hs = hidden_states.detach() + ehs = encoder_hidden_states.detach() + rhs = rotary_emb.detach() + ths = temb.detach() + projhs = timestep_proj.detach() + + noise_pred_high = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + noise_pred_low = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + + # Select based on timestep condition + noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + return noise_pred diff --git a/QEfficient/diffusers/pipelines/wan/__init__.py b/QEfficient/diffusers/pipelines/wan/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py new file mode 100644 index 000000000..888763af0 --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -0,0 +1,758 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +""" +QEfficient WAN Pipeline Implementation + +This module provides an optimized implementation of the WAN pipeline +for high-performance text-to-video generation on Qualcomm AI hardware. +The pipeline supports WAN 2.2 architectures with unified transformer. + +TODO: 1. Update Vae, umt5 to Qaic; present running on cpu +""" + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import WanPipeline + +from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ONNX_SUBFUNCTION_MODULE, + ModulePerf, + QEffPipelineOutput, + QEffWanUnifiedWrapper, + calculate_latent_dimensions_with_frames, + compile_modules_parallel, + compile_modules_sequential, + config_manager, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants +from QEfficient.utils.logging_utils import logger + + +class QEffWanPipeline: + """ + QEfficient-optimized WAN pipeline for high-performance text-to-video generation on Qualcomm AI hardware. + + This pipeline provides an optimized implementation of the WAN diffusion model + specifically designed for deployment on Qualcomm AI Cloud (QAIC) devices. It extends the original + HuggingFace WAN model with QEfficient-optimized components that can be exported to ONNX format + and compiled into Qualcomm Program Container (QPC) files for efficient video generation. + + The pipeline supports the complete WAN workflow including: + - UMT5 text encoding for rich semantic understanding + - Unified transformer architecture: Combines multiple transformer stages into a single optimized model + - VAE decoding for final video output + - Performance monitoring and hardware optimization + + Attributes: + text_encoder: UMT5 text encoder for semantic text understanding (TODO: QEfficient optimization) + unified_wrapper (QEffWanUnifiedWrapper): Wrapper combining transformer stages + transformer (QEffWanUnifiedTransformer): Optimized unified transformer for denoising + vae_decode: VAE decoder for latent-to-video conversion + modules (Dict[str, Any]): Dictionary of pipeline modules for batch operations + model (WanPipeline): Original HuggingFace WAN model reference + tokenizer: Text tokenizer for preprocessing + scheduler: Diffusion scheduler for timestep management + + Example: + >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> videos = pipeline( + ... prompt="A cat playing in a garden", + ... height=480, + ... width=832, + ... num_frames=81, + ... num_inference_steps=4 + ... ) + >>> # Save generated video + >>> videos.images[0].save("generated_video.mp4") + """ + + _hf_auto_class = WanPipeline + + def __init__(self, model, **kwargs): + """ + Initialize the QEfficient WAN pipeline. + + This pipeline provides an optimized implementation of the WAN text-to-video model + for deployment on Qualcomm AI hardware. It wraps the original HuggingFace WAN model + components with QEfficient-optimized versions that can be exported to ONNX and compiled + for QAIC devices. + + Args: + model: Pre-loaded WanPipeline model with transformer and transformer_2 components + **kwargs: Additional keyword arguments including configuration parameters + """ + # Store original model and configuration + self.model = model + self.kwargs = kwargs + self.custom_config = None + + # Text encoder (TODO: Replace with QEfficient UMT5 optimization) + self.text_encoder = model.text_encoder + + # Create unified transformer wrapper combining dual-stage models(high, low noise DiTs) + self.unified_wrapper = QEffWanUnifiedWrapper(model.transformer, model.transformer_2) + self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) + + # VAE decoder for latent-to-video conversion + self.vae_decode = model.vae + + # Store all modules in a dictionary for easy iteration during export/compile + # TODO: add text encoder, vae decoder on QAIC + self.modules = {"transformer": self.transformer} + + # Copy tokenizers and scheduler from the original model + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.scheduler = model.scheduler + # Extract patch dimensions from transformer configuration + _, self.patch_height, self.patch_width = self.transformer.model.config.patch_size + + @property + def do_classifier_free_guidance(self): + """ + Determine if classifier-free guidance should be used. + + Returns: + bool: True if CFG should be applied based on current guidance scales + """ + return self._guidance_scale > 1.0 and (self._guidance_scale_2 is None or self._guidance_scale_2 > 1.0) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **kwargs, + ): + """ + Load a pretrained WAN model from HuggingFace Hub or local path and wrap it with QEfficient optimizations. + + This class method provides a convenient way to instantiate a QEffWanPipeline from a pretrained + WAN model. It automatically loads the base WanPipeline model in float32 precision on CPU + and wraps all components with QEfficient-optimized versions for QAIC deployment. + + Args: + pretrained_model_name_or_path (str or os.PathLike): Either a HuggingFace model identifier + or a local path to a saved WAN model directory. Should contain transformer, transformer_2, + text_encoder, and VAE components. + **kwargs: Additional keyword arguments passed to WanPipeline.from_pretrained(). + + Returns: + QEffWanPipeline: A fully initialized pipeline instance with QEfficient-optimized components + ready for export, compilation, and inference on QAIC devices. + + Raises: + ValueError: If the model path is invalid or model cannot be loaded + OSError: If there are issues accessing the model files + RuntimeError: If model initialization fails + + Example: + >>> # Load from HuggingFace Hub + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> + >>> # Load from local path + >>> pipeline = QEffWanPipeline.from_pretrained("/local/path/to/wan") + >>> + >>> # Load with custom cache directory + >>> pipeline = QEffWanPipeline.from_pretrained( + ... "wan-model-id", + ... cache_dir="/custom/cache/dir" + ... ) + """ + # Load the base WAN model in float32 on CPU for optimization + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + device_map="cpu", + **kwargs, + ) + return cls( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Export all pipeline modules to ONNX format for deployment preparation. + + This method systematically exports the unified transformer to ONNX format with + video-specific configurations including temporal dimensions, dynamic axes, and + optimization settings. The export process prepares the model for subsequent + compilation to QPC format for efficient inference on QAIC hardware. + + Args: + export_dir (str, optional): Target directory for saving ONNX model files. If None, + uses the default export directory structure. The directory will be created + if it doesn't exist. + use_onnx_subfunctions (bool, default=False): Whether to enable ONNX subfunction + optimization for supported modules. This can optimize the graph structure + and improve compilation efficiency for complex models like the transformer. + + Returns: + str: Absolute path to the export directory containing all ONNX model files. + + Raises: + RuntimeError: If ONNX export fails for any module + OSError: If there are issues creating the export directory or writing files + ValueError: If module configurations are invalid + + Example: + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> export_path = pipeline.export( + ... export_dir="/path/to/export", + ... use_onnx_subfunctions=True + ... ) + """ + + # Export each module with video-specific parameters + for module_name, module_obj in self.modules.items(): + # Get ONNX export configuration with video dimensions + example_inputs, dynamic_axes, output_names = module_obj.get_onnx_params() + + # Prepare export parameters + export_params = { + "inputs": example_inputs, + "output_names": output_names, + "dynamic_axes": dynamic_axes, + "export_dir": export_dir, + } + + # Enable ONNX subfunctions for supported modules if requested + if use_onnx_subfunctions and module_name in ONNX_SUBFUNCTION_MODULE: + export_params["use_onnx_subfunctions"] = True + + module_obj.export(**export_params) + + @staticmethod + def get_default_config_path(): + """ + Get the default configuration file path for WAN pipeline. + + Returns: + str: Path to the default WAN configuration JSON file. + """ + return os.path.join(os.path.dirname(__file__), "wan_config.json") + + def compile( + self, + compile_config: Optional[str] = None, + parallel: bool = False, + height: int = constants.WAN_ONNX_EXPORT_HEIGHT_180P, + width: int = constants.WAN_ONNX_EXPORT_WIDTH_180P, + num_frames: int = constants.WAN_ONNX_EXPORT_FRAMES, + use_onnx_subfunctions: bool = False, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the transformer and compiles them into an optimized format + for inference using JSON-based configuration. + + Args: + compile_config (str, optional): Path to a JSON configuration file containing + compilation settings, device mappings, and optimization parameters. If None, + uses the default configuration. + parallel (bool, default=False): Compilation mode selection: + - True: Compile modules in parallel using ThreadPoolExecutor for faster processing + - False: Compile modules sequentially for lower resource usage + height (int, default=192): Target image height in pixels. + width (int, default=320): Target image width in pixels. + num_frames (int, deafult=81) : Target num of frames in pixel space + use_onnx_subfunctions (bool, default=False): Whether to export models with ONNX + subfunctions before compilation if not already exported. + + Raises: + RuntimeError: If compilation fails for any module or if QAIC compiler is not available + FileNotFoundError: If ONNX models haven't been exported or config file is missing + ValueError: If configuration parameters are invalid + OSError: If there are issues with file I/O during compilation + + Example: + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> # Sequential compilation with default config + >>> pipeline.compile(height=480, width=832, num_frames=81) + >>> + >>> # Parallel compilation with custom config + >>> pipeline.compile( + ... compile_config="/path/to/custom_config.json", + ... parallel=True, + ... height=480, + ... width=832, + ... num_frames=81 + ... ) + """ + # Ensure all modules are exported to ONNX before compilation + if any( + path is None + for path in [ + self.transformer.onnx_path, + ] + ): + self.export(use_onnx_subfunctions=use_onnx_subfunctions) + + # Load compilation configuration + config_manager(self, config_source=compile_config, use_onnx_subfunctions=use_onnx_subfunctions) + + # Configure pipeline dimensions and calculate compressed latent parameters + cl, latent_height, latent_width, latent_frames = calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + self.model.vae.config.scale_factor_spatial, + self.model.vae.config.scale_factor_temporal, + self.patch_height, + self.patch_width, + ) + # Prepare dynamic specialization updates based on video dimensions + specialization_updates = { + "transformer": [ + # high noise + { + "cl": cl, # Compressed latent dimension + "latent_height": latent_height, # Latent space height + "latent_width": latent_width, # Latent space width + "num_frames": latent_frames, # Latent frames + }, + # low noise + { + "cl": cl, # Compressed latent dimension + "latent_height": latent_height, # Latent space height + "latent_width": latent_width, # Latent space width + "num_frames": latent_frames, # Latent frames + }, + ] + } + + # Use generic utility functions for compilation + if parallel: + compile_modules_parallel(self.modules, self.custom_config, specialization_updates) + else: + compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None]]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + use_onnx_subfunctions: bool = False, + parallel_compile: bool = True, + ): + """ + Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware. + + This is the main entry point for text-to-video generation. It orchestrates the complete WAN + diffusion pipeline optimized for Qualcomm AI Cloud devices. + + Args: + prompt (str or List[str]): Primary text prompt(s) describing the desired video content. + Required unless `prompt_embeds` is provided. + negative_prompt (str or List[str], optional): Negative prompt(s) describing what to avoid + in the generated video. Used with classifier-free guidance. + height (int, optional): Target video height in pixels. Must be divisible by VAE scale factor. + Default: 480. + width (int, optional): Target video width in pixels. Must be divisible by VAE scale factor. + Default: 832. + num_frames (int, optional): Number of video frames to generate. Must satisfy temporal + divisibility requirements. Default: 81. + num_inference_steps (int, optional): Number of denoising steps. More steps generally + improve quality but increase generation time. Default: 50. + guidance_scale (float, optional): Guidance scale for classifier-free guidance. Default: 3.0. + guidance_scale_2 (float, optional): Guidance scale for low-noise stage in WAN 2.2. + If None, uses guidance_scale value. + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Default: 1. + generator (torch.Generator or List[torch.Generator], optional): Random generator for + reproducible generation. + latents (torch.Tensor, optional): Pre-generated latent tensors. If None, random latents + are generated based on video dimensions. + prompt_embeds (torch.Tensor, optional): Pre-computed text embeddings from UMT5 encoder. + Shape: [batch, seq_len, hidden_dim]. + negative_prompt_embeds (torch.Tensor, optional): Pre-computed negative text embeddings. + output_type (str, optional): Output format. Options: "np" (default), "pil", or "latent". + return_dict (bool, optional): Whether to return a dictionary or tuple. Default: True. + attention_kwargs (Dict[str, Any], optional): Additional attention arguments for transformer. + callback_on_step_end (Callable, optional): Callback function executed after each denoising step. + callback_on_step_end_tensor_inputs (List[str], optional): Tensor names to pass to callback. + Default: ["latents"]. + max_sequence_length (int, optional): Maximum token sequence length for text encoder. Default: 512. + custom_config_path (str, optional): Path to custom JSON configuration file for compilation. + use_onnx_subfunctions (bool, optional): Whether to export transformer blocks as ONNX subfunctions. + Default: False. + parallel_compile (bool, optional): Whether to compile modules in parallel. Default: True. + + Returns: + QEffPipelineOutput: A dataclass containing: + - images: Generated video(s) in the format specified by `output_type` + - pipeline_module: Performance metrics for each pipeline component + + Raises: + ValueError: If input validation fails or parameters are incompatible + RuntimeError: If compilation fails or QAIC devices are unavailable + FileNotFoundError: If custom config file is specified but not found + + Example: + >>> from QEfficient.diffusers.pipelines.wan import QEffWanPipeline + >>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model") + >>> result = pipeline( + ... prompt="A cat playing in a sunny garden", + ... height=480, + ... width=832, + ... num_frames=81, + ... num_inference_steps=4, + ... guidance_scale=3.0 + ... ) + >>> # Save generated video + >>> result.images[0].save("cat_garden.mp4") + """ + device = "cpu" + + # Compile models with custom configuration if needed + self.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + use_onnx_subfunctions=use_onnx_subfunctions, + height=height, + width=width, + num_frames=num_frames, + ) + + # Set device IDs for all modules based on configuration + set_module_device_ids(self) + + # Step 1: Validate all inputs + self.model.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + # Ensure num_frames satisfies temporal divisibility requirements + if num_frames % self.model.vae.config.scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.model.vae.config.scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = ( + num_frames // self.model.vae.config.scale_factor_temporal * self.model.vae.config.scale_factor_temporal + + 1 + ) + num_frames = max(num_frames, 1) + + if self.model.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + # Initialize pipeline state + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 if guidance_scale_2 is not None else guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Step 2: Determine batch size from inputs + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 3: Encode input prompts using UMT5 text encoder + # TODO: Update UMT5 on QAIC + prompt_embeds, negative_prompt_embeds = self.model.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Convert embeddings to transformer dtype for compatibility + transformer_dtype = self.transformer.model.transformer_high.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # Step 4: Prepare timesteps for denoising process + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Step 5: Prepare initial latent variables for video generation + num_channels_latents = self.transformer.model.config.in_channels + + latents = self.model.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # Create mask for temporal processing (used in expand_timesteps mode) + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # Step 6: Configure dual-stage processing for WAN 2.2 + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Calculate boundary timestep for stage switching in WAN 2.2 + if self.model.config.boundary_ratio is not None: + boundary_timestep = self.model.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + # Step 7: Initialize QAIC inference session for transformer + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) + + # Calculate compressed latent dimension for transformer buffer allocation + cl, _, _, _ = calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + self.model.vae.config.scale_factor_spatial, + self.model.vae.config.scale_factor_temporal, + self.patch_height, + self.patch_width, + ) + # Allocate output buffer for QAIC inference + output_buffer = { + "output": np.random.rand( + batch_size, + cl, # Compressed latent dimension + constants.WAN_DIT_OUT_CHANNELS, + ).astype(np.int32), + } + self.transformer.qpc_session.set_buffers(output_buffer) + transformer_perf = [] + + # Step 8: Denoising loop with dual-stage processing + with self.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self._interrupt: + continue + + self._current_timestep = t + + # Determine which model to use based on boundary timestep + if boundary_timestep is None or t >= boundary_timestep: + # High-noise stage + current_model = self.transformer.model.transformer_high + current_guidance_scale = guidance_scale + model_type = torch.ones(1, dtype=torch.int64) # High-noise model indicator + else: + # Low-noise stage + current_model = self.transformer.model.transformer_low + current_guidance_scale = guidance_scale_2 + model_type = torch.ones(2, dtype=torch.int64) # Low-noise model indicator + + # Prepare latent input with proper dtype + latent_model_input = latents.to(transformer_dtype) + + # Handle timestep expansion for temporal consistency + if self.model.config.expand_timesteps: + # Expand timesteps spatially for better temporal modeling + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # Standard timestep broadcasting + timestep = t.expand(latents.shape[0]) + + # Extract dimensions for patch processing + batch_size, num_channels, num_frames, height, width = latents.shape + p_t, p_h, p_w = current_model.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Generate rotary position embeddings + rotary_emb = current_model.rope(latent_model_input) + rotary_emb = torch.cat(rotary_emb, dim=0) + ts_seq_len = None + timestep = timestep.flatten() + + # Generate conditioning embeddings (time + text) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + current_model.condition_embedder( + timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + ) + + # Generate negative conditioning for classifier-free guidance + if self.do_classifier_free_guidance: + temb, timestep_proj, encoder_hidden_states_neg, encoder_hidden_states_image = ( + current_model.condition_embedder( + timestep, + negative_prompt_embeds, + encoder_hidden_states_image=None, + timestep_seq_len=ts_seq_len, + ) + ) + + # Reshape timestep projection for transformer input + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # Prepare inputs for QAIC inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + "tsp": model_type.detach().numpy(), # Transformer stage pointer + } + + # Prepare negative inputs for classifier-free guidance + if self.do_classifier_free_guidance: + inputs_aic2 = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states_neg.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + } + + # Run conditional prediction with caching context + with current_model.cache_context("cond"): + # QAIC inference for conditional prediction + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + print(f"DIT {i} time {end_transformer_step_time - start_transformer_step_time:.2f} seconds") + + # Process transformer output + hidden_states = torch.tensor(outputs["output"]) + + # Reshape output from patches back to video format + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + # Permute dimensions to reconstruct video tensor + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Run unconditional prediction for classifier-free guidance + if self.do_classifier_free_guidance: # Note: CFG is False for WAN Lightning + with current_model.cache_context("uncond"): + # QAIC inference for unconditional prediction + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic2) + end_transformer_step_time = time.perf_counter() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + # Process unconditional output + hidden_states = torch.tensor(outputs["output"]) + + # Reshape unconditional output + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_uncond = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Apply classifier-free guidance + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # Update latents using scheduler (x_t -> x_t-1) + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Execute callback if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + # Step 9: Decode latents to video + if not output_type == "latent": + # Prepare latents for VAE decoding + latents = latents.to(self.vae_decode.dtype) + + # Apply VAE normalization (denormalization) + latents_mean = ( + torch.tensor(self.vae_decode.config.latents_mean) + .view(1, self.vae_decode.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_decode.config.latents_std).view( + 1, self.vae_decode.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # TODO: Enable VAE on QAIC + # VAE Decode latents to video using CPU (temporary) + video = self.model.vae.decode(latents, return_dict=False)[0] # CPU fallback + + # Post-process video for output + video = self.model.video_processor.postprocess_video(video.detach()) + else: + video = latents + + # Step 10: Collect performance metrics + perf_data = { + "transformer": transformer_perf, # Unified transformer (QAIC) + } + + # Build performance metrics for output + perf_metrics = [ModulePerf(module_name=name, perf=perf_data[name]) for name in perf_data.keys()] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=video, + ) diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index b769680ef..2547d9db3 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -202,7 +202,7 @@ def export_kvstyle_transformed_model_to_onnx( batch_size=len(Constants.INPUT_STR), tokenizer=tokenizer, config=transformed_model.config, - prompt=Constants.INPUT_STR, + prompt=Constants.INPUT_STR * (full_batch_size if full_batch_size else 1), prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py index f86a0f254..fac2441c8 100644 --- a/QEfficient/exporter/export_utils.py +++ b/QEfficient/exporter/export_utils.py @@ -17,7 +17,8 @@ import torch from onnx import external_data_helper -from QEfficient.base.onnx_transforms import FP16ClipTransform +from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransformPipeline +from QEfficient.utils import constants def export_onnx( @@ -97,7 +98,7 @@ def export_onnx( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=13, + opset_version=constants.ONNX_EXPORT_OPSET, custom_opsets={"com.qti.aisw.onnx": 1}, ) except Exception as e: @@ -218,7 +219,8 @@ def fix_onnx_fp16( :str: Updated base name of exported ONNX model. """ model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx")) - model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path) + onnx_transforms = OnnxTransformPipeline(transforms=[FP16ClipTransform]) + model, fp16_fix = onnx_transforms.apply(model, model_name="", onnx_base_dir=gen_models_path) if fp16_fix: # Save FP16 model diff --git a/QEfficient/finetune/dataset/alpaca_dataset.py b/QEfficient/finetune/dataset/alpaca_dataset.py index c6ddb6ce1..ff44860eb 100644 --- a/QEfficient/finetune/dataset/alpaca_dataset.py +++ b/QEfficient/finetune/dataset/alpaca_dataset.py @@ -58,10 +58,15 @@ def __getitem__(self, index): else: prompt = PROMPT_DICT["prompt_input"].format_map(ann) example = prompt + ann["output"] + + if self.context_length is not None: + padding_type = "max_length" + else: + padding_type = True prompt = torch.tensor( - self.tokenizer.encode(prompt, max_length=self.context_length, pad_to_max_length=True), dtype=torch.int64 + self.tokenizer.encode(prompt, max_length=self.context_length, padding=padding_type), dtype=torch.int64 ) - example = self.tokenizer.encode(example, max_length=self.context_length, pad_to_max_length=True) + example = self.tokenizer.encode(example, max_length=self.context_length, padding=padding_type) example.append(self.tokenizer.eos_token_id) example = torch.tensor(example, dtype=torch.int64) labels = copy.deepcopy(example) diff --git a/QEfficient/finetune/dataset/custom_dataset/sample_dataset_preproc.py b/QEfficient/finetune/dataset/custom_dataset/sample_dataset_preproc.py index 78db5674c..383d6fd67 100644 --- a/QEfficient/finetune/dataset/custom_dataset/sample_dataset_preproc.py +++ b/QEfficient/finetune/dataset/custom_dataset/sample_dataset_preproc.py @@ -61,17 +61,22 @@ def apply_prompt_template(sample): dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) def tokenize_add_label(sample): + if context_length is not None: + padding_type = "max_length" + else: + padding_type = True + input = tokenizer.encode( tokenizer.bos_token + sample["input"], add_special_tokens=False, max_length=context_length, - pad_to_max_length=True, + padding=padding_type, ) label = tokenizer.encode( sample["label"] + tokenizer.pad_token + tokenizer.eos_token, add_special_tokens=False, max_length=context_length, - pad_to_max_length=True, + padding=padding_type, ) sample = { diff --git a/QEfficient/finetune/dataset/grammar_dataset.py b/QEfficient/finetune/dataset/grammar_dataset.py index e40c01e97..8fb3eb152 100644 --- a/QEfficient/finetune/dataset/grammar_dataset.py +++ b/QEfficient/finetune/dataset/grammar_dataset.py @@ -44,17 +44,23 @@ def convert_to_features(self, example_batch): target_ = example_batch["target"] prompt = f"Correct this to standard English: {input_}\n---\nCorrected: " + + if self.context_length is not None: + padding_type = "max_length" + else: + padding_type = True + prompt_ids = self.tokenizer.encode( self.tokenizer.bos_token + prompt, add_special_tokens=False, max_length=self.context_length, - pad_to_max_length=True, + padding=padding_type, ) label_ids = self.tokenizer.encode( target_ + self.tokenizer.eos_token, add_special_tokens=False, max_length=self.context_length, - pad_to_max_length=True, + padding=padding_type, ) sample = { diff --git a/QEfficient/finetune/experimental/__init__.py b/QEfficient/finetune/experimental/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/configs/sample_config.yaml b/QEfficient/finetune/experimental/configs/sample_config.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/finetune/experimental/core/__init__.py b/QEfficient/finetune/experimental/core/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py new file mode 100644 index 000000000..7744d71e6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -0,0 +1,200 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import logging +from typing import Callable, Dict, Optional, Type + +# from QEfficient.finetune.experimental.core.logger import get_logger + +# logger = get_logger() +logger = logging.getLogger(__name__) + + +def get_object(obj_dict: Dict, name: str, object_type: str, list_fn: Callable) -> Optional[Type]: + """Utility to get object from a dictionary with error handling.""" + obj = obj_dict.get(name) + if obj is None: + raise ValueError(f"Unknown {object_type}: {name}. Available: {list_fn()}") + return obj + + +class ComponentRegistry: + """Registry for managing different training components.""" + + def __init__(self): + self._optimizers: Dict[str, Type] = {} + self._schedulers: Dict[str, Type] = {} + self._datasets: Dict[str, Type] = {} + self._models: Dict[str, Type] = {} + self._data_collators: Dict[str, Type] = {} + self._metrics: Dict[str, Type] = {} + self._loss_functions: Dict[str, Type] = {} + self._callbacks: Dict[str, Type] = {} + self._hooks: Dict[str, Type] = {} + self._trainer_modules: Dict[str, Type] = {} + + def trainer_module(self, name: str, args_cls=None, required_kwargs=None): + """ + Decorator to register a trainer module with its configuration. + Each trainer module has to be binded to its args class and required kwargs. + + Args: + name: Name of the trainer type + args_cls: The arguments class for this trainer + required_kwargs: Dictionary of required keyword arguments and their default values + """ + required_kwargs = required_kwargs or {} + + def decorator(trainer_cls): + self._trainer_modules[name] = { + "trainer_cls": trainer_cls, + "args_cls": args_cls, + "required_kwargs": required_kwargs, + } + logger.info(f"Registered trainer module: {name}") + return self._trainer_modules[name] + + return decorator + + def optimizer(self, name: str): + """Decorator to register an optimizer class.""" + + def decorator(cls: Type): + self._optimizers[name] = cls + logger.info(f"Registered optimizer: {name}") + return cls + + return decorator + + def scheduler(self, name: str): + """Decorator to register a scheduler class.""" + + def decorator(cls: Type): + self._schedulers[name] = cls + logger.info(f"Registered scheduler: {name}") + return cls + + return decorator + + def dataset(self, name: str): + """Decorator to register a dataset class.""" + + def decorator(cls: Type): + self._datasets[name] = cls + logger.info(f"Registered dataset: {name}") + return cls + + return decorator + + def model(self, name: str): + """Decorator to register a model class.""" + + def decorator(cls: Type): + self._models[name] = cls + logger.info(f"Registered model: {name}") + return cls + + return decorator + + def data_collator(self, name: str): + """Decorator to register a data collator class.""" + + def decorator(fn_pointer: Type): + self._data_collators[name] = fn_pointer + logger.info(f"Registered data collator: {name}") + return fn_pointer + + return decorator + + def loss_function(self, name: str): + """Decorator to register a loss function class.""" + + def decorator(cls: Type): + self._loss_functions[name] = cls + logger.info(f"Registered loss function: {name}") + return cls + + return decorator + + def callback(self, name: str): + """Decorator to register a callback class.""" + + def decorator(cls: Type): + self._callbacks[name] = cls + logger.info(f"Registered callback: {name}") + return cls + + return decorator + + def get_trainer_module(self, name: str) -> Optional[Type]: + """Get trainer module class by name.""" + return get_object(self._trainer_modules, name, "trainer module", self.list_trainer_modules) + + def get_optimizer(self, name: str) -> Optional[Type]: + """Get optimizer class by name.""" + return get_object(self._optimizers, name, "optimizer", self.list_optimizers) + + def get_scheduler(self, name: str) -> Optional[Type]: + """Get scheduler class by name.""" + return get_object(self._schedulers, name, "scheduler", self.list_schedulers) + + def get_dataset(self, name: str) -> Optional[Type]: + """Get dataset class by name.""" + return get_object(self._datasets, name, "dataset", self.list_datasets) + + def get_model(self, name: str) -> Optional[Type]: + """Get model class by name.""" + return get_object(self._models, name, "model", self.list_models) + + def get_data_collator(self, name: str) -> Optional[Type]: + """Get data collator class by name.""" + return get_object(self._data_collators, name, "data collator", self.list_data_collators) + + def get_loss_function(self, name: str) -> Optional[Type]: + """Get loss function class by name.""" + return get_object(self._loss_functions, name, "loss function", self.list_loss_functions) + + def get_callback(self, name: str) -> Optional[Type]: + """Get callback class by name.""" + return get_object(self._callbacks, name, "callback", self.list_callbacks) + + def list_trainer_modules(self) -> list[str]: + """List all registered trainer modules.""" + return list(self._trainer_modules.keys()) + + def list_optimizers(self) -> list[str]: + """List all registered optimizers.""" + return list(self._optimizers.keys()) + + def list_schedulers(self) -> list[str]: + """List all registered schedulers.""" + return list(self._schedulers.keys()) + + def list_datasets(self) -> list[str]: + """List all registered datasets.""" + return list(self._datasets.keys()) + + def list_models(self) -> list[str]: + """List all registered models.""" + return list(self._models.keys()) + + def list_data_collators(self) -> list[str]: + """List all registered data collators.""" + return list(self._data_collators.keys()) + + def list_loss_functions(self) -> list[str]: + """List all registered loss functions.""" + return list(self._loss_functions.keys()) + + def list_callbacks(self) -> list[str]: + """List all registered callbacks.""" + return list(self._callbacks.keys()) + + +# Global registry instance +registry = ComponentRegistry() diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/dataset.py b/QEfficient/finetune/experimental/core/dataset.py new file mode 100644 index 000000000..4a243c40b --- /dev/null +++ b/QEfficient/finetune/experimental/core/dataset.py @@ -0,0 +1,257 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Dataset components for the training system. +""" + +import importlib +import os +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict + +from datasets import load_dataset, load_dataset_builder +from torch.utils.data import Dataset + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.utils.dataset_utils import ( + apply_train_test_split, +) + + +class BaseDataset(Dataset, ABC): + """Base class for all datasets to ensure consistent interface.""" + + def __init__(self, dataset_name: str, split: str, seed: int = 42, **kwargs): + self.dataset_name = dataset_name + self.split = split + self.seed = seed + self.kwargs = kwargs + self._initialize_dataset() + + @abstractmethod + def _initialize_dataset(self): + """Subclasses should implement this to load and prepare the dataset.""" + pass + + @abstractmethod + def __len__(self): + """Return the number of samples in the dataset.""" + pass + + @abstractmethod + def __getitem__(self, idx): + """Should return a dictionary with 'input_ids', 'attention_mask', and 'labels'.""" + pass + + +@registry.dataset("sft_dataset") +class SFTDataset(BaseDataset): + """ + A Supervised Fine-Tuning (SFT) dataset class for text data. + + This class handles loading data from Hugging Face datasets or custom JSON files, + filtering out invalid samples, and applying a prompt/completion templating for SFT tasks. + + Args: + dataset_name (str): The name of the dataset to load from Hugging Face datasets. + Ignored if json_file_path is provided. + split (str): The dataset split to use (e.g., "train", "validation", "test"). + split_ratio (float): Ratio for train/test split when only one split is available. + seed (int): Random seed for reproducibility. + json_file_path (str, optional): Path to a custom JSON file containing the dataset. + If provided, this takes precedence over dataset_name. + prompt_template (str): A string template for constructing the prompt. Variables in the + template should be enclosed in curly braces, e.g., "Answer the question: {question}". + completion_template (str): A string template for constructing the completion (target). + Variables should be enclosed in curly braces, e.g., "{answer}". + + Raises: + RuntimeError: If any variables specified in `prompt_template` or `completion_template` + are not found as columns in the loaded dataset. + """ + + def __init__( + self, + dataset_name: str, + split: str, + split_ratio: float = 0.8, + seed: int = 42, + **kwargs, + ): + self.split_ratio = split_ratio + self.json_file_path = kwargs.get("json_file_path", None) + self.prompt_template = kwargs.get("prompt_template", None) + self.completion_template = kwargs.get("completion_template", None) + self.prompt_func_path = kwargs.get("prompt_func", None) + self.completion_func_path = kwargs.get("completion_func", None) + self.remove_samples_with_empty_columns = kwargs.get("remove_samples_with_empty_columns", True) + + if self.json_file_path not in (None, ""): + if not os.path.isfile(self.json_file_path): + raise FileNotFoundError(f"JSON file not found or invalid: '{self.json_file_path}'") + if (self.prompt_template is None and self.prompt_func_path is None) or ( + self.prompt_template is not None and self.prompt_func_path is not None + ): + raise RuntimeError("Either provide prompt_template or prompt_func in the config.") + if (self.completion_template is None and self.completion_func_path is None) or ( + self.completion_template is not None and self.completion_func_path is not None + ): + raise RuntimeError("Either provide completion_template or completion_func in the config.") + + # Call parent class __init__ which will call _initialize_dataset + super().__init__(dataset_name, split, seed, **kwargs) + + def _initialize_dataset(self): + """ + Initialize the dataset from either HuggingFace or a custom JSON file. + + This method loads the dataset, applies splitting if necessary, and prepares + it for preprocessing with prompt/completion templates. + """ + if self.json_file_path: + # Load dataset from JSON file + self.dataset = load_dataset("json", data_files=self.json_file_path, split="train") + + # Apply train/test split if needed + if self.split in ["train", "test"]: + self.dataset = apply_train_test_split(self.dataset, self.split_ratio, self.split, self.seed) + else: + # Load dataset from HuggingFace + db = load_dataset_builder(self.dataset_name) + available_splits = [] + if db.info.splits is not None: + available_splits = list(db.info.splits.keys()) + + if self.split not in available_splits: + raise ValueError(f"Split {self.split} is not available for dataset {self.dataset_name}.") + + # FIXME: Add streaming support for larger datasets. + self.dataset = load_dataset(self.dataset_name, split=self.split) + + if len(available_splits) == 1: + self.dataset = apply_train_test_split(self.dataset, self.split_ratio, self.split, self.seed) + + self.dataset = self._setup_templates(self.dataset, self.dataset.column_names) + + def _setup_templates(self, dataset, dataset_columns): + """ + Set up prompt/completion templates or functions and apply preprocessing. + """ + if self.prompt_template: + self.prompt_func = None + # Extract variables from templates and check if they exist in dataset columns + prompt_variables = re.findall(r"\{(.*?)\}", self.prompt_template) + for var in prompt_variables: + if var not in dataset_columns: + raise RuntimeError( + f"Prompt template variable '{var}' not found in dataset columns: {dataset_columns}." + ) + else: + prompt_variables = dataset_columns + self.prompt_func = self.import_func(self.prompt_func_path) + + if self.completion_template: + self.completion_func = None + # Extract variables from templates and check if they exist in dataset columns + completion_variables = re.findall(r"\{(.*?)\}", self.completion_template) + for var in completion_variables: + if var not in dataset_columns: + raise RuntimeError( + f"Completion template variable '{var}' not found in dataset columns: {dataset_columns}." + ) + else: + completion_variables = dataset_columns + self.completion_func = self.import_func(self.completion_func_path) + + # Filter out samples with None or empty strings in relevant columns + relevant_columns = list(set(prompt_variables + completion_variables)) + if self.remove_samples_with_empty_columns: + dataset = dataset.filter(lambda example: self._filter_empty_or_none_samples(example, relevant_columns)) + return dataset + + def import_func(self, func_path: str) -> Callable: + if ":" not in func_path: + raise ValueError("func_path must be in the format 'module_file_path:function_name'.") + module_file_path, function_name = func_path.split(":") + + try: + module = importlib.import_module(module_file_path) + except Exception: + raise RuntimeError(f"Unable to import module : {module_file_path}.") + if not hasattr(module, function_name): + raise ValueError(f"Function {function_name} not found in module {module_file_path}.") + return getattr(module, function_name) + + def _filter_empty_or_none_samples(self, example: Dict[str, Any], relevant_columns: list) -> bool: + """ + Filters out samples where any of the relevant columns are None or contain only whitespace. + + Args: + example (Dict[str, Any]): A single sample from the dataset. + relevant_columns (list): List of column names to check for empty or None values. + + Returns: + bool: True if the sample should be kept, False otherwise. + """ + for column in relevant_columns: + value = example.get(column) + if value is None or (isinstance(value, str) and not value.strip()): + return False + return True + + def _preprocess_sample(self, example: Dict[str, Any]) -> Dict[str, str]: + """ + Applies the prompt and completion templates to a single example. + + Args: + example (Dict[str, Any]): A single sample from the dataset. + + Returns: + Dict[str, str]: A dictionary containing the 'prompt' and 'completion' strings. + """ + prompt_text = ( + self.prompt_func(example) if self.prompt_func is not None else self.prompt_template.format(**example) + ) + completion_text = ( + self.completion_func(example) + if self.completion_func is not None + else self.completion_template.format(**example) + ) + return { + "prompt": prompt_text, + "completion": completion_text, + } + + def __len__(self) -> int: + """ + Returns the number of samples in the dataset. + + Returns: + int: The total number of samples. + """ + return self.dataset.num_rows + + def __getitem__(self, idx: int) -> Dict[str, str]: + """ + Retrieves a processed sample from the dataset at the given index. + This method doesn't tokenize the input items, it is expected that the SFTTrainer will handle tokenization. + + Args: + idx (int): The index of the sample to retrieve. + + Returns: + Dict[str, str]: A dictionary containing the processed 'prompt' and 'completion' for the sample. + """ + # Get the raw example using .select and access the first element + example = self.dataset.select(indices=[int(idx)])[0] + + # Apply preprocessing (templating) on the fly + processed_example = self._preprocess_sample(example) + + return processed_example diff --git a/QEfficient/finetune/experimental/core/logger.py b/QEfficient/finetune/experimental/core/logger.py new file mode 100644 index 000000000..a1b9c771f --- /dev/null +++ b/QEfficient/finetune/experimental/core/logger.py @@ -0,0 +1,170 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import logging +import sys +from pathlib import Path +from typing import Optional + +from transformers.utils.logging import get_logger as hf_get_logger + +from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank + +# ----------------------------------------------------------------------------- +# Logger usage: +# Initialize logger: +# logger = Logger("my_logger", log_file="logs/output.log", level=logging.DEBUG) +# Log messages: +# logger.info("This is an info message") +# logger.error("This is an error message") +# logger.log_rank_zero("This message is logged only on rank 0") +# logger.log_exception("An error occurred", exception, raise_exception=False) +# Attach file handler later if needed: +# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG") +# ----------------------------------------------------------------------------- + + +class Logger: + """Custom logger with console and file logging capabilities.""" + + def __init__( + self, + name: str = "transformers", # We are using "transformers" as default to align with HF logs + log_file: Optional[str] = None, + level: int = logging.INFO, + ): + """ + Initialize the logger. + + Args: + name: Logger name + log_file: Path to log file (if None, log only to console) + level: Logging level + """ + self.logger = hf_get_logger(name) + self.logger.setLevel(level) + + # Clear any existing handlers + self.logger.handlers.clear() + + # Create formatter + self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + console_handler.setFormatter(self.formatter) + self.logger.addHandler(console_handler) + + # File handler (if log_file is provided) + if log_file: + # Create directory if it doesn't exist + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + file_handler.setFormatter(self.formatter) + self.logger.addHandler(file_handler) + + def debug(self, message: str) -> None: + """Log debug message.""" + self.logger.debug(message) + + def info(self, message: str) -> None: + """Log info message.""" + self.logger.info(message) + + def warning(self, message: str) -> None: + """Log warning message.""" + self.logger.warning(message) + + def error(self, message: str) -> None: + """Log error message.""" + self.logger.error(message) + + def critical(self, message: str) -> None: + """Log critical message.""" + self.logger.critical(message) + + def log_rank_zero(self, message: str, level: int = logging.INFO) -> None: + """ + Log message only on rank 0 process. + + Args: + message: Message to log + level: Logging level + """ + if get_local_rank() == 0: + self.logger.log(level, message) + + def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None: + """ + Log exception message and optionally raise the exception. + + Args: + message: Custom message to log + exception: Exception to log + raise_exception: Whether to raise the exception after logging + """ + error_message = f"{message}: {str(exception)}" + self.logger.error(error_message) + + if raise_exception: + raise exception + + def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "INFO") -> None: + """ + Prepare existing logger to log to both console and file with specified + output directory and log level. + + Args: + output_dir: Output directory for logs + log_level: Logging level as string + """ + # Convert string log level to logging constant + level = getattr(logging, log_level.upper(), logging.INFO) + self.logger.setLevel(level) + + # Update existing handlers' levels + for handler in self.logger.handlers: + handler.setLevel(level) + + # Add file handler if saving metrics + if output_dir: + log_file = Path(output_dir) / "training.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + + # Check if file handler already exists + file_handler_exists = any(isinstance(handler, logging.FileHandler) for handler in self.logger.handlers) + + if not file_handler_exists: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + file_handler.setFormatter(self.formatter) + self.logger.addHandler(file_handler) + + +# Global logger instance +_logger: Optional[Logger] = None + + +def get_logger(log_file: Optional[str] = None) -> Logger: + """ + Get or create a logger instance. + + Args: + log_file: Path to log file (if None, log only to console) + + Returns: + Logger instance + """ + global _logger + if _logger is None: + _logger = Logger(log_file=log_file) + return _logger diff --git a/QEfficient/finetune/experimental/core/model.py b/QEfficient/finetune/experimental/core/model.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/model.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/optimizer.py b/QEfficient/finetune/experimental/core/optimizer.py new file mode 100644 index 000000000..d4f82cbeb --- /dev/null +++ b/QEfficient/finetune/experimental/core/optimizer.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Optimizer components for the training system. +""" + +import torch.optim as optim + +from QEfficient.finetune.experimental.core.component_registry import registry + +registry.optimizer("Adam")(optim.Adam) +registry.optimizer("AdamW")(optim.AdamW) +registry.optimizer("SGD")(optim.SGD) + + +def prepare_optimizer(opt_config): + """ + Create optimizer from config. + Args: opt_config: Dictionary containing optimizer configuration. + Returns: Tuple of optimizer class and its arguments. + """ + opt_name = opt_config.pop("optimizer_name") + opt_cls = registry.get_optimizer(opt_name) + opt_config["lr"] = float(opt_config["lr"]) + optimizer_cls_and_kwargs = (opt_cls, opt_config) + return optimizer_cls_and_kwargs diff --git a/QEfficient/finetune/experimental/core/trainer/__init__.py b/QEfficient/finetune/experimental/core/trainer/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/trainer/base_trainer.py b/QEfficient/finetune/experimental/core/trainer/base_trainer.py new file mode 100644 index 000000000..0a3c50f7f --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/base_trainer.py @@ -0,0 +1,79 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Optional + +from peft import get_peft_model +from transformers import Trainer, TrainingArguments + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.config_manager import PeftConfig + + +@registry.trainer_module(name="base", args_cls=TrainingArguments, required_kwargs={"peft_config": PeftConfig}) +class BaseTrainer(Trainer): + """ + Extended Trainer class that supports PEFT (Parameter-Efficient Fine-Tuning). + + This trainer extends the standard HuggingFace Trainer to optionally apply + PEFT configurations to the model before training. + """ + + def __init__( + self, + model=None, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + model_init=None, + compute_metrics=None, + callbacks=None, + optimizers=(None, None), + preprocess_logits_for_metrics=None, + peft_config: Optional[PeftConfig] = None, + **kwargs, + ): + """ + Initialize the BaseTrainer with optional PEFT support. + + Args: + model: The model to train + args: Training arguments + data_collator: Data collator for batching + train_dataset: Training dataset + eval_dataset: Evaluation dataset + processing_class: Tokenizer or processor + model_init: Function to initialize model + compute_metrics: Function to compute metrics + callbacks: List of callbacks + optimizers: Tuple of (optimizer, scheduler) + preprocess_logits_for_metrics: Function to preprocess logits + peft_config: Optional PEFT configuration. If provided, the model will be + wrapped with PEFT before training. + **kwargs: Additional keyword arguments + """ + # Apply PEFT to model if peft_config is provided + if peft_config is not None and model is not None: + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + # Initialize the parent Trainer class + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **kwargs, + ) diff --git a/QEfficient/finetune/experimental/core/trainer/dpo_trainer.py b/QEfficient/finetune/experimental/core/trainer/dpo_trainer.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/dpo_trainer.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/trainer/grpo_trainer.py b/QEfficient/finetune/experimental/core/trainer/grpo_trainer.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/grpo_trainer.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/trainer/kd_trainer.py b/QEfficient/finetune/experimental/core/trainer/kd_trainer.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/kd_trainer.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/trainer/reward_trainer.py b/QEfficient/finetune/experimental/core/trainer/reward_trainer.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/reward_trainer.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/trainer/sft_trainer.py b/QEfficient/finetune/experimental/core/trainer/sft_trainer.py new file mode 100644 index 000000000..3223c5966 --- /dev/null +++ b/QEfficient/finetune/experimental/core/trainer/sft_trainer.py @@ -0,0 +1,15 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from trl import SFTConfig, SFTTrainer + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.config_manager import PeftConfig + + +@registry.trainer_module(name="sft", args_cls=SFTConfig, required_kwargs={"peft_config": PeftConfig}) +class SFTTrainerModule(SFTTrainer): + pass # Just using the standard SFTTrainer diff --git a/QEfficient/finetune/experimental/core/utils/__init__.py b/QEfficient/finetune/experimental/core/utils/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/utils/dataset_utils.py b/QEfficient/finetune/experimental/core/utils/dataset_utils.py new file mode 100644 index 000000000..11e2fecfc --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/dataset_utils.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +def insert_pad_token(tokenizer): + # Add pad token if it doesn't exist + if tokenizer.pad_token is None: + # Try to use existing special token as pad token + if tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + elif tokenizer.bos_token is not None: + tokenizer.pad_token = tokenizer.bos_token + elif tokenizer.sep_token is not None: + tokenizer.pad_token = tokenizer.sep_token + else: + # Add a new pad token + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + +def apply_train_test_split(dataset, split_ratio, split, seed): + """ + Apply train/test split to the dataset based on split_ratio. + """ + splitted_dataset = dataset.train_test_split(test_size=(1 - split_ratio), seed=seed) + if split == "test": + dataset = splitted_dataset["test"] + else: + dataset = splitted_dataset["train"] + return dataset diff --git a/QEfficient/finetune/experimental/core/utils/dist_utils.py b/QEfficient/finetune/experimental/core/utils/dist_utils.py new file mode 100644 index 000000000..aed88862d --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/dist_utils.py @@ -0,0 +1,39 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch.distributed as dist + + +def is_dist_available_and_initialized() -> bool: + """Check if distributed training is available and initialized.""" + return dist.is_available() and dist.is_initialized() + + +def get_rank() -> int: + """Return the global rank of the current process, else 0.""" + if not is_dist_available_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """Return the local rank of the current process on its node, else 0.""" + if not is_dist_available_and_initialized(): + return 0 + return dist.get_node_local_rank() + + +def get_world_size() -> int: + """Get the total number of processes in distributed training.""" + if not is_dist_available_and_initialized(): + return 1 + return dist.get_world_size() + + +def is_main_process() -> bool: + """Check if the current process is the main process (rank 0).""" + return get_rank() == 0 diff --git a/QEfficient/finetune/experimental/core/utils/import_utils.py b/QEfficient/finetune/experimental/core/utils/import_utils.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/import_utils.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/docs/ReadMe.md b/QEfficient/finetune/experimental/docs/ReadMe.md new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/finetune/experimental/examples/ReadMe.md b/QEfficient/finetune/experimental/examples/ReadMe.md new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/finetune/experimental/extensions/preprocessing/__init__.py b/QEfficient/finetune/experimental/extensions/preprocessing/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/extensions/preprocessing/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/tests/__init__.py b/QEfficient/finetune/experimental/tests/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/finetune/experimental/tests/test_dataset.py b/QEfficient/finetune/experimental/tests/test_dataset.py new file mode 100644 index 000000000..ca2fc1450 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_dataset.py @@ -0,0 +1,528 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Tests for dataset components. +""" + +import json +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from QEfficient.finetune.experimental.core.dataset import BaseDataset, SFTDataset + +SEED = 42 +SPLIT_RATIO = 0.8 + + +class TestBaseDataset(unittest.TestCase): + """Tests for BaseDataset abstract class.""" + + def test_base_dataset_cannot_be_instantiated(self): + """Test that BaseDataset cannot be instantiated directly.""" + with self.assertRaises(TypeError): + BaseDataset(dataset_name="test", split="train") + + +class TestSFTDataset(unittest.TestCase): + """Tests for SFTDataset class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a temporary directory for test files + self.test_dir = tempfile.mkdtemp() + self.json_file_path = os.path.join(self.test_dir, "test_dataset.json") + + # Create a dummy JSON dataset + self.dummy_data = [ + {"question": "What is AI?", "answer": "Artificial Intelligence"}, + {"question": "What is ML?", "answer": "Machine Learning"}, + {"question": "What is DL?", "answer": "Deep Learning"}, + {"question": "What is NLP?", "answer": "Natural Language Processing"}, + {"question": "", "answer": "Empty question"}, # Empty question + {"question": "Valid question", "answer": ""}, # Empty answer + {"question": None, "answer": "None question"}, # None question + {"question": "Valid question 2", "answer": None}, # None answer + ] + + with open(self.json_file_path, "w") as f: + json.dump(self.dummy_data, f) + + def tearDown(self): + """Clean up test fixtures.""" + # Remove temporary files and directories + import shutil + + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset") + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset_builder") + def test_sft_dataset_with_huggingface_dataset_and_templates(self, mock_builder, mock_load): + """Test loading from HuggingFace dataset with templates using mocked data.""" + # Create mock dataset with dummy data + mock_dataset = MagicMock() + mock_dataset.column_names = ["text", "label"] + mock_dataset.num_rows = 3 + + # Mock the select method to return individual samples + def mock_select(indices): + sample_data = [ + {"text": "Sample text 1", "label": "Label 1"}, + {"text": "Sample text 2", "label": "Label 2"}, + {"text": "Sample text 3", "label": "Label 3"}, + ] + return [sample_data[indices[0]]] + + mock_dataset.select = mock_select + mock_dataset.filter = lambda func: mock_dataset # Return self for filtering + + # Mock train_test_split to return a dict with train/test splits + mock_split_result = {"train": mock_dataset, "test": mock_dataset} + mock_dataset.train_test_split = lambda test_size, seed: mock_split_result + + # Mock the dataset builder to indicate multiple splits are available + mock_info = MagicMock() + mock_info.splits = {"train": MagicMock(), "test": MagicMock()} + mock_builder.return_value.info = mock_info + + # Mock load_dataset to return our mock dataset + mock_load.return_value = mock_dataset + + # Create the dataset + dataset = SFTDataset( + dataset_name="dummy_hf_dataset", + split="train", + prompt_template="Text: {text}", + completion_template="Label: {label}", + ) + + self.assertIsNotNone(dataset) + self.assertEqual(len(dataset), 3) + + # Test __getitem__ + sample = dataset[0] + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + self.assertTrue(sample["prompt"].startswith("Text:")) + self.assertTrue(sample["completion"].startswith("Label:")) + + def test_sft_dataset_with_json_file_and_templates(self): + """Test loading from JSON file with templates.""" + dataset = SFTDataset( + dataset_name="dummy", # Ignored when json_file_path is provided + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIsNotNone(dataset) + # After filtering empty/None values and applying train split (default 0.8) + # we get a subset of the 4 valid samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 4) + + # Test __getitem__ + sample = dataset[0] + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + self.assertTrue(sample["prompt"].startswith("Q:")) + self.assertTrue(sample["completion"].startswith("A:")) + + def test_sft_dataset_json_file_without_filtering(self): + """Test loading from JSON file without filtering empty samples.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + remove_samples_with_empty_columns=False, + ) + + # When filtering is disabled and split="train" is used, it still applies train/test split + # So we get ~80% of 8 samples = ~6 samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 8) + + def test_sft_dataset_train_test_split_from_json(self): + """Test train/test split when loading from JSON file.""" + train_dataset = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + test_dataset = SFTDataset( + dataset_name="dummy", + split="test", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + # After filtering, we have 4 valid samples + # With split ratio, train should have ~3 samples, test should have ~1 sample + self.assertGreater(len(train_dataset), 0) + self.assertGreater(len(test_dataset), 0) + # Total should equal the filtered dataset size + self.assertEqual(len(train_dataset) + len(test_dataset), 4) + + def test_sft_dataset_with_custom_prompt_function(self): + """Test loading with custom prompt function.""" + # Create a temporary module file with custom functions + func_file_path = os.path.join(self.test_dir, "custom_funcs.py") + with open(func_file_path, "w") as f: + f.write(""" +def custom_prompt(example): + return f"Custom prompt: {example['question']}" + +def custom_completion(example): + return f"Custom completion: {example['answer']}" +""") + + # Add the test directory to sys.path temporarily + import sys + + sys.path.insert(0, self.test_dir) + + try: + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="custom_funcs:custom_prompt", + completion_func="custom_funcs:custom_completion", + ) + + self.assertIsNotNone(dataset) + self.assertGreater(len(dataset), 0) + + # Test that custom functions are applied + sample = dataset[0] + self.assertTrue(sample["prompt"].startswith("Custom prompt:")) + self.assertTrue(sample["completion"].startswith("Custom completion:")) + finally: + # Clean up + sys.path.remove(self.test_dir) + if os.path.exists(func_file_path): + os.remove(func_file_path) + + def test_sft_dataset_missing_template_variable(self): + """Test error when template variable is not in dataset columns.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {nonexistent_column}", + completion_template="A: {answer}", + ) + + self.assertIn("not found in dataset columns", str(context.exception)) + + def test_sft_dataset_missing_completion_template_variable(self): + """Test error when completion template variable is not in dataset columns.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {nonexistent_column}", + ) + + self.assertIn("not found in dataset columns", str(context.exception)) + + def test_sft_dataset_no_prompt_template_or_func(self): + """Test error when neither prompt_template nor prompt_func is provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + completion_template="A: {answer}", + ) + + self.assertIn("Either provide prompt_template or prompt_func", str(context.exception)) + + def test_sft_dataset_both_prompt_template_and_func(self): + """Test error when both prompt_template and prompt_func are provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + prompt_func="module:function", + completion_template="A: {answer}", + ) + + self.assertIn("Either provide prompt_template or prompt_func", str(context.exception)) + + def test_sft_dataset_no_completion_template_or_func(self): + """Test error when neither completion_template nor completion_func is provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + ) + + self.assertIn( + "Either provide completion_template or completion_func", + str(context.exception), + ) + + def test_sft_dataset_both_completion_template_and_func(self): + """Test error when both completion_template and completion_func are provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + completion_func="module:function", + ) + + self.assertIn( + "Either provide completion_template or completion_func", + str(context.exception), + ) + + def test_sft_dataset_invalid_func_path_format(self): + """Test error when func_path doesn't contain colon separator.""" + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="invalid_format", + completion_template="A: {answer}", + ) + + self.assertIn("must be in the format", str(context.exception)) + + def test_sft_dataset_invalid_module_import(self): + """Test error when module cannot be imported.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="nonexistent_module:function", + completion_template="A: {answer}", + ) + + self.assertIn("Unable to import module", str(context.exception)) + + def test_sft_dataset_invalid_function_name(self): + """Test error when function doesn't exist in module.""" + # Create a temporary module file without the expected function + func_file_path = os.path.join(self.test_dir, "test_module.py") + with open(func_file_path, "w") as f: + f.write("def some_other_function():\n pass\n") + + import sys + + sys.path.insert(0, self.test_dir) + + try: + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="test_module:nonexistent_function", + completion_template="A: {answer}", + ) + + self.assertIn("not found in module", str(context.exception)) + finally: + sys.path.remove(self.test_dir) + if os.path.exists(func_file_path): + os.remove(func_file_path) + + def test_sft_dataset_filter_empty_or_none_samples(self): + """Test filtering of samples with empty or None values.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + remove_samples_with_empty_columns=True, + ) + + # Verify that all samples have valid (non-empty) questions and answers + for i in range(len(dataset)): + sample = dataset[i] + # Extract the actual question and answer from the formatted strings + question = sample["prompt"].replace("Q: ", "").strip() + answer = sample["completion"].replace("A: ", "").strip() + # Verify neither is empty + self.assertTrue(len(question) > 0, f"Question should not be empty: {sample['prompt']}") + self.assertTrue(len(answer) > 0, f"Answer should not be empty: {sample['completion']}") + + def test_sft_dataset_getitem_returns_correct_format(self): + """Test that __getitem__ returns the correct format.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + sample = dataset[0] + + # Check that sample is a dictionary + self.assertIsInstance(sample, dict) + + # Check that it has the required keys + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + + # Check that values are strings + self.assertIsInstance(sample["prompt"], str) + self.assertIsInstance(sample["completion"], str) + + def test_sft_dataset_len(self): + """Test __len__ method.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + # Check that len returns an integer + self.assertIsInstance(len(dataset), int) + + # Check that len is positive + self.assertGreater(len(dataset), 0) + + # Check that we can iterate through all samples + for i in range(len(dataset)): + sample = dataset[i] + self.assertIsNotNone(sample) + + def test_sft_dataset_with_multiple_template_variables(self): + """Test templates with multiple variables.""" + # Create a more complex JSON dataset + complex_data = [ + {"context": "The sky", "question": "What color?", "answer": "Blue"}, + {"context": "Math", "question": "What is 2+2?", "answer": "4"}, + ] + + complex_json_path = os.path.join(self.test_dir, "complex_dataset.json") + with open(complex_json_path, "w") as f: + json.dump(complex_data, f) + + try: + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=complex_json_path, + prompt_template="Context: {context}\nQuestion: {question}", + completion_template="Answer: {answer}", + ) + + # With split="train", it applies train/test split, so we get ~80% of 2 samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 2) + + sample = dataset[0] + self.assertIn("Context:", sample["prompt"]) + self.assertIn("Question:", sample["prompt"]) + self.assertIn("Answer:", sample["completion"]) + finally: + if os.path.exists(complex_json_path): + os.remove(complex_json_path) + + def test_sft_dataset_seed_reproducibility(self): + """Test that using the same seed produces the same split.""" + dataset1 = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + dataset2 = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + # Both datasets should have the same length + self.assertEqual(len(dataset1), len(dataset2)) + + # Both datasets should have the same samples + for i in range(len(dataset1)): + sample1 = dataset1[i] + sample2 = dataset2[i] + self.assertEqual(sample1["prompt"], sample2["prompt"]) + self.assertEqual(sample1["completion"], sample2["completion"]) + + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset") + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset_builder") + def test_sft_dataset_invalid_split(self, mock_builder, mock_load): + """Test error when requesting an invalid split.""" + # Mock the dataset builder to return specific splits + mock_info = MagicMock() + mock_info.splits = {"train": MagicMock(), "validation": MagicMock()} + mock_builder.return_value.info = mock_info + + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy_dataset", + split="nonexistent_split", + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIn("not available", str(context.exception)) + + def test_sft_dataset_invalid_json_path(self): + """Test error when an invalid JSON file path is provided.""" + invalid_path = "/path/to/nonexistent/file.json" + + with self.assertRaises(FileNotFoundError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=invalid_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIn("JSON file not found or invalid", str(context.exception)) + self.assertIn(invalid_path, str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/QEfficient/finetune/experimental/tests/test_logger.py b/QEfficient/finetune/experimental/tests/test_logger.py new file mode 100644 index 000000000..0af0c8b51 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_logger.py @@ -0,0 +1,233 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import logging +from unittest.mock import patch + +import pytest + +from QEfficient.finetune.experimental.core.logger import Logger, get_logger + + +class TestLogger: + def setup_method(self): + """Reset the global logger before each test method""" + import QEfficient.finetune.experimental.core.logger as logger_module + + logger_module._logger = None + + def test_init_console_only(self): + """Test logger initialization with console-only output""" + logger = Logger("test_logger") + + # Check logger attributes + assert logger.logger.name == "test_logger" + assert logger.logger.level == logging.INFO + + # Check handlers - should have console handler only + assert len(logger.logger.handlers) == 1 # Only console handler + assert isinstance(logger.logger.handlers[0], logging.StreamHandler) + + def test_init_with_file(self, tmp_path): + """Test logger initialization with file output""" + log_file = tmp_path / "test.log" + logger = Logger("file_test_logger", str(log_file)) + + # Check handlers - should have both console and file handlers + assert len(logger.logger.handlers) == 2 # Console + file handler + assert isinstance(logger.logger.handlers[0], logging.StreamHandler) + assert isinstance(logger.logger.handlers[1], logging.FileHandler) + + # Check file creation + assert log_file.exists() + + def test_log_levels(self, caplog): + """Test all log levels work correctly""" + logger = Logger("level_test_logger", level=logging.DEBUG) + + with caplog.at_level(logging.DEBUG): + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + logger.critical("Critical message") + + # Check all messages were logged + assert "Debug message" in caplog.text + assert "Info message" in caplog.text + assert "Warning message" in caplog.text + assert "Error message" in caplog.text + assert "Critical message" in caplog.text + + @patch("QEfficient.finetune.experimental.core.logger.get_local_rank") + def test_log_rank_zero_positive_case(self, mock_get_local_rank, caplog): + """Test rank zero logging functionality""" + mock_get_local_rank.return_value = 0 + logger = Logger("rank_test_logger") + + with caplog.at_level(logging.INFO): + logger.log_rank_zero("Rank zero message") + + assert "Rank zero message" in caplog.text + + @patch("QEfficient.finetune.experimental.core.logger.get_local_rank") + def test_log_rank_zero_negative_case(self, mock_get_local_rank, caplog): + """Test to verify that only rank‑zero messages are logged""" + mock_get_local_rank.return_value = 1 + logger = Logger("rank_test_logger") + + with caplog.at_level(logging.INFO): + logger.log_rank_zero("Should not appear") + + assert "Should not appear" not in caplog.text + + def test_log_exception_raise(self, caplog): + """Test exception logging with raising""" + logger = Logger("exception_test_logger") + + with pytest.raises(ValueError), caplog.at_level(logging.ERROR): + logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=True) + + # The actual logged message is "Custom error: Test exception" + # But the exception itself contains just "Test exception" + assert "Custom error: Test exception" in caplog.text + + def test_log_exception_no_raise(self, caplog): + """Test exception logging without raising""" + logger = Logger("exception_test_logger") + + with caplog.at_level(logging.ERROR): + logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=False) + + # Check that the formatted message was logged + assert "Custom error: Test exception" in caplog.text + + def test_prepare_for_logs(self, tmp_path): + """Test preparing logger for training logs""" + output_dir = tmp_path / "output" + logger = Logger("prepare_test_logger") + + # Prepare for logs + logger.prepare_for_logs(str(output_dir), log_level="DEBUG") + + # Check file handler was added + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 1 + + # Check file exists + log_file = output_dir / "training.log" + assert log_file.exists() + + # Check log level was updated + assert logger.logger.level == logging.DEBUG + + def test_prepare_for_logs_no_file_handler(self): + """Test preparing logger without saving to file""" + logger = Logger("prepare_test_logger") + + # Prepare for logs without saving metrics + logger.prepare_for_logs(log_level="INFO") + + # Check no file handler was added + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 0 + + def test_prepare_for_logs_already_has_file_handler(self, tmp_path): + """Test preparing logger when file handler already exists""" + output_dir = tmp_path / "output" + logger = Logger("prepare_test_logger") + + # Add a file handler manually first + log_file = output_dir / "manual.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(str(log_file)) + logger.logger.addHandler(file_handler) + + # Prepare for logs again + logger.prepare_for_logs(str(output_dir), log_level="INFO") + + # Should still have only one file handler + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 1 + + def test_get_logger_singleton(self): + """Test that get_logger returns the same instance""" + logger1 = get_logger() + logger2 = get_logger() + + assert logger1 is logger2 + + def test_get_logger_with_file(self, tmp_path): + """Test get_logger with file parameter""" + log_file = tmp_path / "get_logger_test.log" + logger = get_logger(str(log_file)) + + # Check that we have 2 handlers (console + file) + assert len(logger.logger.handlers) == 2 # Console + file + assert isinstance(logger.logger.handlers[1], logging.FileHandler) + + # Check file exists + assert log_file.exists() + + +class TestLoggerIntegration: + """Integration tests for logger functionality""" + + def setup_method(self): + """Reset the global logger before each test method""" + import QEfficient.finetune.experimental.core.logger as logger_module + + logger_module._logger = None + + def test_complete_workflow(self, tmp_path, caplog): + """Test complete logger workflow""" + # Setup + log_file = tmp_path / "workflow.log" + logger = Logger("workflow_test", str(log_file), logging.DEBUG) + + # Test all methods + logger.debug("Debug test") + logger.info("Info test") + logger.warning("Warning test") + logger.error("Error test") + logger.critical("Critical test") + + # Test exception handling + try: + raise ValueError("Test exception") + except ValueError as e: + logger.log_exception("Caught exception", e, raise_exception=False) + + # Test rank zero logging + with patch("QEfficient.finetune.experimental.core.logger.get_local_rank") as mock_rank: + mock_rank.return_value = 0 + logger.log_rank_zero("Rank zero test") + + # Verify all messages were logged + with caplog.at_level(logging.DEBUG): + assert "Debug test" in caplog.text + assert "Info test" in caplog.text + assert "Warning test" in caplog.text + assert "Error test" in caplog.text + assert "Critical test" in caplog.text + assert "Caught exception: Test exception" in caplog.text + assert "Rank zero test" in caplog.text + + # Check file was written to + assert log_file.exists() + content = log_file.read_text() + assert "Debug test" in content + assert "Info test" in content + assert "Warning test" in content + assert "Error test" in content + assert "Critical test" in content + assert "Caught exception: Test exception" in content + assert "Rank zero test" in content + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/QEfficient/finetune/experimental/tests/test_optimizer.py b/QEfficient/finetune/experimental/tests/test_optimizer.py new file mode 100644 index 000000000..e105d5ddf --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_optimizer.py @@ -0,0 +1,96 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy + +import pytest +import torch.nn as nn +import torch.optim as optim + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.optimizer import prepare_optimizer + +OPTIMIZER_CONFIGS = { + "Adam": { + "optimizer_name": "Adam", + "opt_cls": optim.Adam, + "lr": 1e-4, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-8, + "amsgrad": False, + }, + "AdamW": { + "optimizer_name": "AdamW", + "opt_cls": optim.AdamW, + "lr": 1e-4, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-8, + "amsgrad": False, + }, + "SGD": { + "optimizer_name": "SGD", + "opt_cls": optim.SGD, + "lr": 1e-4, + "momentum": 0.9, + "weight_decay": 0.01, + "dampening": 0.0, + "nesterov": False, + }, + "RMSprop": { + "optimizer_name": "RMSprop", + "opt_cls": optim.RMSprop, + }, +} + +REGISTRY_CONFIG = { + "RMSprop": { + "optimizer_name": "RMSprop", + "opt_cls": optim.RMSprop, + }, +} + + +@pytest.fixture +def dummy_model(): + return nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + + +@pytest.mark.parametrize("opt_name", OPTIMIZER_CONFIGS.keys()) +def test_optimizers(opt_name, dummy_model): + """Test that all registered optimizers can be created with their configs.""" + config = copy.deepcopy(OPTIMIZER_CONFIGS[opt_name]) + + config.pop("opt_cls") + try: + optimizer_class_and_kwargs = prepare_optimizer(config) + assert optimizer_class_and_kwargs is not None + except ValueError as e: + assert "Unknown optimizer" in str(e) + return + optimizer_class = optimizer_class_and_kwargs[0] + opt_inst = optimizer_class(dummy_model.parameters(), **optimizer_class_and_kwargs[1]) + assert isinstance(opt_inst, optim.Optimizer) + assert len(list(opt_inst.param_groups)) == 1 + + for key in ["lr", "weight_decay", "betas", "eps", "momentum", "dampening", "nesterov", "amsgrad"]: + if key in config: + assert opt_inst.param_groups[0][key] == config[key], f"{key} mismatch" + + +@pytest.mark.parametrize("opt_name, opt_cls", REGISTRY_CONFIG.items()) +def test_registered_optimizer(opt_name, opt_cls): + """Test that the optimizer registerd correctly.""" + registry.optimizer(opt_name)(opt_cls) + optimizer_class = registry.get_optimizer(opt_name) + assert optimizer_class is not None + assert optimizer_class == opt_cls diff --git a/QEfficient/finetune/experimental/tests/test_registry.py b/QEfficient/finetune/experimental/tests/test_registry.py new file mode 100644 index 000000000..3e10aa820 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_registry.py @@ -0,0 +1,167 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest + +from QEfficient.finetune.experimental.core.component_registry import ComponentRegistry, get_object, registry + + +class TestComponentRegistry: + @pytest.fixture(autouse=True) + def setUp(self): + """Set up test fixtures before each test method.""" + self.registry = ComponentRegistry() + + @pytest.mark.parametrize( + "register_method, get_method, object_name", + [ + ("trainer_module", "get_trainer_module", "test_trainer"), + ("optimizer", "get_optimizer", "test_optimizer"), + ("scheduler", "get_scheduler", "test_scheduler"), + ("dataset", "get_dataset", "test_dataset"), + ("model", "get_model", "test_model"), + ("data_collator", "get_data_collator", "test_collator"), + ("loss_function", "get_loss_function", "test_loss"), + ("callback", "get_callback", "test_callback"), + ], + ) + def test_object_success(self, register_method: str, get_method: str, object_name: str): + """Test object registration decorator.""" + + class MockObject: + pass + + # Register with decorator + getattr(self.registry, register_method)(object_name)(MockObject) + + # Verify registration + retrieved = getattr(self.registry, get_method)(object_name) + if register_method == "trainer_module": + retrieved = retrieved["trainer_cls"] + assert retrieved == MockObject + + @pytest.mark.parametrize( + "object_type, get_method", + [ + ("trainer module", "get_trainer_module"), + ("optimizer", "get_optimizer"), + ("scheduler", "get_scheduler"), + ("dataset", "get_dataset"), + ("model", "get_model"), + ("data collator", "get_data_collator"), + ("loss function", "get_loss_function"), + ("callback", "get_callback"), + ], + ) + def test_object_failure(self, object_type: str, get_method: str, object_name: str = "non_existent"): + """Test failure when retrieving non-existent object.""" + with pytest.raises(ValueError) as exc_info: + getattr(self.registry, get_method)(object_name) + + assert f"Unknown {object_type}" in str(exc_info.value) + + def test_init_empty_registries(self): + """Test that all registries are initialized as empty dictionaries.""" + assert len(self.registry._optimizers) == 0 + assert len(self.registry._schedulers) == 0 + assert len(self.registry._datasets) == 0 + assert len(self.registry._models) == 0 + assert len(self.registry._data_collators) == 0 + assert len(self.registry._metrics) == 0 + assert len(self.registry._loss_functions) == 0 + assert len(self.registry._callbacks) == 0 + assert len(self.registry._hooks) == 0 + assert len(self.registry._trainer_modules) == 0 + + def test_trainer_module_with_args_and_kwargs(self): + """Test trainer module registration with args class and required kwargs.""" + + class MockArgs: + pass + + class MockTrainer: + pass + + # Register with decorator including args class and required kwargs + self.registry.trainer_module( + "test_trainer_with_args", args_cls=MockArgs, required_kwargs={"param1": "default1", "param2": "default2"} + )(MockTrainer) + + # Verify registration details + module_info = self.registry.get_trainer_module("test_trainer_with_args") + assert module_info["trainer_cls"] == MockTrainer + assert module_info["args_cls"] == MockArgs + assert module_info["required_kwargs"] == {"param1": "default1", "param2": "default2"} + + def test_list_methods(self): + """Test all list methods return correct keys.""" + + # Register some dummy items + class DummyClass: + pass + + self.registry.optimizer("opt1")(DummyClass) + self.registry.scheduler("sched1")(DummyClass) + self.registry.dataset("ds1")(DummyClass) + self.registry.model("model1")(DummyClass) + self.registry.data_collator("coll1")(lambda x: x) + self.registry.loss_function("loss1")(DummyClass) + self.registry.callback("cb1")(DummyClass) + self.registry.trainer_module("tm1")(DummyClass) + + # Test lists + assert self.registry.list_optimizers() == ["opt1"] + assert self.registry.list_schedulers() == ["sched1"] + assert self.registry.list_datasets() == ["ds1"] + assert self.registry.list_models() == ["model1"] + assert self.registry.list_data_collators() == ["coll1"] + assert self.registry.list_loss_functions() == ["loss1"] + assert self.registry.list_callbacks() == ["cb1"] + assert self.registry.list_trainer_modules() == ["tm1"] + + def test_logging_on_registration(self, mocker): + """Test that registration logs messages.""" + mock_logger = mocker.patch("QEfficient.finetune.experimental.core.component_registry.logger") + + class MockClass: + pass + + # Test optimizer registration logging + self.registry.optimizer("test_opt")(MockClass) + mock_logger.info.assert_called_with("Registered optimizer: test_opt") + + # Reset mock + mock_logger.reset_mock() + + # Test trainer module registration logging + self.registry.trainer_module("test_tm")(MockClass) + mock_logger.info.assert_called_with("Registered trainer module: test_tm") + + +class TestGetObjectFunction: + def test_get_object_success(self): + """Test get_object function success case.""" + test_dict = {"key1": "value1", "key2": "value2"} + + result = get_object(test_dict, "key1", "test_type", lambda: ["key1", "key2"]) + assert result == "value1" + + def test_get_object_failure(self): + """Test get_object function failure case.""" + test_dict = {"key1": "value1"} + + with pytest.raises(ValueError) as exc_info: + get_object(test_dict, "nonexistent", "test_type", lambda: ["key1", "key2"]) + + assert "Unknown test_type: nonexistent" in str(exc_info.value) + assert "Available: ['key1', 'key2']" in str(exc_info.value) + + +class TestGlobalRegistry: + def test_global_registry_instance(self): + """Test that global registry instance exists and is of correct type.""" + assert isinstance(registry, ComponentRegistry) diff --git a/QEfficient/finetune/experimental/tests/test_trainer.py b/QEfficient/finetune/experimental/tests/test_trainer.py new file mode 100644 index 000000000..20af61e36 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_trainer.py @@ -0,0 +1,493 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import shutil + +import pytest +import torch +from datasets import Dataset +from peft import LoraConfig, PeftModel +from transformers import Trainer, TrainingArguments +from trl import SFTConfig, SFTTrainer + +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401 - needed for registration +from QEfficient.finetune.experimental.core.trainer.base_trainer import BaseTrainer +from QEfficient.finetune.experimental.core.trainer.sft_trainer import ( + SFTTrainerModule, +) + +LORA_R = 8 +LORA_ALPHA = 16 +LORA_DROPOUT = 0.1 +MAX_LENGTH = 128 + + +class TestBaseTrainer: + """Test suite for BaseTrainer class.""" + + def test_base_trainer_registered(self): + """Test that BaseTrainer is registered in the registry.""" + trainer_list = registry.list_trainer_modules() + assert "base" in trainer_list + + def test_base_trainer_info_structure(self): + """Test that BaseTrainer registration has correct structure.""" + trainer_info = registry.get_trainer_module("base") + + assert isinstance(trainer_info, dict) + assert "trainer_cls" in trainer_info + assert "args_cls" in trainer_info + assert "required_kwargs" in trainer_info + + def test_base_trainer_class(self): + """Test that BaseTrainer class is correct.""" + + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # The decorator returns the dict, but BaseTrainer is the original class + assert trainer_cls.__name__ == "BaseTrainer" + assert issubclass(trainer_cls, Trainer) + assert trainer_info["args_cls"] == TrainingArguments + + def test_base_trainer_required_kwargs(self): + """Test that BaseTrainer has peft_config in required_kwargs.""" + trainer_info = registry.get_trainer_module("base") + + assert "peft_config" in trainer_info["required_kwargs"] + assert callable(trainer_info["required_kwargs"]["peft_config"]) + + +class TestSFTTrainerModule: + """Test suite for SFTTrainerModule class.""" + + def test_sft_trainer_registered(self): + """Test that SFTTrainerModule is registered in the registry.""" + trainer_list = registry.list_trainer_modules() + assert "sft" in trainer_list + + def test_sft_trainer_info_structure(self): + """Test that SFTTrainerModule registration has correct structure.""" + trainer_info = registry.get_trainer_module("sft") + + assert isinstance(trainer_info, dict) + assert "trainer_cls" in trainer_info + assert "args_cls" in trainer_info + assert "required_kwargs" in trainer_info + + def test_sft_trainer_class(self): + """Test that SFTTrainerModule class is correct.""" + + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + assert trainer_cls == SFTTrainerModule["trainer_cls"] + assert issubclass(trainer_cls, SFTTrainer) + assert trainer_info["args_cls"] == SFTConfig + + def test_sft_trainer_required_kwargs(self): + """Test that SFTTrainerModule has peft_config in required_kwargs.""" + trainer_info = registry.get_trainer_module("sft") + + assert "peft_config" in trainer_info["required_kwargs"] + assert callable(trainer_info["required_kwargs"]["peft_config"]) + + +class TestTrainerRegistry: + """Test suite for trainer registration in the component registry.""" + + def test_both_trainers_registered(self): + """Test that both base and sft trainers are registered.""" + trainer_list = registry.list_trainer_modules() + + assert "base" in trainer_list + assert "sft" in trainer_list + assert len(trainer_list) >= 2 + + def test_registry_returns_dict(self): + """Test that registry returns dict for trainer modules.""" + base_info = registry.get_trainer_module("base") + sft_info = registry.get_trainer_module("sft") + + assert isinstance(base_info, dict) + assert isinstance(sft_info, dict) + + def test_trainer_classes_correct(self): + """Test that trainer classes are correctly stored.""" + base_info = registry.get_trainer_module("base") + sft_info = registry.get_trainer_module("sft") + assert base_info["trainer_cls"] == BaseTrainer["trainer_cls"] + assert sft_info["trainer_cls"] == SFTTrainerModule["trainer_cls"] + + +class TestBaseTrainerWithModel: + """Test suite for BaseTrainer integration with model loading and PEFT.""" + + @pytest.fixture(autouse=True) + def cleanup_output_dirs(self): + """Fixture to clean up test output directories after each test.""" + # Setup: yield control to the test + yield + + # Teardown: clean up output directories + output_dirs = ["./test_output", "./test_output_peft", "./test_output_base", "./test_output_base_peft"] + for output_dir in output_dirs: + if os.path.exists(output_dir): + try: + shutil.rmtree(output_dir) + print(f"\nCleaned up: {output_dir}") + except Exception as e: + print(f"\nWarning: Failed to clean up {output_dir}: {e}") + + @pytest.fixture + def model_config(self): + """Fixture for basic model configuration.""" + return { + "model_name": "HuggingFaceTB/SmolLM-135M", + "auto_class_name": "AutoModelForCausalLM", + "use_cache": False, + "torch_dtype": "float16", + "attn_implementation": "eager", + "device_map": None, + "num_hidden_layers": 1, + } + + @pytest.fixture + def peft_model_config(self): + """Fixture for PEFT configuration.""" + return { + "r": LORA_R, + "lora_alpha": LORA_ALPHA, + "lora_dropout": LORA_DROPOUT, + "target_modules": ["q_proj", "v_proj"], + "bias": "none", + } + + @pytest.fixture + def dummy_dataset(self): + """Fixture for creating a dummy dataset.""" + data = { + "text": [ + "This is a test sentence for training.", + "Another example text for the model.", + "Third sample to ensure proper batching.", + ] + } + return Dataset.from_dict(data) + + def test_base_trainer_instantiation_with_model(self, model_config, dummy_dataset): + """Test that BaseTrainer can be instantiated with a loaded model.""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer without PEFT + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + ) + + assert trainer is not None + assert trainer.model is not None + assert trainer.processing_class is not None + + def test_base_trainer_with_peft_model(self, model_config, peft_model_config, dummy_dataset): + """Test that BaseTrainer works with PEFT-enabled models.""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Load PEFT Config + peft_config = LoraConfig(**peft_model_config) + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base_peft", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer with PEFT config + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + assert trainer is not None + assert trainer.model is not None + + # Verify that the model is now a PEFT model + assert isinstance(trainer.model, PeftModel), "Model should be wrapped as a PeftModel" + + # Verify that the model has the expected PEFT config + assert hasattr(trainer.model, "peft_config"), "Model should have peft_config attribute" + assert trainer.model.peft_config is not None, "PEFT config should not be None" + + # Verify trainable parameters are reduced (PEFT should make only a subset trainable) + trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in trainer.model.parameters()) + + assert trainable_params < total_params, "PEFT should reduce the number of trainable parameters" + print(f"\nTrainable params: {trainable_params:,} / Total params: {total_params:,}") + + def test_base_trainer_without_peft_config(self, model_config, dummy_dataset): + """Test that BaseTrainer works without PEFT config (standard training).""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer without PEFT config + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=None, # Explicitly pass None + ) + + assert trainer is not None + assert trainer.model is not None + + # Verify that the model is NOT a PEFT model + assert not isinstance(trainer.model, PeftModel), ( + "Model should not be wrapped as a PeftModel when peft_config is None" + ) + + +class TestSFTTrainerWithModel: + """Test suite for SFTTrainer integration with model loading.""" + + @pytest.fixture(autouse=True) + def cleanup_output_dirs(self): + """Fixture to clean up test output directories after each test.""" + # Setup: yield control to the test + yield + + # Teardown: clean up output directories + output_dirs = ["./test_output", "./test_output_peft"] + for output_dir in output_dirs: + if os.path.exists(output_dir): + try: + shutil.rmtree(output_dir) + print(f"\nCleaned up: {output_dir}") + except Exception as e: + print(f"\nWarning: Failed to clean up {output_dir}: {e}") + + @pytest.fixture + def model_config(self): + """Fixture for basic model configuration.""" + return { + "model_name": "HuggingFaceTB/SmolLM-135M", + "auto_class_name": "AutoModelForCausalLM", + "use_cache": False, + "torch_dtype": "float16", + "attn_implementation": "eager", + "device_map": None, + "num_hidden_layers": 1, + } + + @pytest.fixture + def peft_model_config(self): + """Fixture for PEFT configuration.""" + return { + "lora_r": LORA_R, + "lora_alpha": LORA_ALPHA, + "lora_dropout": LORA_DROPOUT, + "target_modules": ["q_proj", "v_proj"], + "bias": "none", + } + + @pytest.fixture + def dummy_dataset(self): + """Fixture for creating a dummy dataset.""" + + data = { + "text": [ + "This is a test sentence for training.", + "Another example text for the model.", + "Third sample to ensure proper batching.", + ] + } + return Dataset.from_dict(data) + + def test_model_forward_pass(self, model_config): + """Test that the loaded model can perform a forward pass.""" + + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + loaded_model = hf_model.model + tokenizer = hf_model.tokenizer + + # Prepare input + text = "This is a test." + inputs = tokenizer(text, return_tensors="pt") + + # Perform forward pass + with torch.no_grad(): + outputs = loaded_model(**inputs) + + assert outputs is not None + assert hasattr(outputs, "logits") + assert outputs.logits.shape[0] == 1 # batch size + + def test_sft_trainer_instantiation_with_model(self, model_config, dummy_dataset): + """Test that SFTTrainer can be instantiated with a loaded model.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer + trainer = trainer_cls( + model=model, + args=sft_config, + train_dataset=dummy_dataset, + processing_class=tokenizer, + ) + + assert trainer is not None + assert trainer.model is not None + assert trainer.tokenizer is not None + + def test_sft_trainer_with_peft_model(self, model_config, peft_model_config, dummy_dataset): + """Test that SFTTrainer works with PEFT-enabled models.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + # Load PEFT Config + peft_config = LoraConfig(peft_model_config) + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output_peft", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer with PEFT config + trainer = trainer_cls( + model=model, + args=sft_config, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + assert trainer is not None + assert trainer.model is not None + + def test_sft_trainer_train_dataset_required(self, model_config): + """Test that SFTTrainer requires a training dataset.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Attempt to instantiate without dataset should raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not iterable"): + trainer_cls( + model=model, + args=sft_config, + processing_class=tokenizer, + ) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index e9e1320de..45b995124 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -123,11 +123,20 @@ def train( break if train_config.use_peft and train_config.from_peft_checkpoint: - intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1 - intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1]) + path = train_config.from_peft_checkpoint.rstrip("/") + try: + intermediate_epoch = int(path.split("/")[-2].split("_")[-1]) - 1 + intermediate_step = int(path.split("/")[-1].split("_")[-1]) + except (IndexError, ValueError): + intermediate_epoch = int(path.split("/")[-1].split("_")[-1]) - 1 + intermediate_step = 0 + if epoch < intermediate_epoch: logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.") continue + if intermediate_step == 0 and epoch == intermediate_epoch: + logger.log_rank_zero(f"Skipping epoch {epoch + 1}, since fine tuning has already completed for it.") + continue logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}") if max_steps_reached: @@ -154,6 +163,7 @@ def train( # resume training from a particular checkpoint, assuming the dataset is not shuffled if train_config.use_peft and train_config.from_peft_checkpoint: # to bring the count of train_step in sync with where it left off + if epoch == intermediate_epoch and step == 0: logger.log_rank_zero( f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it." @@ -365,7 +375,7 @@ def train( eval_step_metric, eval_metric, ) - avg_epoch_time = sum(epoch_times) / len(epoch_times) + avg_epoch_time = sum(epoch_times) / len(epoch_times) if len(epoch_times) > 0 else 0 avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 results["last_epoch_train_loss"] = train_epoch_loss.cpu() diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..652a641e2 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +import platform +import sys from pathlib import Path from typing import Dict, List, Optional, Union from warnings import warn @@ -13,32 +15,29 @@ try: import qaicrt + + is_qaicrt_imported = True except ImportError: - import platform - import sys + try: + sys.path.append(f"/opt/qti-aic/dev/lib/{platform.machine()}") + import qaicrt - sys.path.append(f"/opt/qti-aic/dev/lib/{platform.machine()}") - import qaicrt + is_qaicrt_imported = True + except ImportError: + is_qaicrt_imported = False try: import QAicApi_pb2 as aicapi -except ImportError: - import sys - sys.path.append("/opt/qti-aic/dev/python") - import QAicApi_pb2 as aicapi + is_aicapi_imported = True +except ImportError: + try: + sys.path.append("/opt/qti-aic/dev/python") + import QAicApi_pb2 as aicapi -aic_to_np_dtype_mapping = { - aicapi.FLOAT_TYPE: np.dtype(np.float32), - aicapi.FLOAT_16_TYPE: np.dtype(np.float16), - aicapi.INT8_Q_TYPE: np.dtype(np.int8), - aicapi.UINT8_Q_TYPE: np.dtype(np.uint8), - aicapi.INT16_Q_TYPE: np.dtype(np.int16), - aicapi.INT32_Q_TYPE: np.dtype(np.int32), - aicapi.INT32_I_TYPE: np.dtype(np.int32), - aicapi.INT64_I_TYPE: np.dtype(np.int64), - aicapi.INT8_TYPE: np.dtype(np.int8), -} + is_aicapi_imported = True + except ImportError: + is_qaicrt_imported = False class QAICInferenceSession: @@ -58,6 +57,25 @@ def __init__( :activate: bool. If false, activation will be disabled. Default=True. :enable_debug_logs: bool. If True, It will enable debug logs. Default=False. """ + if not (is_qaicrt_imported and is_aicapi_imported): + raise ImportError( + "Unable to import `qaicrt` and/or `QAicApi_pb2` libraries required for executing QPC files on the CLOUD AI platform.\n" + "Please ensure that the QAIC platform SDK and apps SDK are installed correctly." + ) + + # Build dtype mapping once (depends on aicapi constants) + self.aic_to_np_dtype_mapping = { + aicapi.FLOAT_TYPE: np.dtype(np.float32), + aicapi.FLOAT_16_TYPE: np.dtype(np.float16), + aicapi.INT8_Q_TYPE: np.dtype(np.int8), + aicapi.UINT8_Q_TYPE: np.dtype(np.uint8), + aicapi.INT16_Q_TYPE: np.dtype(np.int16), + aicapi.INT32_Q_TYPE: np.dtype(np.int32), + aicapi.INT32_I_TYPE: np.dtype(np.int32), + aicapi.INT64_I_TYPE: np.dtype(np.int64), + aicapi.INT8_TYPE: np.dtype(np.int8), + } + # Load QPC if device_ids is not None: devices = qaicrt.QIDList(device_ids) @@ -77,7 +95,7 @@ def __init__( raise RuntimeError("Failed to getIoDescriptor") iodesc.ParseFromString(bytes(iodesc_data)) self.allowed_shapes = [ - [(aic_to_np_dtype_mapping[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes] + [(self.aic_to_np_dtype_mapping[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes] for allowed_shape in iodesc.allowed_shapes ] self.bindings = iodesc.selected_set.bindings @@ -90,12 +108,14 @@ def __init__( self.program = qaicrt.Program(self.context, None, qpc, prog_properties) if self.program.load() != qaicrt.QStatus.QS_SUCCESS: raise RuntimeError("Failed to load program") + self.is_active = False if activate: self.activate() + self.is_active = True # Create input qbuffers and buf_dims self.qbuffers = [qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings] self.buf_dims = qaicrt.BufferDimensionsVecRef( - [(aic_to_np_dtype_mapping[binding.type].itemsize, list(binding.dims)) for binding in self.bindings] + [(self.aic_to_np_dtype_mapping[binding.type].itemsize, list(binding.dims)) for binding in self.bindings] ) @property @@ -108,15 +128,17 @@ def output_names(self) -> List[str]: def activate(self): """Activate qpc""" - - self.program.activate() - self.execObj = qaicrt.ExecObj(self.context, self.program) + if not self.is_active: + self.program.activate() + self.execObj = qaicrt.ExecObj(self.context, self.program) + self.is_active = True def deactivate(self): """Deactivate qpc""" - - del self.execObj - self.program.deactivate() + if self.is_active: + del self.execObj + self.program.deactivate() + self.is_active = False def set_buffers(self, buffers: Dict[str, np.ndarray]): """ @@ -201,6 +223,6 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: continue outputs[output_name] = np.frombuffer( bytes(output_qbuffers[buffer_index]), - aic_to_np_dtype_mapping[self.bindings[buffer_index].type], + self.aic_to_np_dtype_mapping[self.bindings[buffer_index].type], ).reshape(self.buf_dims[buffer_index][1]) return outputs diff --git a/QEfficient/generation/embedding_handler.py b/QEfficient/generation/embedding_handler.py new file mode 100644 index 000000000..e07b5dd04 --- /dev/null +++ b/QEfficient/generation/embedding_handler.py @@ -0,0 +1,519 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Vision Handler for Vision-Language Models + +This module provides the VisionHandler class that encapsulates all vision model +operations, separating them from the main text generation logic. +""" + +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import requests +import torch +from PIL import Image +from transformers import AutoImageProcessor, AutoTokenizer + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants +from QEfficient.utils.logging_utils import logger + + +class VisionHandler: + """ + Handles all vision model operations for vision-language models. + + This class encapsulates vision preprocessing, inference, and output handling, + providing a clean separation between vision and language processing. + """ + + def __init__( + self, + qeff_model: Optional[QAICInferenceSession], + vision_session: Optional[QAICInferenceSession], + processor: Optional[AutoImageProcessor], + tokenizer: Optional[AutoTokenizer], + image_height: Optional[int] = None, + image_width: Optional[int] = None, + config: Optional[Dict[str, Any]] = None, + lang_session: Optional[QAICInferenceSession] = None, + ): + """ + Initialize vision handler + + Args: + vision_session: QAICInferenceSession for vision model + processor: AutoImageProcessor for image preprocessing + tokenizer: AutoTokenizer for text tokenization + image_height: Desired image height for resizing + image_width: Desired image width for resizing + config: Configuration dictionary with vision model parameters + lang_session: Optional language session for coordination (to avoid resource conflicts) + """ + self._qeff_model = qeff_model + self._vision_session = vision_session + self._processor = processor + self._tokenizer = tokenizer + self._image_height = image_height + self._image_width = image_width + self._config = config or {} + self._lang_session = lang_session # Store language session for coordination + + # Cache for vision output shapes + self._vision_output_shapes = None + + if self._vision_session and not self._processor: + logger.warning("Vision session provided but no processor. Vision functionality may be limited.") + + def is_available(self) -> bool: + """ + Check if vision processing is available + + Returns: + True if both vision session and processor are available + """ + return self._vision_session is not None and self._processor is not None + + def prepare_internVL_inputs(self, img_url: str, prompt: str) -> Dict[str, np.ndarray]: + """ + Prepare inputs for InternVL model + + Args: + image_url: URL or path to image + prompt: Text query to process with image + """ + if not self._tokenizer: + raise ValueError("Tokenizer is required for InternVL input preparation") + pixel_values = [] + num_patches_list = [] + questions = [] + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + if self._image_height and self._image_width: + image = image.resize((self._image_height, self._image_width)) + else: + logger.warning("Height and Width not specified. Using default image size for num_patches = 13.") + image = image.resize((constants.INTERN_IMAGE_HEIGHT, constants.INTERN_IMAGE_WIDTH)) + + # preprocess the resized image + pixel_value = self._processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + pixel_values.append(pixel_value) + + question = "\n" + prompt + questions.append(question) + + pixel_values = torch.cat(pixel_values, dim=0) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = self._tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + + def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + Args: + image_url: URL or path to image + query: Text query to process with image + Returns: + Dictionary of vision model inputs + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(("http://", "https://")): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + image = image.resize((constants.MOLMO_IMAGE_HEIGHT, constants.MOLMO_IMAGE_WIDTH)) + inputs = self._processor.process(images=[image], text=query) + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + inputs["pixel_values"] = inputs.pop("images") + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + + def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: + """ + Download and preprocess image into model inputs + + Args: + image_url: URL or path to image + query: Text query to process with image + prefill_seq_len: Padded sequence length for language model + + Returns: + Dictionary of vision model inputs + + Raises: + ValueError: If vision handler is not properly initialized + RuntimeError: If image processing fails + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") + + try: + # Download image + if image_url.startswith(("http://", "https://")): + image = Image.open(requests.get(image_url, stream=True).raw) + else: + image = Image.open(image_url) + + if self._image_height and self._image_width: + image = image.resize((self._image_width, self._image_height)) + else: + logger.warning("Height and Width not specified. Using default image size.") + if "mistral3" in self._qeff_model.model.config.model_type: + image = image.resize((constants.MISTRAL3_IMAGE_HEIGHT, constants.MISTRAL3_IMAGE_WIDTH)) + if "llava_next" in self._qeff_model.model.config.model_type: + image = image.resize( + (constants.GRANITEVISION_IMG_SIZE_HEIGHT, constants.GRANITEVISION_IMG_SIZE_WIDTH) + ) + + # Prepare conversation format + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process image and text + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "qwen2_5_vl" + ): + inputs = self._qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0] + ) + + # Convert to float32 if needed + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Convert to numpy arrays + vision_inputs = {} + for k, v in inputs.items(): + if k in { + "pixel_values", + "image_masks", + "image_input_idx", + "valid_idx", + "aspect_ratio_ids", + "aspect_ratio_mask", + }: + vision_inputs[k] = np.array(v) + + # Convert specific inputs to float16 + vision_inputs_fp16 = {"pixel_values", "image_masks"} + for k in vision_inputs_fp16: + if k in vision_inputs: + vision_inputs[k] = vision_inputs[k].astype("float16") + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + + return vision_inputs, lang_inputs + + except Exception as e: + raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") + + def run_vision_inference(self, vision_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Execute vision model inference with session coordination + + Args: + vision_inputs: Preprocessed vision inputs + + Returns: + Vision embeddings and metadata + + Raises: + ValueError: If vision session is not available + RuntimeError: If inference fails + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + lang_was_active = False + try: + # Coordinate with language session to avoid resource conflicts + if self._lang_session and self._lang_session.is_active: + logger.debug("Deactivating language session before vision inference") + self._lang_session.deactivate() + lang_was_active = True + + # Activate vision session + logger.debug("Activating vision session for inference") + self._vision_session.activate() + + # Run inference + vision_outputs = self._vision_session.run(vision_inputs) + + # Deactivate vision session + logger.debug("Deactivating vision session after inference") + self._vision_session.deactivate() + + # Reactivate language session if it was active before + if lang_was_active and self._lang_session: + logger.debug("Reactivating language session after vision inference") + self._lang_session.activate() + + return vision_outputs + + except Exception as e: + # Ensure proper cleanup on error + if self._vision_session: + try: + self._vision_session.deactivate() + except Exception: + logger.warning("Deactivating vision session failed") + + # Restore language session if needed + if lang_was_active and self._lang_session: + try: + self._lang_session.activate() + except Exception: + logger.warning("Deactivating language session failed") + + raise RuntimeError(f"Vision inference failed: {str(e)}") + + def get_vision_output_shapes(self) -> Dict[str, Tuple[int, ...]]: + """ + Get vision output dimensions from config or session + + Returns: + Dictionary mapping output names to shapes + """ + if self._vision_output_shapes is not None: + return self._vision_output_shapes + + # Try to get from config first + if self._config and "vision_output_shapes" in self._config: + self._vision_output_shapes = self._config["vision_output_shapes"] + return self._vision_output_shapes + + # Try to derive from vision session + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + self._vision_output_shapes = shapes + return shapes + except Exception as e: + logger.warning(f"Could not derive vision output shapes from session: {e}") + + # Fallback to default shapes (these were hard-coded in original implementation) + default_shapes = { + "vision_embeds": (2448, 5120) # This should be derived from model config + } + + logger.warning("Using default vision output shapes. Consider providing shapes in config.") + self._vision_output_shapes = default_shapes + return default_shapes + + def setup_vision_buffers(self): + """ + Configure vision model output buffers + + Raises: + ValueError: If vision session is not available + """ + if not self._vision_session: + raise ValueError("Vision session not available") + + try: + shapes = self.get_vision_output_shapes() + + # Set up output buffers + buffers = {} + for output_name, shape in shapes.items(): + # Create placeholder with appropriate dtype + if "vision_embeds" in output_name: + buffers[output_name] = np.zeros(shape, dtype=np.float16) + else: + buffers[output_name] = np.zeros(shape, dtype=np.float32) + + self._vision_session.set_buffers(buffers) + + except Exception as e: + raise RuntimeError(f"Failed to setup vision buffers: {str(e)}") + + def prepare_complete_vision_language_inputs( + self, image_url: str, query: str + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Complete pipeline: prepare inputs and run vision inference + + Args: + image_url: URL or path to image + query: Text query + + Returns: + Tuple of (vision_inputs, vision_outputs) + """ + # Prepare vision inputs + vision_inputs = self.prepare_vision_inputs(image_url, query) + + # Setup buffers + self.setup_vision_buffers() + + # Run vision inference + vision_outputs = self.run_vision_inference(vision_inputs) + + return vision_inputs, vision_outputs + + def get_processed_inputs( + self, image_url: str, query: str, prefill_seq_len: int + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """ + Process vision inputs and prepare language model inputs + + Args: + image_url: URL or path to image + query: Text query + padded_len: Padded sequence length for language model + + Returns: + Tuple of (language_inputs, vision_outputs) + """ + if not self.is_available(): + raise ValueError("Vision handler not properly initialized") + + try: + ## Get vlm inputs ## + if ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "internvl_chat" + ): + vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query) + elif ( + hasattr(self._qeff_model.model.config, "model_type") + and self._qeff_model.model.config.model_type == "molmo" + ): + vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query) + else: + vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) + + # Handle padding for language model + pad_token_id = 1 + input_ids_length = lang_inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) + padded_len = num_chunks * prefill_seq_len + + lang_inputs["input_ids"] = torch.nn.functional.pad( + lang_inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + lang_inputs["attention_mask"] = torch.nn.functional.pad( + lang_inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + + if "cross_attention_mask" in lang_inputs: + lang_inputs["cross_attention_mask"] = torch.nn.functional.pad( + lang_inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in lang_inputs.items(): + lang_inputs[k] = np.array(v) + + vision_outputs = {} + if vision_inputs: + self.setup_vision_buffers() + vision_outputs = self.run_vision_inference(vision_inputs) + + if "position_ids" in lang_inputs: + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where(lang_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + lang_inputs["image_idx"] = np.array([[0]]) + + return lang_inputs, vision_outputs, num_chunks + + except Exception as e: + raise RuntimeError(f"Failed to process vision-language inputs: {str(e)}") diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..de10c9b88 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -327,6 +329,7 @@ def cloud_ai_100_exec_kv( is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ): """ @@ -354,6 +357,8 @@ def cloud_ai_100_exec_kv( next tokens. For Speculative Decoding Target Language Model, `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative Decoding Draft Language Model and `return_pdfs`=False for regular model. + :include_guided_decoding (bool, default=False): If True, enables guided token-level filtering + during decoding. Only works when `include_sampler`=True. sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend. The dictionary should contain the following keys: `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`, @@ -384,12 +389,15 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -430,26 +438,38 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, + activate: bool = True, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs + self.include_guided_decoding = include_guided_decoding self.sampling_params = sampling_params + self._qpc_path = qpc_path # Store qpc_path for later use # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + self._session = QAICInferenceSession( + qpc_path, device_id, activate=activate, enable_debug_logs=enable_debug_logs + ) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._session.input_names), + include_sampler=include_sampler, + include_guided_decoding=include_guided_decoding, ) # Fetch the variables from the QPC @@ -616,7 +636,7 @@ def prepare_decode_inputs(self): decode_inputs["batch_index"] = self.batch_index if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if self.batch_index is not None: decode_inputs[op] = self.sampling_params[op][self.batch_index.flatten()] else: @@ -778,11 +798,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.is_tlm: inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: inputs["last_accepted_output_tokens"] = inputs["input_ids"] - for op in Constants.SAMPLER_OPS: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): if decode_batch_id is not None: inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: @@ -797,7 +818,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + for i in range(num_chunks): + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len @@ -808,6 +841,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] outputs = self._session.run(chunk_inputs) + if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) return ( @@ -816,6 +850,21 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + self.list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + + return ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -847,6 +896,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -884,6 +937,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths_decode is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -896,6 +963,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + if self.comp_ctx_lengths_decode is not None: + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + if ( + decode_inputs["position_ids"][decode_batch_id, -1] + >= self.comp_ctx_lengths_decode[ccl_id] - 1 + ): + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -922,7 +998,18 @@ def run_decode( self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -934,6 +1021,7 @@ def run_decode( # Prepare inputs for next iteration decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if self.include_sampler: @@ -983,12 +1071,15 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, include_sampler: bool = False, return_pdfs: bool = False, + include_guided_decoding: bool = False, sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._qaic_model = QEffTextGenerationBase( @@ -996,17 +1087,22 @@ def __init__( qpc_path=qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, device_id=device_id, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, is_tlm=is_tlm, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py new file mode 100644 index 000000000..adacc373e --- /dev/null +++ b/QEfficient/generation/vlm_generation.py @@ -0,0 +1,829 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +This module provides the VisionLanguageGeneration class that inherits from +QEffTextGenerationBase, enabling all advanced text generation features while +maintaining full API compatibility with the original VisionLanguageGeneration. + +Key enhancements: +- Continuous batching support for vision models +- Advanced streaming capabilities +- On-device sampling support +- LoRA adapter support +- Better performance metrics +""" + +from collections import deque +from time import perf_counter +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.embedding_handler import VisionHandler +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfo, + PerfMetrics, + QEffTextGenerationBase, + TextGeneration, + calculate_latency, + write_io_files, +) +from QEfficient.utils import LRUCache +from QEfficient.utils.constants import Constants +from QEfficient.utils.logging_utils import logger + + +class VisionLanguageGeneration(QEffTextGenerationBase): + """ + Enhanced vision-language generation class inheriting from QEffTextGenerationBase. + + This class maintains full API compatibility with VisionLanguageGeneration while + adding advanced features like continuous batching, streaming, and sampling. + + Example: + >>> # Drop-in replacement for VisionLanguageGeneration + >>> vlm = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0] + ... ) + >>> result = vlm.generate( + ... images=["image1.jpg"], + ... prompts=["Describe this image"], + ... generation_len=512 + ... ) + + >>> # Enhanced usage with new features + >>> vlm_enhanced = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0], + ... full_batch_size=8, # Enable continuous batching + ... include_sampler=True, # Enable on-device sampling + ... sampling_params=sampling_config + ... ) + """ + + def __init__( + self, + qeff_model, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + processor: AutoImageProcessor, + lang_qpc_path: str, + vision_qpc_path: str, + device_id: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + enable_debug_logs: bool = False, + write_io_dir: Optional[str] = None, + full_batch_size: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, + is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + include_guided_decoding: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize vision-language generation with enhanced capabilities + + Args: + qeff_model: QEff model instance + tokenizer: Text tokenizer + processor: Image processor + lang_qpc_path: Path to language model QPC + vision_qpc_path: Path to vision encoder QPC + device_id: Device IDs for execution (default: [0]) + ctx_len: Context length + enable_debug_logs: Enable debug logging + write_io_dir: Directory for I/O file writing + full_batch_size: Enable continuous batching (new feature) + image_height: Desired image height for resizing + image_width: Desired image width for resizing + is_tlm: Target language model flag + include_sampler: Enable on-device sampling (new feature) + return_pdfs: Return probability distributions + include_guided_decoding: Enable guided decoding in on-device sampling + sampling_params: Sampling parameters for on-device sampling + """ + # Validate required parameters + if not lang_qpc_path: + raise TypeError("lang_qpc_path is required") + if not vision_qpc_path: + raise TypeError("vision_qpc_path is required") + + # Initialize base class with language QPC + # Pass activate=False to prevent premature activation before vision components are ready + super().__init__( + tokenizer=tokenizer, + qpc_path=lang_qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, + sampling_params=sampling_params, + activate=False, # vision components need to be initialized first + ) + + # Vision-specific initialization + self.is_qwen2_5_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" + ) + self.qeff_model = qeff_model + self.processor = processor + self.tokenizer = tokenizer + self.image_height = image_height + self.image_width = image_width + self._vision_qpc_path = vision_qpc_path + self.device_id = device_id # Store device_id for vision components + self.enable_debug_logs = enable_debug_logs # Store for vision components + self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs + self._vision_cache = {} # Cache for vision outputs across batches + self._init_vision_components() + + # Now that vision components are initialized, activate the text session + self._session.activate() + + logger.info( + f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}" + ) + + def _init_vision_components(self): + """Initialize vision-specific components""" + # Vision session (separate from base class language session) + self._vision_session = QAICInferenceSession( + self._vision_qpc_path, self.device_id, activate=False, enable_debug_logs=self.enable_debug_logs + ) + + # Vision handler with language session coordination + vision_config = self._get_vision_config() + self._vision_handler = VisionHandler( + qeff_model=self.qeff_model, + vision_session=self._vision_session, + processor=self.processor, + tokenizer=self.tokenizer, + image_height=self.image_height, + image_width=self.image_width, + config=vision_config, + lang_session=self._session, # Pass language session for coordination + ) + + # Setup vision buffer skipping + self._setup_vision_buffer_skipping() + + def _get_vision_config(self) -> Dict[str, Any]: + """ + Derive vision config from session + + Returns: + Dictionary with vision configuration + """ + config = {} + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + config["vision_output_shapes"] = shapes + except Exception as e: + logger.warning(f"Could not derive vision config from session: {e}") + + return config + + def _setup_vision_buffer_skipping(self): + """Skip KV cache and retained state buffers for vision session""" + # Pre-compute skip buffers + self._vision_skip_buffers = [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + self._vision_session.skip_buffers(self._vision_skip_buffers) + + # Pre-compute language skip buffers + self._lang_skip_buffers = [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs in the prompt queue and updates the decode input. + + Method iterates over the full batch size and for each decode batch ID, it pops the next prompt from the queue. It then runs prefill for the next prompt and updates the decode input with the outputs. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + + """ + for decode_batch_id in range(self.full_batch_size): + next_prompt = prompt_queue.popleft() + + # run prefill for num_chunks + outputs, position_ids, generation_len = self.run_prefill( + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + ) + + if self.is_qwen2_5_vl: + _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + else: + _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + + def _execute_chunked_prefill( + self, + lang_inputs: Dict[str, np.ndarray], + num_chunks: int, + decode_batch_id: Optional[np.ndarray] = None, + prefill_logit_bs: int = 1, + ) -> Dict[str, np.ndarray]: + """ + Execute chunked prefill with language inputs + + Args: + lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. + num_chunks: Number of chunks to process + decode_batch_id: Batch ID for continuous batching (optional) + prefill_logit_bs: Batch size for prefill logits + + Returns: + Final prefill outputs + """ + # Set output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Skip buffers for dual-QPC coordination + self._session.skip_buffers(self._lang_skip_buffers) + + # Run chunked prefill + outputs = None + chunk_image_idx = None + + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + if self.include_sampler: + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): + if decode_batch_id is not None: + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + else: + lang_inputs[op] = self.sampling_params[op] + + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] + position_ids_slice = lang_inputs["position_ids"][ + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id + + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + + if self.include_sampler: + chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] + for op in Constants.SAMPLER_OPS | ({"token_bitmasks"} if self.include_guided_decoding else set()): + chunk_inputs[op] = lang_inputs[op] + + outputs = self._session.run(chunk_inputs) + + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare decode-time cross_attention_mask + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + return outputs + + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + """ + Override base class prefill to handle vision processing + + Args: + prompt: Can be string or tuple (image_path, text_prompt) + generation_len: Generation length + prefill_logit_bs: Prefill batch size + decode_batch_id: Batch ID for continuous batching + + Returns: + Same as base class: (outputs, position_ids, generation_len) + """ + # Normalize prompt: TextGeneration passes a list even for batch_size=1 + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], tuple) and len(prompt[0]) == 2: + # Unwrap single (image_path, text_prompt) tuple + if len(prompt) == 1: + prompt = prompt[0] + else: + raise NotImplementedError( + "VisionLanguageGeneration.run_prefill currently supports a single (image, text) pair per call." + ) + # Check if this is a vision-language prompt + if isinstance(prompt, tuple) and len(prompt) == 2: + image_path, text_prompt = prompt + + # Check cache for vision outputs + cache_key = image_path if isinstance(image_path, str) else str(image_path) + if cache_key in self._vision_cache: + lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key] + logger.debug(f"Using cached vision outputs for {cache_key}") + else: + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len + ) + # Cache for future use + self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) + logger.debug(f"Cached vision outputs for {cache_key}") + + # Set vision buffers in language session + self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + self._vision_processed = True + self._vision_outputs = vision_outputs + + # Calculate generation_len consistent with ctx_len + max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Execute chunked prefill + outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) + + self._session.skip_buffers(vision_outputs) + + # Prepare position_ids for decode phase (next position after prefill) + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + return outputs, position_ids_decode, generation_len + else: + # Fall back to base class for text-only + return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) + + def _prepare_vision_language_prompt(self, text_prompt, image_path): + """ + Prepare text prompt with vision context + + This method handles the integration of vision and text inputs + according to the specific model's requirements. + """ + # For most vision-language models, we need to apply the chat template + # that includes both image and text components + try: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + processed_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + return processed_prompt + + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") + return text_prompt + + def generate( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs + ) -> CloudAI100ExecInfo: + """ + Main generation method maintaining API compatibility with VisionLanguageGeneration + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + stream: Enable streaming output + **kwargs: Additional arguments passed to base class + + Returns: + CloudAI100ExecInfo with results and metrics + + Raises: + ValueError: If images and prompts lengths don't match + """ + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + # Clear vision cache for fresh generation + self._vision_cache.clear() + + logger.info(f"Generating for {len(images)} image-prompt pairs") + + # Convert to base class format: list of (image, prompt) tuples + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + # Use base class generate method with vision prompts + if self.full_batch_size is not None: + # Continuous batching mode (new capability) + return self._generate_continuous_batching(vision_prompts, generation_len, stream, **kwargs) + else: + # Regular batching mode + return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) + + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Handle regular batching for vision-language generation without creating a second language session""" + batch_results = [] + for i in range(0, len(vision_prompts), self.batch_size): + batch = vision_prompts[i : i + self.batch_size] + + if stream: + print( + f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}" + ) + for j, (img, prompt) in enumerate(batch): + print(f"Image: {img}") + print(f"Prompt: {prompt}") + print("Completion:", flush=True, end="") + + # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) + exec_batch_size = self.batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self.initialize_decode_inputs( + num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length + ) + + # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) + start = perf_counter() + outputs, position_ids, generation_len_final = self.run_prefill( + batch, generation_len, prefill_logit_bs=self.batch_size + ) + self.update_decode_input(outputs, position_ids, generation_len_final) + + # Prepare decode + decode_inputs = self.prepare_decode_inputs() + + # Decode loop + loop_start = perf_counter() + num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) + end = perf_counter() + + # Decode generated texts + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + # Latency metrics + total_decode_tokens = num_token + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end + ) + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + # Package result for this batch + batch_results.append( + CloudAI100ExecInfo( + batch_size=self.batch_size, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics, + ) + ) + + # Aggregate results across batches + return self._aggregate_batch_results(batch_results) + + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Enable continuous batching for vision-language models (new capability)""" + logger.info("Using continuous batching for vision-language generation") + + if stream: + logger.warning("Streaming output not fully supported with continuous batching") + + # Reset vision processing state for new generation + self._vision_processed = False + self._vision_outputs = None + self._vision_outputs_cache = {} + + # Initialize decode inputs + num_prompts = len(vision_prompts) + execution_batch_size = self.full_batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + if self.is_qwen2_5_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) + + # Create prompt queue + prompt_queue = deque(vision_prompts) + + start = perf_counter() + + # Pre-process ALL vision inputs and cache them + logger.info("Pre-processing all vision inputs...") + for batch_id in range(min(self.full_batch_size, len(vision_prompts))): + img, prompt = vision_prompts[batch_id] + + # Process vision for this slot + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len + ) + + # Cache vision outputs for this batch slot + self._vision_outputs_cache[batch_id] = { + "vision_outputs": vision_outputs, + "lang_inputs": lang_inputs, + "num_chunks": num_chunks, + } + + logger.debug(f"Cached vision outputs for batch_id {batch_id}") + + # Reset prompt queue for prefill + prompt_queue = deque(vision_prompts) + + self.batch_index = None + + # Run prefill for all inputs using cached vision + self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) + + # Set vision buffers for decode (use first slot's vision for now) + # For identical images, any slot's vision works + cached_slot_0 = self._vision_outputs_cache.get(0) + if cached_slot_0: + self._session.set_buffers(cached_slot_0["vision_outputs"]) + logger.debug("Set vision buffers from slot 0 for decode phase") + + # Now set batch_index for decode phase + self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) + + loop_start = perf_counter() + decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) + end = perf_counter() + + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + total_decode_tokens = sum( + np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) + ) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end, decode_pause_time + ) + prefill_time /= len(vision_prompts) # Average prefill time for continuous batching + + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + return CloudAI100ExecInfo( + batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics + ) + + def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs using pre-cached vision outputs. + + This avoids the vision buffer overwriting issue by using cached vision + outputs instead of processing vision during each prefill iteration. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + """ + for decode_batch_id in range(self.full_batch_size): + # Pop the promt as we are processing + _ = prompt_queue.popleft() + + # Get cached vision outputs for this batch slot + cached = self._vision_outputs_cache.get(decode_batch_id) + if cached: + vision_outputs = cached["vision_outputs"] + lang_inputs = cached["lang_inputs"] + num_chunks = cached["num_chunks"] + + # Set vision buffers for THIS prefill + self._session.set_buffers(vision_outputs) + logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") + + # Run prefill with cached inputs + outputs = self._execute_chunked_prefill( + lang_inputs, + num_chunks, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + prefill_logit_bs=1, + ) + + self._session.skip_buffers(vision_outputs.keys()) + + # Calculate position_ids for decode + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Calculate generation_len + max_gen_len = ( + self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + ) + generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) + + # Update decode inputs + if self.is_qwen2_5_vl: + self.update_decode_inputs_qwen2_5_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) + else: + self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) + else: + logger.error(f"No cached vision outputs for batch_id {decode_batch_id}") + raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}") + + def prepare_decode_inputs(self): + """ + Override base class to handle vision-specific decode inputs + """ + decode_inputs = super().prepare_decode_inputs() + + # Add image_idx for vision-language models in CB mode during decode only + if self.batch_index is not None and hasattr(self, "_vision_outputs"): + # image_idx should be a single slot selector; decoder expects shape (1,1) + # Query binding dims if available to be robust + try: + if "image_idx" in getattr(self._session, "binding_index_map", {}): + idx = self._session.binding_index_map["image_idx"] + dims = tuple(self._session.bindings[idx].dims) + decode_inputs["image_idx"] = np.zeros(dims, dtype=np.int64) + else: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + except Exception: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + + # Include cross_attention_mask during decode if present/required + if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: + # Decoder specialization expects a single mask (batch dim = 1) + decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask + + return decode_inputs + + def _aggregate_batch_results(self, batch_results): + """Aggregate results from multiple batches""" + if not batch_results: + raise ValueError("No batch results to aggregate") + + if len(batch_results) == 1: + return batch_results[0] + + # Aggregate multiple batch results + all_generated_texts = [] + all_generated_ids = [] + all_metrics = [] + + for result in batch_results: + if isinstance(result.generated_texts[0], list): + # Flatten nested lists + all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) + else: + all_generated_texts.extend(result.generated_texts) + + if isinstance(result.generated_ids, list): + all_generated_ids.extend(result.generated_ids) + else: + all_generated_ids.append(result.generated_ids) + + all_metrics.append(result.perf_metrics) + + # Average metrics + avg_metrics = PerfMetrics( + prefill_time=np.mean([m.prefill_time for m in all_metrics]), + decode_perf=np.mean([m.decode_perf for m in all_metrics]), + total_perf=np.mean([m.total_perf for m in all_metrics]), + total_time=np.mean([m.total_time for m in all_metrics]), + ) + + return CloudAI100ExecInfo( + batch_size=batch_results[0].batch_size, + generated_texts=all_generated_texts, + generated_ids=all_generated_ids, + perf_metrics=avg_metrics, + ) + + def generate_stream_tokens( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, **kwargs + ): + """ + Enable token-by-token streaming for vision models (new capability) + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + **kwargs: Additional arguments + + Yields: + List of decoded tokens for each batch position + + Raises: + NotImplementedError: If continuous batching is enabled + """ + if self.full_batch_size is not None: + raise NotImplementedError("Token streaming not supported with continuous batching for VLM") + + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") + + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + text_gen = TextGeneration( + tokenizer=self.tokenizer, + qpc_path=self._qpc_path, + ctx_len=self._ctx_len, + device_id=self.device_id, + enable_debug_logs=self.enable_debug_logs, + is_tlm=self.is_tlm, + include_sampler=self.include_sampler, + return_pdfs=self.return_pdfs, + include_guided_decoding=self.include_guided_decoding, + sampling_params=self.sampling_params, + ) + + text_gen._qaic_model = self + + # Yield tokens as they're generated + for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): + yield tokens + + def __repr__(self): + """String representation of the class""" + return ( + f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})" + ) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 592c0c1d3..6c7173072 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -18,11 +18,15 @@ from transformers.generation.streamers import BaseStreamer from QEfficient.base.modeling_qeff import QEFFBaseModel -from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + AdapterWeightsToInputsTransform, + BaseOnnxTransform, + FP16ClipTransform, + SplitTensorsTransform, +) from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM -from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform from QEfficient.utils import constants @@ -66,7 +70,11 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel): """ _pytorch_transforms: List[PytorchTransform] = [CustomOpsTransform, KVCacheTransform, PeftModelInputsTransform] - _onnx_transforms: List[OnnxTransform] = [FP16ClipTransform, AdapterWeightsToInputsTransform, SplitTensorsTransform] + _onnx_transforms: List[BaseOnnxTransform] = [ + FP16ClipTransform, + AdapterWeightsToInputsTransform, + SplitTensorsTransform, + ] _hf_auto_class = AutoPeftModelForCausalLM def __init__(self, model: nn.Module): @@ -245,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with the active adapter to ONNX format. @@ -283,9 +291,10 @@ def export(self, export_dir: Optional[str] = None) -> str: example_inputs, output_names, dynamic_axes, - export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights + do_constant_folding=False, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, + **kwargs, ) def compile( @@ -300,6 +309,7 @@ def compile( num_cores: int = 16, mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -367,6 +377,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 8196cd769..8ff8335f5 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + **kwargs, ) def generate( diff --git a/QEfficient/peft/lora/layers.py b/QEfficient/peft/lora/layers.py index 6b75e696f..79abeba77 100644 --- a/QEfficient/peft/lora/layers.py +++ b/QEfficient/peft/lora/layers.py @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): # multilora implementation: lora_ids other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) selected_lora_a_weights = CtxGatherFuncCB.apply( - self.lora_a_weights, lora_ids, other_indices_a + self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2] ) # other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) selected_lora_b_weights = CtxGatherFuncCB.apply( - self.lora_b_weights, lora_ids, other_indices_b + self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2] ) # other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) selected_lora_scalings = CtxGatherFuncCB.apply( - self.lora_scalings, lora_ids, other_indices_s + self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2] ) # selected_lora_a_weights = selected_lora_a_weights.squeeze(1) diff --git a/QEfficient/peft/onnx_transforms.py b/QEfficient/peft/onnx_transforms.py index d31d35243..c949e028b 100644 --- a/QEfficient/peft/onnx_transforms.py +++ b/QEfficient/peft/onnx_transforms.py @@ -9,10 +9,10 @@ import onnx -from QEfficient.base.onnx_transforms import OnnxTransform +from QEfficient.base.onnx_transforms import BaseOnnxTransform -class AdapterWeightsToInputsTransform(OnnxTransform): +class AdapterWeightsToInputsTransform(BaseOnnxTransform): @classmethod def apply(cls, model: onnx.ModelProto, *, adapter_name: str, **kwargs) -> Tuple[onnx.ModelProto, bool]: transformed = False diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..faadaba6b 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -15,6 +15,8 @@ from QEfficient.customop import ( CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFuncBlockedKV, + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterFunc, @@ -24,6 +26,34 @@ ) +class InvalidIndexProvider: + SUBFUNC_ENABLED = False + + @classmethod + def enable_subfunc(cls): + cls.SUBFUNC_ENABLED = True + + @classmethod + def _get_invalid_idx_value(cls): + """ + Get the appropriate invalid index value for CtxGather operations. + + For ONNX export with functions, we use 0 to avoid INT32_MAX constants + that cause issues when functions are inlined at runtime. + + Returns: + int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise) + """ + if torch.onnx.is_in_onnx_export(): + if cls.SUBFUNC_ENABLED: + # TODO: should not return 0 remove this if condition, it can hurt perf + return 0 + else: + return torch.iinfo(torch.int32).max + else: + return 0 + + class QEffDynamicLayer(DynamicLayer): def read_only(self, cache_kwargs): """ @@ -40,11 +70,52 @@ def read_only(self, cache_kwargs): k_out, v_out = self.keys, self.values position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) - ctx_len = k_out.shape[2] + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + def read_only_blockedKV(self, start_index, end_index, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer for each KV block. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + start_index (`int`): + Start index of the K/V block to read + + end_index (`int`): + End index of the K/V block to read + + Return: + A tuple containing the updated key and value states. + """ + # Gather + k_out, v_out = self.keys, self.values + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = k_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: @@ -53,11 +124,12 @@ def read_only(self, cache_kwargs): ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices) + v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -137,23 +209,20 @@ def update( k_out, v_out = self.keys, self.values # Gather - ctx_len = k_out.shape[2] + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -262,6 +331,25 @@ def read_only(self, layer_idx, cache_kwargs): """ return self.layers[layer_idx].read_only(cache_kwargs) + def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + start_index (`int`): + Start index of the K/V block to read + end_index (`int`): + End index of the K/V block to read + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs) + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -414,23 +502,21 @@ def update( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather - ctx_len = self.key_cache[layer_idx].shape[2] + ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -516,7 +602,8 @@ def update( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather - ctx_len = min(layer_ctx_len, k_out.shape[2]) + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_len = min(layer_ctx_len, ctx_len) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -529,11 +616,255 @@ def update( # Rolling indices for sliding window all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.sliding_window_len = sliding_window_len + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def write_only( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + _, _, ctx_len, _ = self.key_cache[layer_idx].shape + if is_sliding_layer: + kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + if is_sliding_layer: + ctx_len = self.key_cache[layer_idx].shape[2] + else: + ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) + + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + + def full_cache_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Gather + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + def sliding_window_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + sliding_window_len = cache_kwargs.get("sliding_window") + + # Gather + ctx_len = position_ids.shape[1] + sliding_window_len + ctx_indices = torch.arange(ctx_len)[None, None, ...] + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0) + ctx_indices += add_idx + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out diff --git a/QEfficient/transformers/modeling_attn_mask_utils.py b/QEfficient/transformers/modeling_attn_mask_utils.py index 4faedba33..629c10dd6 100644 --- a/QEfficient/transformers/modeling_attn_mask_utils.py +++ b/QEfficient/transformers/modeling_attn_mask_utils.py @@ -14,6 +14,7 @@ def _create_causal_mask( position_ids, target_length, sliding_window: Optional[int] = None, + start_index: Optional[int] = 0, ): """ A utility attention mask class that allows one to: @@ -40,7 +41,7 @@ def _create_causal_mask( attention_mask = attention_mask.unsqueeze(1) else: query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, 1, -1) + kv_indices = torch.arange(start=start_index, end=target_length).view(1, 1, -1) attention_mask = kv_indices > query_indices attention_mask = attention_mask.unsqueeze(1) diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index c692d1beb..47059d8dc 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -185,8 +185,12 @@ ] ) +# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} +# This is for supporting different modelling classes specially written for prefill-only model +SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"} + # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index 776bfce43..3addd7501 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -72,6 +72,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -124,6 +125,9 @@ def forward( if layer_past is not None: cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -147,6 +151,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -245,6 +250,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -294,6 +300,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -312,6 +319,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, batch_index=batch_index, @@ -348,6 +356,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -361,6 +370,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 8f2c3730d..1cfdf88e1 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -117,6 +117,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, @@ -141,6 +142,9 @@ def forward( if layer_past is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) if attention_mask is not None: @@ -172,6 +176,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, @@ -195,6 +200,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, alibi=alibi, head_mask=head_mask, @@ -245,6 +251,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -307,6 +314,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask[i], use_cache=use_cache, @@ -352,6 +360,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -368,6 +377,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index eea1e3898..1edb8ef53 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -154,6 +155,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -186,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -214,6 +219,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -243,6 +249,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -299,6 +306,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -334,6 +342,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -350,6 +359,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index be3ba942d..2944601c9 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -144,6 +144,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -161,7 +162,15 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -194,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -266,6 +277,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -338,6 +350,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -381,6 +394,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -404,6 +418,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..a6e451bec 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import copy -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -215,6 +215,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -254,6 +255,9 @@ def forward( "is_sliding": self.is_sliding, "sliding_window_pattern": self.config.sliding_window_pattern, } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -297,6 +301,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -323,6 +328,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -363,6 +369,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -429,6 +436,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -466,6 +474,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -525,6 +534,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -592,7 +602,16 @@ def __init__(self, model): self.config = self.model.config self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_index @@ -603,7 +622,12 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -620,7 +644,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -632,7 +664,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -647,7 +683,12 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 @@ -667,24 +708,77 @@ def get_specializations( "ctx_len": ctx_len, } ] - lang = [ - { - "batch_size": batch_size, + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang = [lang_prefill, lang_decode] + specializations = {} if kv_offload: @@ -694,17 +788,21 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} - pkv_dynamic_sliding_axes = {0: "batch_size", 2: "sliding_window"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} + pkv_dynamic_sliding_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} layer_switch = ( self.language_model.config.sliding_window_pattern if hasattr(self.language_model.config, "sliding_window_pattern") @@ -719,6 +817,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): ) lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -767,7 +868,9 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -806,13 +909,22 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV lang_inputs["past_key_values"] = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index d68a65430..6136a2c5d 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -65,6 +65,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -121,6 +122,10 @@ def forward( # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -156,6 +161,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -174,6 +180,7 @@ def forward( hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, head_mask=head_mask, @@ -232,6 +239,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -341,6 +349,7 @@ def forward( outputs = block( hidden_states, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -392,6 +401,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -418,6 +428,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index af233870b..85ea42674 100644 --- a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -98,6 +98,7 @@ def forward( self, hidden_states: torch.Tensor, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -153,6 +154,9 @@ def forward( if layer_past is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key, value = curr_past_key_value.update(key, value, self.layer_idx, cache_kwargs) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if self.is_cross_attention: @@ -180,6 +184,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -194,6 +199,7 @@ def forward( attn_outputs = self.attn( hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -242,6 +248,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -333,6 +340,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, attention_mask=attention_mask, @@ -374,6 +382,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -399,6 +408,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gpt_oss/__init__.py b/QEfficient/transformers/models/gpt_oss/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..3efe890b8 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,1334 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import math +import os +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +class QEffPrefillOnlyGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + if os.environ.get("NUM_FFN_BLOCKS", None) is not None: + return self.blocked_ffn_forward(hidden) + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + # Gate and Up projections + gate = (tgb @ W_g) + b_g # [T, I] + up = (tgb @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out_block = (intermediate @ W_d) + b_d # [T, H] + + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + expert_out += down_out * routing_weight + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + + wg_col_shape = W_g.shape[1] + wg_num_blocks = math.ceil(wg_col_shape / 128) + last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128 + + intermediates = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:] + cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:] + else: + cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] + cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] + + cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) + cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) + cur_intermediate = (cur_up + 1) * cur_glu + intermediates.append(cur_intermediate) + + intermediate = torch.cat(intermediates, dim=-1) + + downs = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:]) + else: + downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128]) + + down_out_block = torch.cat(downs, dim=1) + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +class QEffGptOssMLP(GptOssMLP): + # ------------------- Gather based, weights as activation approach --------------- + def forward_weights_as_activation(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- + def forward(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts (separate gate and up projections) + gate_proj = self.experts.gate_proj[router_indices.flatten()] + gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()] + up_proj = self.experts.up_proj[router_indices.flatten()] + up_proj_bias = self.experts.up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + # Apply gate and up projections separately using bmm + gate = torch.bmm(expert_in, gate_proj) + gate_proj_bias.unsqueeze(1) + up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) + + # Apply activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + # Down projection + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + q_block = query[:, :, qi : qi + real_q_len, :] + scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, value_states) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +def opt_eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + # Calculate block size (last block should be handled with remainder) + + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + if block_idx == 0: + kv_start_idx = 0 + else: + kv_start_idx = qi - 128 + + q_block = query[:, :, qi : qi + real_q_len, :] + if kwargs.get("sliding_window"): + k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :] + v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :] + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len] + else: + k_block = key_states + v_block = value_states + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, v_block) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": self.sliding_window, + } + if self.sliding_window is not None: + key_states, value_states = past_key_value.sliding_window_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states, value_states = past_key_value.full_cache_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.sliding_window is not None: + attention_mask = sliding_mask + # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) + ctx_len = position_ids.shape[1] + self.sliding_window + ctx_indices = torch.arange(ctx_len) + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0) + # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0) + # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window) + ctx_indices += add_idx + attention_mask = attention_mask[:, :, :, ctx_indices] + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffPrefillOnlyGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if self.sliding_window is not None: + sliding_window_len = past_key_value.sliding_window_len + short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2]) + read_idx = short_read_idx + torch.where( + position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 + ) + # This is a trick to export with seq_len position_ids.max(), 0, read_idx) + k_cache = key_states[:, :, read_idx, :] + v_cache = value_states[:, :, read_idx, :] + else: + k_cache, v_cache = key_states, value_states + _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + if os.environ.get("ENABLE_OPT_SWA", "0") == "1": + attention_interface: Callable = opt_eager_attention_forward_blocked + else: + attention_interface: Callable = eager_attention_forward_blocked + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = hidden_states.reshape(residual.shape) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffPrefillOnlyGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.max_cache_len, + sliding_window=self.config.sliding_window, + ) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention" and not retain_full_kv: + pkv_dynamic_axes.append( + {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} + ) + else: + pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + **kwargs, + ): + batch_size = batch_size if batch_size else 1 + if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len: + ctx_len = prefill_seq_len + logger.warning( + f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" + ) + + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + { + "batch_size": batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + ] + return specializations diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index dc3e5e6d2..1a9e45e97 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -83,6 +83,7 @@ def forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -135,6 +136,9 @@ def forward( if layer_past is not None: cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -151,6 +155,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -164,6 +169,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -191,6 +197,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -270,6 +277,7 @@ def forward( outputs = block( hidden_states=hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -314,6 +322,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -339,6 +348,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 2a2d47d6d..62be5f54d 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -17,6 +17,7 @@ from transformers.models.granite.modeling_granite import ( GraniteAttention, GraniteConfig, + GraniteDecoderLayer, GraniteForCausalLM, GraniteModel, GraniteRotaryEmbedding, @@ -129,6 +130,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -146,7 +148,15 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -164,6 +174,80 @@ def forward( return attn_output, attn_weights +class QEffGraniteDecoderLayer(GraniteDecoderLayer): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py + The only differences are: + - add new args batch idx for the CB models although its not supported yet. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + batch_index: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + class QEffGraniteModel(GraniteModel): def forward( self, @@ -171,6 +255,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -226,6 +311,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -267,6 +353,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -319,6 +406,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c085f6a5e..b158b4046 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -123,6 +123,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -150,6 +151,9 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -209,6 +213,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +291,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -297,6 +303,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -492,6 +499,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -546,6 +554,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 567a8e070..2d8fc412d 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -55,6 +55,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -94,6 +95,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -205,6 +209,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -235,6 +240,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -277,6 +283,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -351,6 +358,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -395,6 +403,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -441,6 +450,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 38d0fe167..b47db7eda 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List, Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -34,7 +36,16 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -55,7 +66,12 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -74,7 +90,12 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): num_patches = compiler_options.pop("num_patches", None) @@ -104,24 +125,75 @@ def get_specializations( "batched_num_patches": batch_size * num_patches, } ] - lang = [ - { - "batch_size": batch_size, + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "num_patches": num_patches, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "num_patches": num_patches, "img_size": img_size, "vision_size": vision_size, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] specializations = {} @@ -130,22 +202,31 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {1: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -173,7 +254,9 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -222,10 +305,13 @@ def get_dummy_inputs(self, kv_offload: bool = False): ) lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -234,6 +320,11 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -244,7 +335,15 @@ def get_dummy_inputs(self, kv_offload: bool = False): return inputs - def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + pixel_values, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape @@ -266,7 +365,11 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index f2a68f80e..fb3aed556 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -113,6 +113,7 @@ def eager_attention_forward( attn_weights = torch.where( attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -120,6 +121,80 @@ def eager_attention_forward( return attn_output, attn_weights +def eager_attention_forward_blockedKV( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: Optional[torch.Tensor] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum + batch_size, num_heads, seq_len, _ = query.shape + current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE)) + + # Initialize Denominator + current_denominator = torch.zeros(batch_size, num_heads, seq_len) + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + block_size = -(-past_seen_tokens // num_kv_blocks) + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) + + for j in range(num_kv_blocks): + start_index = j * block_size + end_index = (j + 1) * block_size + K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + K_block_states = repeat_kv(K_block, module.num_key_value_groups) + V_block_states = repeat_kv(V_block, module.num_key_value_groups) + past_seen_tokens_start = start_index + past_seen_tokens_end = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start + ) + + # Compute attention scores for the block + attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + + # Update Running row maximum + prev_max = current_max + current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values) + delta_max = prev_max - current_max + + current_exp = torch.exp( + attn_weights_block - current_max.unsqueeze(-1) + ) # Subract current_max from each column of attn_weights_block + + # update running denominator + prev_denominator = current_denominator + current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1) + + prob = current_exp / current_denominator.unsqueeze(-1) + + prev_output = output + output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, V_block_states) + attn_output = output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + + class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -132,9 +207,11 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + num_kv_blocks: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -150,14 +227,29 @@ def forward( value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface = eager_attention_forward + if num_kv_blocks is not None: + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "past_seen_tokens": past_seen_tokens, + } + past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + else: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if num_kv_blocks is not None: + attention_interface = eager_attention_forward_blockedKV + else: + attention_interface = eager_attention_forward attn_output, attn_weights = attention_interface( self, @@ -166,6 +258,10 @@ def forward( value_states, attention_mask, scaling=self.scaling, + num_kv_blocks=num_kv_blocks, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_value=past_key_value, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -187,6 +283,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -202,6 +299,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -229,6 +327,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -277,6 +376,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -310,6 +410,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -326,6 +427,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..834ee8880 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -470,6 +470,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -503,14 +504,20 @@ def forward( if past_key_value is not None: chunk_position_ids = position_ids - if self.use_rope: chunk_position_ids = torch.where( chunk_position_ids != -1, chunk_position_ids % self.config.attention_chunk_size, chunk_position_ids ) # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": chunk_position_ids, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -543,6 +550,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -562,6 +570,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -615,6 +624,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -682,6 +692,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -731,6 +742,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -754,6 +766,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -820,7 +833,7 @@ def forward(self, pixel_values): ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.model.multi_modal_projector(vision_flat) - return projected_vision_flat + return projected_vision_flat # , pixel_values # This wrapper utilizes the 'vision_embeds', which contains vision embeddings, and an 'image_idx' index starting at 0. @@ -836,7 +849,16 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +868,12 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -860,7 +887,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy @@ -880,7 +915,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -893,9 +932,15 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if max_num_tiles is None: logger.warning( "User should pass `max_num_tiles` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 17" @@ -941,9 +986,54 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_tiles": max_num_tiles, @@ -951,18 +1041,31 @@ def get_specializations( "vision_size": vision_size, "chunk_length": prefill_seq_len, "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, "ctx_len": ctx_len, "max_num_tiles": max_num_tiles, "img_size": img_size, "vision_size": vision_size, "chunk_length": prefill_seq_len, "chunk_ctx_len": chunk_ctx_len, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] specializations = {} @@ -971,18 +1074,24 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -993,6 +1102,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -1011,6 +1123,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: + # vision_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names @@ -1045,7 +1158,9 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1090,10 +1205,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1102,6 +1221,12 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 9fd1ed782..fa42b3f96 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -89,6 +89,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, ) -> torch.Tensor: @@ -98,6 +99,7 @@ def forward( # Reshape the query, key, and value tensors. query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -105,8 +107,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -155,6 +159,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, + comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -166,6 +171,7 @@ def forward( hidden_states=hidden_states, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, ) @@ -201,11 +207,19 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _run_swiftkv_layers( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states = layer(hidden_states, position_ids, past_key_values, causal_mask, batch_index) + hidden_states = layer( + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + ) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -289,6 +303,7 @@ def forward( input_ids: Optional[torch.Tensor], position_ids: torch.Tensor, past_key_values: List[torch.Tensor], + comp_ctx_lengths: Optional[torch.LongTensor], batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -328,6 +343,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=False, use_cache=True, @@ -373,7 +389,7 @@ def forward( causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( - hidden_states, position_ids, past_key_values, causal_mask, batch_index + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index ) # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction # we only need the last valid pos_indices hidden_states. @@ -405,9 +421,12 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, ): - hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + hidden_states, output_past_key_values = self.model( + input_ids, position_ids, past_key_values, comp_ctx_lengths, batch_index + ) logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index e260beb05..abdb77ea5 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List, Optional + import torch import torch.nn as nn import torch.utils.checkpoint @@ -16,6 +18,7 @@ from QEfficient.utils.logging_utils import logger BS = 1 +FBS = 4 NUM_CHANNEL = 3 SEQ_LEN = 592 CTX_LEN = 1024 @@ -51,7 +54,16 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -65,6 +77,8 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, return_dict=True, ) @@ -83,7 +97,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -109,6 +131,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -120,7 +143,13 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -145,11 +174,17 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for i in range(num_layers): lang_inputs["past_key_values"].append( ( - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), - torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(FBS if continuous_batching else BS, num_key_value_heads, CTX_LEN, head_dim), ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) inputs = {} if kv_offload: @@ -166,7 +201,12 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_images = compiler_options.pop("max_num_images", 1) @@ -187,24 +227,78 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, + + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] + specializations = {} if kv_offload: @@ -212,9 +306,13 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -224,11 +322,22 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} dynamic_axes = {} if kv_offload: diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 2fa1d9234..627f7393e 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- +from typing import List, Optional + import numpy as np import torch import torch.nn as nn @@ -18,6 +20,9 @@ from QEfficient.utils._utils import IOInfo from QEfficient.utils.logging_utils import logger +BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE +FBS = constants.ONNX_EXPORT_EXAMPLE_FBS + class QEffLlavaNextEncoderWrapper(nn.Module): def __init__(self, model): @@ -123,7 +128,16 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index @@ -138,6 +152,8 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -154,7 +170,13 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -203,13 +225,13 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): lang_inputs["past_key_values"].append( ( torch.zeros( - constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + FBS if continuous_batching else BS, num_key_value_heads, constants.GRANITEVISION_CTX_LEN, head_dim, ), torch.zeros( - constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + FBS if continuous_batching else BS, num_key_value_heads, constants.GRANITEVISION_CTX_LEN, head_dim, @@ -217,6 +239,13 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(BS).view(BS, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -232,7 +261,12 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_images = compiler_options.pop("max_num_images", 1) @@ -285,9 +319,54 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "image_size_height": image_size_height, @@ -296,9 +375,17 @@ def get_specializations( "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "image_size_height": image_size_height, @@ -307,17 +394,28 @@ def get_specializations( "max_num_images": max_num_images, "img_size": img_size, "vision_size": vision_size, - }, - ] + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] + specializations = {} if kv_offload: specializations["vision"] = vision specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { @@ -327,11 +425,23 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "vision_embeds": {0: "batch_size", 1: "vision_size"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ca23cc144..5edfb8f3a 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -140,6 +140,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -164,6 +165,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -196,6 +200,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +231,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -256,6 +262,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -316,6 +323,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -354,6 +362,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -377,6 +386,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 735eec9e5..d2149b6bd 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -106,6 +106,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -126,6 +127,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -166,19 +168,30 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): - inputs_embeds = self.model.get_input_embeddings()(input_ids) - vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) mask = input_ids == self.model.config.image_token_index indices1 = mask.to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(mask.shape[0]).view(-1, 1) image_features_expanded = vision_embeds.unsqueeze(0)[indices0, indices1] - inputs_embeds_1 = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) - outputs = self.model.model( - inputs_embeds=inputs_embeds_1, + image_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + outputs = self.language_model( + inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, ) # Cast to int32 to avoid ONNXRT issue @@ -198,7 +211,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.get_image_features( @@ -219,6 +240,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -230,7 +252,13 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -270,10 +298,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( - config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -282,6 +314,11 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -298,7 +335,12 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): if img_size is None and hasattr(self.config.vision_config, "image_size"): @@ -323,22 +365,70 @@ def get_specializations( "vision_size": vision_size, } ] - lang = [ - { - "batch_size": batch_size, + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size": img_size, + "vision_size": vision_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + lang.append(lang_prefill) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size": img_size, + "vision_size": vision_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "image_size": img_size, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "image_size": img_size, "vision_size": vision_size, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + lang = [lang_prefill, lang_decode] specializations = {} @@ -351,7 +441,9 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -364,9 +456,21 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): "vision_embeds": {0: "vision_size"}, } + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} dynamic_axes = {} if kv_offload: diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 9b9e3448a..862714fea 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -160,6 +161,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -245,6 +249,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -282,6 +287,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -314,6 +320,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -375,6 +382,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, @@ -412,6 +420,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -435,6 +444,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index cb24f1de4..74de1c6c1 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -177,6 +177,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -249,6 +250,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, use_cache: bool = False, @@ -282,6 +284,9 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_self_attention_forward @@ -316,6 +321,7 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -350,6 +356,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -379,6 +386,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -396,13 +404,17 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, + cache_kwargs, ) elif past_key_value is not None: key_states, value_states = ( @@ -448,6 +460,7 @@ def forward( full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -461,6 +474,7 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, ) @@ -594,6 +608,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.FloatTensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, @@ -658,6 +673,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, ) @@ -688,6 +704,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, @@ -707,6 +724,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -774,6 +792,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -820,6 +839,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, inputs_embeds=inputs_embeds, cache_position=cache_position, @@ -853,6 +873,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -869,6 +890,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -879,7 +901,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -943,6 +965,10 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + inputs = {} if kv_offload: @@ -959,6 +985,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -973,22 +1001,53 @@ def get_specializations( logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + ] + specializations = {} if kv_offload: @@ -998,7 +1057,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1023,6 +1082,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2a00577f2..236f6c9f5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,15 +5,17 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import numpy as np import torch import torch.nn as nn from transformers import ( + AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoModelForCTC, @@ -35,12 +37,21 @@ calculate_latency, get_compilation_dims, ) -from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +from QEfficient.generation.vlm_generation import VisionLanguageGeneration +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, +) from QEfficient.transformers.models.pytorch_transforms import ( + BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyChunkedTransform, + PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -51,12 +62,15 @@ AwqToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, ) from QEfficient.utils import ( constants, get_padding_shape_from_config, ) +from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs class QEFFTransformersBase(QEFFBaseModel): @@ -119,21 +133,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path) - @property - def model_name(self) -> str: - """ - Get the name of the underlying HuggingFace model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - class MultimodalUtilityMixin: """ @@ -311,7 +310,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -323,6 +322,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -346,6 +347,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -358,6 +360,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -384,6 +387,8 @@ def compile( Number of cores to use for compilation. mxfp6_matmul : bool, optional Use MXFP6 compression for weights. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. These are passed directly to the underlying compilation command. @@ -427,6 +432,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -591,7 +597,7 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the vision encoder component to ONNX format. @@ -607,6 +613,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -614,7 +622,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the vision encoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -627,6 +640,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -650,6 +664,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -667,24 +683,10 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying vision encoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -718,7 +720,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -726,14 +728,21 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. Supported keys include: + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. **kwargs : Additional keyword arguments passed to the base class constructor. """ - super().__init__(model, **kwargs) + super().__init__(model, qaic_config=qaic_config, **kwargs) self.model = model.get_qeff_language_decoder() + self.model.qaic_config = qaic_config self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) + + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the language decoder component to ONNX format. @@ -749,6 +758,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Directory path where the exported ONNX graph will be saved. Default is None. offload_pt_weights : bool, optional If True, PyTorch weights will be offloaded after export. Default is True. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -756,7 +767,12 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt Path to the generated ONNX graph file for the language decoder. """ return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -769,6 +785,7 @@ def compile( mdp_ts_num_devices, aic_num_cores, custom_io, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -792,6 +809,8 @@ def compile( Number of cores to use for compilation. custom_io : Dict[str, str] Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : Additional compiler options passed to the underlying compilation command. @@ -809,24 +828,10 @@ def compile( mdp_ts_num_devices=mdp_ts_num_devices, aic_num_cores=aic_num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying language decoder model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -856,6 +861,8 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -865,39 +872,35 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : - Additional keyword arguments. `full_batch_size` is not supported here. - - Raises - ------ - NotImplementedError - If `full_batch_size` is provided. + Additional keyword arguments. """ if kwargs.pop("full_batch_size", None): - raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) self.model = model self.config = model.config + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) + self.continuous_batching = continuous_batching + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None self.input_shapes, self.output_names = None, None - - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path. @@ -922,8 +925,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) @property def onnx_path(self): @@ -958,6 +967,7 @@ def qpc_path(self): def export( self, export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -970,6 +980,8 @@ def export( ---------- export_dir : str, optional Directory path where the exported ONNX graphs will be saved. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **kwargs : Additional keyword arguments. @@ -978,9 +990,37 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.model.language_model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) self.vision_model.export( inputs["vision"], @@ -988,9 +1028,15 @@ def export( dynamic_axes["vision"], export_dir=export_dir, offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, ) self.lang_model.export( - inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=True, + use_onnx_subfunctions=use_onnx_subfunctions, ) return self.onnx_path @@ -1003,6 +1049,8 @@ def compile( compile_dir: Optional[str] = None, *, prefill_seq_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None, batch_size: int = 1, full_batch_size: Optional[int] = None, @@ -1011,9 +1059,9 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, skip_vision: Optional[bool] = False, skip_lang: Optional[bool] = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1053,6 +1101,8 @@ def compile( If True, skips compilation of the vision encoder. Default is False. skip_lang : bool, optional If True, skips compilation of the language decoder. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1068,23 +1118,47 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -1109,7 +1183,14 @@ def compile( if (self.vision_model.onnx_path is None and vision_onnx_path is None) or ( self.lang_model.onnx_path is None and lang_onnx_path is None ): - self.export() + self.export( + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) if not skip_vision: self.vision_model._compile( @@ -1122,6 +1203,7 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_vision, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -1150,17 +1232,25 @@ def compile( aic_num_cores=num_cores, custom_io=custom_io_lang, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path def generate( self, - inputs: torch.Tensor, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, + **kwargs, ) -> Union[torch.Tensor, np.ndarray]: """ Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. @@ -1172,6 +1262,14 @@ def generate( inputs : Dict[str, Union[torch.Tensor, np.ndarray]] Inputs to run the execution, typically includes `pixel_values`, `input_ids`, `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. streamer : TextStreamer, optional A streamer object to display generated tokens in real-time. Default is None. device_ids : List[int], optional @@ -1196,6 +1294,35 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + image_height=image_height, + image_width=image_width, + **kwargs, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1313,9 +1440,14 @@ def kv_offload_generate( vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} - lang_inputs["position_ids"] = np.where( - lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 - ) # Need to use -1 as position_ids for invalid tokens + + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" if not_mllama: @@ -1327,21 +1459,33 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ - :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ..., i * prefill_seq_len : (i + 1) * prefill_seq_len ] outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] - prefill_time = perf_counter() - prefill_start + vision_end - vision_start + prefill_time = perf_counter() - lang_start + vision_end - vision_start # Skip inputs/outputs again lang_session.skip_buffers( [ @@ -1350,10 +1494,12 @@ def kv_offload_generate( if x.startswith("past_") or x.endswith("_RetainedState") ] ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) - lang_inputs["position_ids"] = input_len.numpy() + lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 if "cross_attention_mask" in lang_inputs: bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() @@ -1363,8 +1509,27 @@ def kv_offload_generate( streamer.put(lang_inputs["input_ids"][0]) # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = lang_session.run(lang_inputs) # Prepare inputs for next iteration @@ -1414,6 +1579,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal def __init__( self, model: nn.Module, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -1423,18 +1589,28 @@ def __init__( ---------- model : nn.Module The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. Supported keys include: + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. **kwargs : Additional keyword arguments. `full_batch_size` is not supported here. Raises ------ NotImplementedError - If `full_batch_size` is provided. + If `full_batch_size` is provided or `include_sampler` is True. """ if kwargs.pop("full_batch_size", None): + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if qaic_config is not None and qaic_config.pop("include_sampler", False): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) + self.model.qaic_config = qaic_config + # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): self.model.config.llm_config.use_cache = True @@ -1446,11 +1622,19 @@ def __init__( else: self.model.config.use_cache = True self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) @classmethod def from_pretrained( cls, pretrained_model_name_or_path, + qaic_config: Optional[dict] = None, *args, **kwargs, ): @@ -1481,6 +1665,7 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + from transformers import AutoConfig config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) @@ -1488,11 +1673,17 @@ def from_pretrained( config.vision_config.use_flash_attn = "false" model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) def export( self, export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, **kwargs, ) -> str: """ @@ -1510,10 +1701,16 @@ def export( str Path to the generated ONNX graph file. """ - inputs = self.model.get_dummy_inputs() - dynamic_axes = self.model.get_onnx_dynamic_axes() + inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=use_onnx_subfunctions, + ) def compile( self, @@ -1523,6 +1720,8 @@ def compile( *, prefill_seq_len: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, @@ -1531,6 +1730,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -1566,6 +1766,8 @@ def compile( Use MXINT8 compression for KV cache. Default is False. num_speculative_tokens : int, optional Not supported for this model; must be None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -1585,14 +1787,32 @@ def compile( f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " ) + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names() + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # Get specializations from modelling file # TODO: expose this via the auto class as well specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + kv_cache_batch_size=kv_cache_batch_size, img_size=img_size, **compiler_options, ) @@ -1611,6 +1831,11 @@ def compile( if output_name.endswith("_RetainedState"): custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -1623,6 +1848,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) return self.qpc_path @@ -1777,12 +2003,26 @@ def cloud_ai_100_generate( inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs["image_idx"] = np.array([[0]]) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + qpc_session.activate() chunk_inputs = inputs.copy() prefill_start = perf_counter() # Run prefill for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] outputs = qpc_session.run(chunk_inputs) @@ -1806,8 +2046,27 @@ def cloud_ai_100_generate( inputs.pop("pixel_values") # Decode loop + if self.comp_ctx_lengths_decode is not None: + list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = qpc_session.run(inputs) # Prepare inputs for next iteration inputs["input_ids"] = outputs["logits"].argmax(2) @@ -1832,21 +2091,6 @@ def cloud_ai_100_generate( ), ) - @property - def model_name(self) -> str: - """ - Get the name of the underlying multimodal model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname - @property def get_model_config(self) -> dict: """ @@ -1925,7 +2169,14 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__( + self, + model: nn.Module, + kv_offload: Optional[bool] = True, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -1946,13 +2197,22 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC( + model, continuous_batching, qaic_config=qaic_config, **kwargs + ) else: - return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) + return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1964,6 +2224,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -1981,18 +2243,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona If `continuous_batching` is provided as True. """ # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + if kwargs.get("attn_implementation", None) not in {None, "eager"}: logger.warning('Updating attn_implementation="eager"') if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2027,18 +2296,39 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): + if enable: + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) + + else: + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + def __init__( self, model: nn.Module, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, **kwargs, ): """ @@ -2058,6 +2348,9 @@ def __init__( - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. For Speculative Decoding Target Language Models, this is always True. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -2084,6 +2377,8 @@ def __init__( ) # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2092,6 +2387,11 @@ def __init__( self.is_tlm = transformed self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.hash_params["max_seq_len_cached"] = max_seq_len_cached # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -2103,20 +2403,8 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True - @property - def model_name(self) -> str: - """ - Get the name of the underlying Causal Language Model. - - Returns - ------- - str - The model's class name, with "QEff" or "QEFF" prefix removed if present. - """ - mname = self.model.__class__.__name__ - if mname.startswith("QEff") or mname.startswith("QEFF"): - mname = mname[4:] - return mname + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -2128,6 +2416,7 @@ def from_pretrained( pretrained_model_name_or_path, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, *args, **kwargs, ): @@ -2157,6 +2446,8 @@ def from_pretrained( and ``return_pdfs=False`` for regular model. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. The values provided in ``top_ks`` tensor must be less than this maximum limit. + - **include_guided_decoding** (bool): If True, enables guided token-level filtering + during decoding. Only works when include_sampler=True. *args : Positional arguments passed directly to `cls._hf_auto_class.from_pretrained`. @@ -2191,16 +2482,21 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class - if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + continuous_batching=continuous_batching, + **kwargs, ) return cls( model, continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, **kwargs, ) @@ -2216,7 +2512,56 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def get_seq_len_and_handle_specialized_prefill_model( + self, prefill_seq_len: Optional[int] = None, enable_chunking=False + ) -> int: + self.hash_params["prefill_only"] = True + if enable_chunking: + self.hash_params["chunking"] = True + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) + if num_q_blocks is None: + block_size = 256 + if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Or set `NUM_Q_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // block_size + logger.warning( + f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override" + ) + os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) + num_q_blocks = int(num_q_blocks) + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) + + self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks + self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks + self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + **kwargs, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2229,7 +2574,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. - + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- str @@ -2241,6 +2587,33 @@ def export(self, export_dir: Optional[str] = None) -> str: kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) + enable_chunking = kwargs.get("enable_chunking", False) + if prefill_only: + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.prefill(enable=True, enable_chunking=enable_chunking) + self.hash_params.pop("retain_full_kv", None) + seq_len = ( + self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking + ) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH + else seq_len + ) + kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + else: + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("NUM_Q_BLOCKS", None) + self.hash_params.pop("NUM_FFN_BLOCKS", None) + self.hash_params.pop("ENABLE_OPT_SWA", None) + self.hash_params.pop("chunking", None) + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + self.model.config.sliding_window + self.hash_params["retain_full_kv"] = True + example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), @@ -2250,6 +2623,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths_prefill is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -2283,10 +2660,26 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names.append(f"past_{kv}.{i}_RetainedState") else: + # HACK: create common function for this including above if condition code + pkv_dynamic_axes = ( + self.model.get_pkv_dynamic_axes( + retain_full_kv=kwargs.get("retain_full_kv", False) + or (prefill_only and kwargs.get("enable_chunking", False)), + continuous_batching=self.continuous_batching, + ) + if hasattr(self.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * self.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) + for i in range(self.num_layers): for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] output_names.append(f"past_{kv}.{i}_RetainedState") if self.continuous_batching: @@ -2299,105 +2692,33 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"} if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): - example_inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + example_inputs, output_names, dynamic_axes = get_sampling_inputs_and_outputs( example_inputs=example_inputs, output_names=output_names, dynamic_axes=dynamic_axes, + continuous_batching=self.continuous_batching, + vocab_size=self.model.config.vocab_size, + qaic_config=self.model.qaic_config, ) - return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + offload_pt_weights=kwargs.get("offload_pt_weights", True), + prefill_only=prefill_only, ) - def get_sampling_inputs_and_outputs( - self, - example_inputs: Dict[str, torch.Tensor], - output_names: List[str], - dynamic_axes: Dict[str, Dict[int, str]], - ): - """ - Updates the example inputs, output names, and dynamic axes to include - parameters relevant for on-device sampling during ONNX export. - - Parameters - ---------- - example_inputs : Dict[str, torch.Tensor] - Current dictionary of example inputs. - output_names : List[str] - Current list of output names. - dynamic_axes : Dict[str, Dict[int, str]] - Current dictionary of dynamic axes configurations. - - Returns - ------- - Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] - Updated example inputs, output names, and dynamic axes including - sampling-related parameters. - """ - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - - example_inputs["last_accepted_output_tokens"] = torch.zeros( - (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 - ) - dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} - - example_inputs["past_repetition_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_repetition_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_repetition_penalty_buffer_RetainedState") - - example_inputs["repetition_penalties"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES - ) - dynamic_axes["repetition_penalties"] = {0: "batch_size"} - - example_inputs["past_presence_penalty_buffer"] = torch.zeros( - (fbs if self.continuous_batching else bs, self.model.config.vocab_size), dtype=torch.bool - ) - dynamic_axes["past_presence_penalty_buffer"] = { - 0: "full_batch_size" if self.continuous_batching else "batch_size", - } - output_names.append("past_presence_penalty_buffer_RetainedState") - - example_inputs["presence_penalties"] = ( - torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES - ) - dynamic_axes["presence_penalties"] = {0: "batch_size"} - - example_inputs["temperatures"] = ( - torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES - ) - dynamic_axes["temperatures"] = {0: "batch_size"} - - max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) - example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) - dynamic_axes["top_ks"] = {0: "batch_size"} - - example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS - dynamic_axes["top_ps"] = {0: "batch_size"} - - example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS - dynamic_axes["min_ps"] = {0: "batch_size"} - - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) - dynamic_axes["random_numbers"] = {0: "batch_size"} - - return example_inputs, output_names, dynamic_axes - def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the prefill phase. @@ -2420,12 +2741,27 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ - spec = { - "batch_size": 1 if self.continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_logits_to_keep": 1 if self.is_tlm else None, - } + if prefill_seq_len == 1 and self.continuous_batching: + exec_batch_size = full_batch_size + else: + exec_batch_size = 1 if self.continuous_batching else batch_size + + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=exec_batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + **kwargs, + )[0] + else: + spec = { + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + spec["num_logits_to_keep"] = 1 if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2438,10 +2774,12 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the decode phase. @@ -2469,12 +2807,23 @@ def build_decode_specialization( """ if prefill_seq_len == 1 and not self.continuous_batching: return None # Avoid duplication with prefill - spec = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, - "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, - "ctx_len": ctx_len, - "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, - } + + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=full_batch_size if self.continuous_batching else batch_size, + prefill_seq_len=(num_speculative_tokens + 1) if self.is_tlm else 1, + ctx_len=ctx_len, + )[1] + else: + spec = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, + "ctx_len": ctx_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + + spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -2489,6 +2838,8 @@ def compile( *, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, @@ -2498,6 +2849,10 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, + use_onnx_subfunctions: bool = False, + offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -2539,6 +2894,8 @@ def compile( prefill_only : bool, optional If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC or QNN compilers. @@ -2576,6 +2933,51 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if (kv_cache_batch_size or full_batch_size) and not self.continuous_batching: + logger.warning( + "`kv_cache_batch_size` or `full_batch_size` is being passed" + "This will be ignored as `continuous_batching` is set to `False` in `from_pretrained`" + ) + + if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + else: + if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None: + raise ValueError( + "Please pass valid integer for kv_cache_batch_size or full_batch_size, both have same meaning, as continuous_batching is enabled for prefill-only model" + ) + + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + if isinstance(comp_ctx_lengths_prefill, str): + import ast + + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") @@ -2583,15 +2985,6 @@ def compile( if self.is_tlm: num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) - if self.continuous_batching and full_batch_size is None: - raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") - - if kv_cache_batch_size and not full_batch_size: - raise ValueError( - "KV caching requires continuous batching. Please set `full_batch_size` and " - "enable `continuous_batching=True` in `from_pretrained`." - ) - if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) @@ -2600,32 +2993,67 @@ def compile( ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") - # Infer kv_cache_batch_size if not provided - kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size - # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: - specializations.append( - self.build_prefill_specialization( + # TODO: we are handling decode-only case inside prefill call which is utterly mis-leading + if self.comp_ctx_lengths_prefill is not None: + # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization + for i in range(0, len(self.comp_ctx_lengths_prefill)): + if prefill_only or enable_chunking: + raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_prefill[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + else: + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + ) + ) + + if prefill_only is None or not prefill_only: + if self.comp_ctx_lengths_decode is not None: + # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + for i in range(0, len(self.comp_ctx_lengths_decode)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_decode[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, ) - ) - if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - kv_cache_batch_size=kv_cache_batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - if decode_spec: - specializations.append(decode_spec) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -2635,7 +3063,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -2649,6 +3076,11 @@ def compile( num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, + retain_full_kv=retain_full_kv, **compiler_options, ) @@ -2700,9 +3132,11 @@ def generate( raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - self.qpc_path, + tokenizer=tokenizer, + qpc_path=self.qpc_path, prompt=prompts, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, device_id=device_id, generation_len=generation_len, automation=kwargs.pop("automation", False), @@ -2838,7 +3272,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2850,6 +3284,8 @@ def export(self, export_dir: Optional[str] = None) -> str: export_dir : str, optional Directory path where the exported ONNX graph will be saved. If not provided, the default export directory is used. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns ------- @@ -2859,7 +3295,13 @@ def export(self, export_dir: Optional[str] = None) -> str: inputs = self.model.get_dummy_inputs() dynamic_axes = self.model.get_onnx_dynamic_axes() output_names = self.model.get_output_names() - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + return self._export( + inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) def compile( self, @@ -2877,6 +3319,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -2918,6 +3361,8 @@ def compile( Not yet supported for this model. num_speculative_tokens : int, optional Not yet supported for this model. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False **compiler_options : dict Additional compiler options for QAIC. @@ -2985,6 +3430,7 @@ def compile( mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) @@ -3202,12 +3648,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. ``Optional`` Args: :export_dir (str, optional): The directory path to store ONNX-graph. + :use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False Returns: :str: Path of the generated ``ONNX`` graph. @@ -3228,6 +3676,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -3240,6 +3689,7 @@ def compile( num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, + use_onnx_subfunctions: bool = False, **compiler_options, ) -> str: """ @@ -3255,6 +3705,7 @@ def compile( :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :use_onnx_subfunctions: bool, optional: whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False :compiler_options (dict, optional): Additional compiler options. For QAIC Compiler: Extra arguments for qaic-exec can be passed. @@ -3287,6 +3738,7 @@ def compile( mxfp6_matmul=mxfp6_matmul, mdp_ts_num_devices=num_devices, aic_num_cores=num_cores, + use_onnx_subfunctions=use_onnx_subfunctions, **compiler_options, ) diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index 4f92316ca..b686e6aed 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -43,14 +43,14 @@ def eager_attention_forward( if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0 repeat_factor = num_q_heads // num_kv_heads - _, _, S, D = k.shape + B, _, S, D = k.shape k = k.unsqueeze(2) k = k.expand(-1, -1, repeat_factor, -1, -1) - k = k.reshape(1, num_q_heads, S, D) + k = k.reshape(B, num_q_heads, S, D) v = v.unsqueeze(2) v = v.expand(-1, -1, repeat_factor, -1, -1) - v = v.reshape(1, num_q_heads, S, D) + v = v.reshape(B, num_q_heads, S, D) attn_weights = torch.matmul(q, k.transpose(2, 3)) * scale_factor @@ -243,6 +243,7 @@ def attention( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -279,7 +280,15 @@ def attention( if layer_past is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } + if comp_ctx_lengths is not None: + attention_bias = attention_bias[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_bias.shape[-1] k, v = layer_past.update(k, v, self.layer_id, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -311,6 +320,7 @@ def forward( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -334,6 +344,7 @@ def forward( attention_bias, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -380,6 +391,7 @@ def forward( subsegment_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, last_logits_only: bool = False, @@ -496,6 +508,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -518,6 +531,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layers_past=layers_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -574,7 +588,16 @@ def __init__(self, model): # self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + batch_index: Optional[torch.LongTensor] = None, + ): if input_ids is not None: input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) inputs_embeds = self.model.model.transformer.wte(input_ids) @@ -587,7 +610,12 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -608,7 +636,16 @@ def get_qeff_language_decoder(self): """ def forward( - self, pixel_values, image_masks, image_input_idx, valid_idx, input_ids, position_ids, image_idx, past_key_values + self, + pixel_values, + image_masks, + image_input_idx, + valid_idx, + input_ids, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): image_features, _ = self.model.vision_backbone(pixel_values, image_masks) num_image, num_patch = image_features.shape[1:3] @@ -637,7 +674,11 @@ def forward( inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -651,8 +692,13 @@ def get_specializations( ctx_len: int, num_images: int = None, img_size: int = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, valid_size: int = None, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): prefill_seq_len = prefill_seq_len if prefill_seq_len else 1024 @@ -679,30 +725,108 @@ def get_specializations( } ] - lang_prefill = { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "valid_size": valid_size, - } + if comp_ctx_lengths_prefill is not None and comp_ctx_lengths_decode is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "valid_size": valid_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "valid_size": valid_size, + "vision_batch_size": batch_size, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_decode[key] = value + + lang.append(lang_decode) - lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "valid_size": valid_size, + "vision_batch_size": batch_size, + } - if kv_offload: - values = { - "img_size": img_size, - "img_tile": img_tile, - "num_images": num_images, - "num_patch": num_patch, + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "valid_size": valid_size, + "vision_batch_size": batch_size, } - for key, value in values.items(): - lang_prefill[key] = value - lang_decode[key] = value + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + lang_decode[key] = value + + lang = [lang_prefill, lang_decode] - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) specializations = {} if kv_offload: @@ -712,13 +836,15 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "valid_size"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "valid_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 1: "num_images", 2: "img_tile", 3: "img_size"} vision_dynamic_axes["image_input_idx"] = {0: "batch_size", 1: "num_images", 2: "num_patch"} @@ -728,8 +854,20 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): num_layers = self.model.config.n_layers for i in range(num_layers): - lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} - lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} dynamic_axes = {} if kv_offload: @@ -760,7 +898,13 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes_lang = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -811,10 +955,14 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV kv_cache_shape = get_padding_shape_from_config( config=self.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -823,6 +971,11 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 9bf6a4422..c1d98c1f8 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -39,6 +39,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, ): @@ -52,6 +53,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -101,6 +105,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, use_cache: bool = False, output_attentions: bool = False, ): @@ -118,6 +123,7 @@ def forward( batch_index=batch_index, attention_mask=attention_mask, past_key_value=layer_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -144,6 +150,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -205,6 +212,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -250,6 +258,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -271,6 +280,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 6dae7ac84..00755cae5 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): @@ -109,7 +110,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -129,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -153,6 +157,9 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -185,6 +192,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -200,6 +208,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -230,6 +239,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -283,6 +293,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -319,6 +330,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -340,6 +352,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index 18557f1ca..4bf2e8785 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -67,6 +67,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -105,7 +106,15 @@ def forward( if past_key_value is not None: # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -140,6 +149,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, @@ -181,6 +191,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -213,6 +224,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -274,6 +286,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -316,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -370,6 +384,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 4b5234a5a..b97a0ab8d 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -140,6 +140,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -166,6 +167,9 @@ def forward( "batch_index": batch_index, "position_ids": position_ids, } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -198,6 +202,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -235,6 +240,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -265,6 +271,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +321,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -350,6 +358,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -366,6 +375,7 @@ def forward( batch_index=batch_index, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index c910ab387..b978b6193 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import warnings +from functools import partial from types import MethodType from typing import Callable, Optional, Tuple, Union @@ -51,9 +52,19 @@ GPTBigCodeForCausalLM, GPTBigCodeModel, ) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRMSNorm, +) from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel from transformers.models.granite.modeling_granite import ( GraniteAttention, + GraniteDecoderLayer, GraniteForCausalLM, GraniteModel, GraniteRMSNorm, @@ -152,6 +163,18 @@ Qwen2Model, Qwen2RMSNorm, ) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLDecoderLayer, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLModel, + Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2RMSNorm as Qwen2_5RMSNorm, +) from transformers.models.qwen3.modeling_qwen3 import ( Qwen3Attention, Qwen3DecoderLayer, @@ -174,6 +197,10 @@ Starcoder2ForCausalLM, Starcoder2Model, ) +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, @@ -215,6 +242,7 @@ QEffGemma3Attention, QEffGemma3CustomRMSNormAIC, QEffGemma3DecoderLayer, + QEffGemma3DecoderWrapper, QEffGemma3ForCausalLMModel, QEffGemma3ForConditionalGeneration, QEffGemma3TextModel, @@ -231,6 +259,19 @@ QEffGPTBigCodeForCausalLM, QEffGPTBigCodeModel, ) +from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffGptOssDecoderLayer, + QEffGptOssExperts, + QEffGptOssForCausalLM, + QEffGptOssMLP, + QEffGptOssModel, + QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyChunkedGptOssMLP, + QEffPrefillOnlyGptOssAttention, + QEffPrefillOnlyGptOssMLP, + QEffPrefillOnlyGptOssModel, +) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, QEffGPTJBlock, @@ -239,6 +280,7 @@ ) from QEfficient.transformers.models.granite.modeling_granite import ( QEffGraniteAttention, + QEffGraniteDecoderLayer, QEffGraniteForCausalLM, QEffGraniteModel, ) @@ -260,6 +302,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -271,6 +314,7 @@ QEffLlamaRotaryEmbedding, ) from QEfficient.transformers.models.llama4.modeling_llama4 import ( + QEffLlama4DecoderWrapper, QEffLlama4ForCausalLM, QEffLlama4ForConditionalGeneration, QEffLlama4Router, @@ -283,9 +327,11 @@ QEffLlama4VisionModel, ) from QEfficient.transformers.models.llava.modeling_llava import ( + QEFFLlavaDecoderWrapper, QEffLlavaForConditionalGeneration, ) from QEfficient.transformers.models.llava_next.modeling_llava_next import ( + QEffLlavaNextDecoderWrapper, QEffLlavaNextForConditionalGeneration, ) from QEfficient.transformers.models.mistral.modeling_mistral import ( @@ -356,6 +402,16 @@ QEffQwen2ForCausalLM, QEffQwen2Model, ) +from QEfficient.transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + QEffQwen2_5_VisionTransformerPretrainedModel, + QEffQwen2_5_VLAttention, + QEffQwen2_5_VLDecoderLayer, + QEffQwen2_5_VLModel, + QEffQwen2_5_VLTextModel, + QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, + QEffQwen_2_5_vl_ForConditionalGeneration, +) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( QEffQwen3Attention, QEffQwen3DecoderLayer, @@ -376,6 +432,10 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.t5.modeling_t5 import ( + QEffT5Attention, + QEffT5LayerNorm, +) from QEfficient.transformers.models.whisper.modeling_whisper import ( QEffWhisperAttention, QEffWhisperDecoder, @@ -396,6 +456,7 @@ class CustomOpsTransform(ModuleMappingTransform): _module_mapping = { GemmaRMSNorm: GemmaCustomRMSNormAIC, Gemma2RMSNorm: GemmaCustomRMSNormAIC, + GptOssRMSNorm: CustomRMSNormAIC, LlamaRMSNorm: CustomRMSNormAIC, Llama4TextRMSNorm: CustomRMSNormAIC, MistralRMSNorm: CustomRMSNormAIC, @@ -404,6 +465,7 @@ class CustomOpsTransform(ModuleMappingTransform): Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, Qwen3RMSNorm: CustomRMSNormAIC, + Qwen2_5RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, PixtralRMSNorm: CustomRMSNormAIC, @@ -480,10 +542,18 @@ class KVCacheTransform(ModuleMappingTransform): Gemma3TextModel: QEffGemma3TextModel, Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # GPT_OSS + GptOssAttention: QEffGptOssAttention, + GptOssDecoderLayer: QEffGptOssDecoderLayer, + GptOssModel: QEffGptOssModel, + GptOssForCausalLM: QEffGptOssForCausalLM, + GptOssMLP: QEffGptOssMLP, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, GraniteAttention: QEffGraniteAttention, + GraniteDecoderLayer: QEffGraniteDecoderLayer, # GraniteMoe GraniteMoeModel: QEffGraniteMoeModel, GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, @@ -544,6 +614,14 @@ class KVCacheTransform(ModuleMappingTransform): Qwen3DecoderLayer: QEffQwen3DecoderLayer, Qwen3Model: QEffQwen3Model, Qwen3ForCausalLM: QEffQwen3ForCausalLM, + # Qwen2.5 VL + Qwen2_5_VLForConditionalGeneration: QEffQwen_2_5_vl_ForConditionalGeneration, + Qwen2_5_VLModel: QEffQwen2_5_VLModel, + Qwen2_5_VLAttention: QEffQwen2_5_VLAttention, + Qwen2_5_VLDecoderLayer: QEffQwen2_5_VLDecoderLayer, + Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention, + Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, @@ -575,6 +653,39 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class PrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + } + + +class PrefillOnlyChunkedTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + } + + +class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, + QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + } + + +class RevertPrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, + **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, + } + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. @@ -646,8 +757,22 @@ class SamplerTransform: # supported architectures _module_mapping = { - # Llama + QEffFalconForCausalLM, + QEffGemmaForCausalLM, + QEffGemma3DecoderWrapper, + QEffGPT2LMHeadModel, + QEffGPTJForCausalLM, + QEffGraniteForCausalLM, + QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, + QEffLlama4DecoderWrapper, + QEFFLlavaDecoderWrapper, + QEffLlavaNextDecoderWrapper, + QEffMptForCausalLM, + QEffPhi3ForCausalLM, + QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod @@ -741,6 +866,14 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} +class T5ModelTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + T5Attention: QEffT5Attention, + T5LayerNorm: QEffT5LayerNorm, + } + + class PoolingTransform: """ Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. @@ -758,3 +891,49 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu model = PooledModel(model, pooling_method) warnings.warn("Pooling is applied to the model.") return model, transformed + + +def get_decoder_layer_classes_for_export(model: nn.Module) -> set: + """ + Dynamically determine which DecoderLayer classes should be exported as functions + based on the model's architecture using the existing KVCacheTransform mapping. + """ + # Define patterns that identify decoder layer classes + DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"] + + # Get all QEff classes that are decoder layers from the existing mapping + decoder_layer_classes = set() + + for original_class, qeff_class in KVCacheTransform._module_mapping.items(): + # Check if the QEff class name contains decoder layer patterns + qeff_class_name = qeff_class.__name__ + if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS): + decoder_layer_classes.add(qeff_class) + + # Filter to only include classes that are actually used in the current model + model_decoder_classes = set() + for module in model.modules(): + if module.__class__ in decoder_layer_classes: + model_decoder_classes.add(module.__class__) + + return model_decoder_classes + + +class BlockedKVAttentionTransform: + _module_mapping = { + QEffLlamaAttention, + QEffQwen2_5_VLAttention, + } + + @classmethod + def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]: + transformed = False + for module in model.modules(): + if type(module) in cls._module_mapping: + repl_module = type(module) + module.__class__ = repl_module + module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module) + transformed = True # Set to True if at least one transformation occurs + elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping: + warnings.warn(f"KV blocking is not yet supported for {type(module)}.") + return model, transformed diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 24e8df46c..7c093a4b0 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -150,6 +150,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -167,6 +168,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -200,6 +204,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -231,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -261,6 +267,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -313,6 +320,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -348,6 +356,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -364,6 +373,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2_5_vl/__init__.py b/QEfficient/transformers/models/qwen2_5_vl/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen2_5_vl/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py new file mode 100644 index 000000000..21d2e026e --- /dev/null +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -0,0 +1,1278 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLConfig, + Qwen2_5_VLDecoderLayer, + Qwen2_5_VLModelOutputWithPast, + Qwen2_5_VLRotaryEmbedding, + Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, + apply_rotary_pos_emb_vision, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache + +# from transformers import Qw +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + mrope_section = mrope_section * 2 + cos = cos[position_ids] + sin = sin[position_ids] + + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffQwen2_5_VLVisionAttention(Qwen2_5_VLVisionAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + + # Create index grids + seq_len = attention_mask.shape[-1] + rows = torch.arange(seq_len).view(1, -1) + cols = torch.arange(seq_len).view(-1, 1) + + # Prepare start and end indices + start = cu_seqlens[:-1].view(-1, 1, 1) + end = cu_seqlens[1:].view(-1, 1, 1) + + # Create block masks using broadcasting + row_mask = (rows >= start) & (rows < end) + col_mask = (cols >= start) & (cols < end) + block_mask = row_mask & col_mask # shape: (num_blocks, seq_len, seq_len) + + # Combine all blocks into one mask + final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask[block_mask.any(dim=0)] = 0 + + final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) + + attention_mask[0] = final_mask + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class QEffQwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel): + def rot_pos_emb(self, grid_thw): + pos_ids = [] + + bs, t, h, w = grid_thw.shape + + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + + x_expanded = pos_ids.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1, -1) + pos_ids = x_expanded.reshape(-1, pos_ids.size(1)) + + max_grid_size = max(grid_thw.shape) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + bs, grid_t, grid_h, grid_w = grid_thw.shape + + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + + x_expanded = seqlens.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1) + seqlens = x_expanded.reshape(-1) + + index_padded = index_padded.reshape(-1) + + mask = (index_padded == -100).to(torch.int32) + + if torch.jit.is_tracing(): + order = torch.argsort(mask) + else: + order = torch.argsort(mask, stable=True) + + index_new = index_padded[order] + index_new = index_new[: index.reshape(-1).size(0)] + + step = grid_t * llm_grid_h * llm_grid_w + batch_indices = torch.arange(bs) + batch_indices = batch_indices.view(-1, 1) + offsets = batch_indices * step + window_index_tmp = index_new.unsqueeze(0) + offsets + window_index = window_index_tmp.reshape(-1) + + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + + cu_window_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens_tmp.dtype), cu_seqlens_tmp]) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + + cu_window_seqlens = cu_window_seqlens.to( + device=hidden_states.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 + ) + + # cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + + hidden_states = hidden_states[window_index, :, :] + + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + bs, t, h, w = grid_thw.shape + + t = torch.arange(t, t + 1).squeeze().expand(bs) + h = torch.arange(h, h + 1).squeeze().expand(bs) + w = torch.arange(w, w + 1).squeeze().expand(bs) + + cu_seqlens = (h * w).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + + cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens]) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class QEffQwen2_5_VLRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def eager_attention_forward_blockedKV( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + num_kv_blocks: Optional[torch.Tensor] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum + batch_size, num_heads, seq_len, _ = query.shape + current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE)) + + # Initialize Denominator + current_denominator = torch.zeros(batch_size, num_heads, seq_len) + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + block_size = -(-past_seen_tokens // num_kv_blocks) + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) + + for j in range(num_kv_blocks): + start_index = j * block_size + end_index = (j + 1) * block_size + K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + K_block_states = repeat_kv(K_block, module.num_key_value_groups) + V_block_states = repeat_kv(V_block, module.num_key_value_groups) + past_seen_tokens_start = start_index + past_seen_tokens_end = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start + ) + + # Compute attention scores for the block + attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + + # Update Running row maximum + prev_max = current_max + current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values) + delta_max = prev_max - current_max + + current_exp = torch.exp( + attn_weights_block - current_max.unsqueeze(-1) + ) # Subract current_max from each column of attn_weights_block + + # update running denominator + prev_denominator = current_denominator + current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1) + + prob = current_exp / current_denominator.unsqueeze(-1) + + prev_output = output + output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, V_block_states) + attn_output = output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + + +def eager_attention_forward_q_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + **kwargs, +): + """ + Q-blocked attention for Qwen2.5-VL. + Blocks only the query SL dimension. + + Args: + query: (BS, NH, Q_LEN, DH) + key: (BS, NH_KV, KV_LEN, DH) + value: (BS, NH_KV, KV_LEN, DH) + attention_mask: (BS, NH, Q_LEN, KV_LEN) or broadcastable + """ + BS, NH, Q_LEN, DH = query.shape + _, _, KV_LEN, _ = key.shape + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + target_blocks_q = int(os.environ.get("num_q_blocks", Q_LEN)) + q_block_positions = [(i * Q_LEN) // target_blocks_q for i in range(target_blocks_q)] + scaling = 1.0 / math.sqrt(module.head_dim) + + q_output_blocks = [] + q_attn_weights_blocks = [] + + # Process each Q block + for q_block_idx in range(target_blocks_q): + qi = q_block_positions[q_block_idx] + + # Calculate Q block size + if q_block_idx == target_blocks_q - 1: + real_q_len = Q_LEN - qi + else: + real_q_len = q_block_positions[q_block_idx + 1] - qi + + # Extract Q block + q_block = query[:, :, qi : qi + real_q_len, :] + attn_mask_block = None + if attention_mask is not None: + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + # Compute attention scores for this Q block + attn_weights = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + if attn_mask_block is not None: + attn_weights = torch.where( + attn_mask_block, + torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32, device=attn_weights.device), + attn_weights, + ) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # Compute output for this Q block + output_block = torch.matmul(attn_weights, value_states) + + q_output_blocks.append(output_block) + q_attn_weights_blocks.append(attn_weights) + + attn_output = torch.cat(q_output_blocks, dim=2) + attn_output = attn_output.transpose(1, 2).contiguous() + + # Concatenate attention weights + attn_weights = torch.cat(q_attn_weights_blocks, dim=2) + + return attn_output, attn_weights + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + num_kv_blocks: Optional[torch.Tensor] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + """ + Wrapper that routes to blocked or default attention based on environment variable. + """ + blocking_mode = os.environ.get("ATTENTION_BLOCKING_MODE", "default").lower() + + if blocking_mode == "q": + return eager_attention_forward_q_blocked(module, query, key, value, attention_mask, **kwargs) + elif blocking_mode != "q" and num_kv_blocks is not None: + return eager_attention_forward_blockedKV( + module, + query, + key, + value, + attention_mask, + cache_kwargs=cache_kwargs, + num_kv_blocks=num_kv_blocks, + layer_idx=layer_idx, + past_key_value=past_key_value, + **kwargs, + ) + elif blocking_mode == "default": + # Original implementation + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + else: + raise ValueError(f"Invalid ATTENTION_BLOCKING_MODE: {blocking_mode}. Must be 'q' or 'default'") + + +class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __qeff_init__(self): + self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + num_kv_blocks: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + if num_kv_blocks is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + "past_seen_tokens": past_seen_tokens, + } + past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + else: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + num_kv_blocks=num_kv_blocks, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_value=past_key_value, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffQwen2_5_VLDecoderLayer(Qwen2_5_VLDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids[0], target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return (hidden_states, past_key_values) + + +class QEffQwen2_5_VLModel(Qwen2_5_VLModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen_2_5_vl_EncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.visual + + def forward(self, pixel_values, image_grid_thw): + image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw) + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + + return image_embeds + + +class QEffQwen_2_5_vl_DecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.language_model = self.model.model.language_model + + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + ): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + outputs = self.model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=True, + ) + + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + + return logits, vision_embeds, image_idx, outputs.past_key_values + + +class QEffQwen_2_5_vl_ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + def get_qeff_vision_encoder(self): + return QEffQwen_2_5_vl_EncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffQwen_2_5_vl_DecoderWrapper(self) + + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + + vision_size = 3577 + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.model.config.hidden_size, + ) + inputs_shapes["image_grid_thw"] = (1, 1, 98, 146) + inputs_shapes["position_ids"] = ( + 3, + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = (14308, 1176) + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + .unsqueeze(0) + .repeat(4, 1, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.model.config.text_config, + batch_size=fbs if continuous_batching else bs, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: None, + height: int = None, + width: int = None, + num_frames: int = 1, + kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + **compiler_options, + ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + + if height is None or width is None: + height = constants.QWEN2_5_VL_HEIGHT + width = constants.QWEN2_5_VL_WIDTH + logger.warning( + f"Setting height and width to be {height} and {width} respectively, as it was neither passed nor found in vision_config" + ) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + channel = 3 + patch_size = self.config.vision_config.patch_size + temporal_patch_size = self.config.vision_config.temporal_patch_size + + IMAGE_FACTOR = 28 + MIN_PIXELS = 4 * 28 * 28 + MAX_PIXELS = 16384 * 28 * 28 + MAX_RATIO = 200 + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + resized_height, resized_width = smart_resize(height=height, width=width) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + grid_height = grid_h * grid_w + grid_width = patch_size * patch_size * temporal_patch_size * channel + vision_size = grid_height // 4 + vision_size = vision_size * num_frames + grid_height = grid_height * batch_size + + vision = [ + { + "batch_size": batch_size, + "vision_size": vision_size, + "grid_height": grid_height, + "grid_width": grid_width, + "grid_h": grid_h, + "grid_w": grid_w, + } + ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [lang_prefill, lang_decode] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") + return lang, compiler_options + + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + + vision_dynamic_axes = { + "pixel_values": {0: "grid_height", 1: "grid_width"}, + "image_grid_thw": {0: "batch_size", 2: "grid_h", 3: "grid_w"}, + } + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {1: "batch_size", 2: "seq_len"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + + dynamic_axes = {} + + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + lang_dynamic_axes.pop("vision_embeds") + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + for i in range(self.model.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=1): + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + + pos_ids, rope_deltas = self.model.get_rope_index( + inputs["input_ids"], + None if "image_grid_thw" not in inputs else inputs["image_grid_thw"], + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + inputs.pop("image_grid_thw", None) + + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + ] diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ecdb36019..540bad4c7 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -151,6 +151,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -168,6 +169,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -201,6 +205,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -232,6 +237,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -262,6 +268,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +321,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -349,6 +357,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -367,6 +376,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 591f7c1b0..cbd80d8ca 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -201,6 +201,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -218,6 +219,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -243,6 +247,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -274,6 +279,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -300,6 +306,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, batch_index: Optional[torch.LongTensor] = None, @@ -342,6 +349,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -369,6 +377,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -385,6 +394,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 9a327761d..c86e7478b 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -69,6 +69,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -85,6 +86,9 @@ def forward( if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -118,6 +122,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -153,6 +158,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -184,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -237,6 +244,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -273,6 +281,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -289,6 +298,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/transformers/models/t5/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py new file mode 100644 index 000000000..f54201465 --- /dev/null +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -0,0 +1,145 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +from transformers import EncoderDecoderCache +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerNorm, +) + + +class QEffT5LayerNorm(T5LayerNorm): + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class QEffT5Attention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + if past_key_value is not None: # This block is where the patch applies + position_bias = position_bias[:, :, -1:, :] # Added by patch + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index e078493a7..a03ffecf7 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,7 @@ def forward( position_ids_layer: torch.Tensor = None, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -100,6 +101,9 @@ def forward( value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: cache_kwargs = {"position_ids": position_ids_layer} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -181,6 +185,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, @@ -215,6 +220,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -388,6 +394,7 @@ def forward( cross_attn_head_mask=None, position_ids=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds=None, use_cache=None, output_attentions=None, @@ -532,6 +539,7 @@ def forward( layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, position_ids_layer=position_ids, @@ -643,6 +651,7 @@ def forward( cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, @@ -674,6 +683,7 @@ def forward( head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -719,6 +729,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -740,6 +751,7 @@ def forward( decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=position_ids, use_cache=use_cache, diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index d647b73a6..dc2308e99 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -4,3 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers + +__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"] diff --git a/QEfficient/transformers/quantizers/auto.py b/QEfficient/transformers/quantizers/auto.py index ba204e419..d73909211 100644 --- a/QEfficient/transformers/quantizers/auto.py +++ b/QEfficient/transformers/quantizers/auto.py @@ -11,7 +11,8 @@ from transformers.quantizers.quantizer_awq import AwqQuantizer from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.quantizers.quantizer_gptq import GptqHfQuantizer -from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( @@ -21,30 +22,35 @@ QEffFP8Quantizer, ) from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer QEFF_AUTO_QUANTIZER_MAPPING = { "awq": QEffAwqQuantizer, "gptq": QEffGPTQQuantizer, "compressed-tensors": QEffCompressedTensorsFP8Quantizer, "fp8": QEffFP8Quantizer, + "mxfp4": QEffMxfp4HfQuantizer, } QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": QEffAwqConfig, "gptq": QEffGPTQConfig, "compressed-tensors": QEffCompressedTensorsConfig, "fp8": QEffFP8Config, + "mxfp4": QEffMxfp4Config, } DUPLICATE_AUTO_QUANTIZER_MAPPING = { "awq": AwqQuantizer, "gptq": GptqHfQuantizer, "compressed-tensors": CompressedTensorsHfQuantizer, "fp8": None, + "mxfp4": Mxfp4HfQuantizer, } DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = { "awq": AwqConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, "fp8": None, + "mxfp4": Mxfp4Config, } diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 0427bca37..69d6380f0 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,13 +7,19 @@ import torch from torch import nn +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear -from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights +from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts +from QEfficient.transformers.quantizers.quantizer_utils import ( + convert_moe_packed_tensors, + dequantize_gptq, + unpack_weights, +) class AwqToMatmulNbitsTransform(ModuleMutatorTransform): @@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module): if original_module.bias is not None: dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) return dequant_linear_layer + + +class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights + """ + + _match_class = QEffMxfp4GptOssExperts + + @classmethod + def mutate(cls, original_module, parent_module): + dequant_module = GptOssExperts(original_module.config) + dequant_module.gate_up_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + convert_moe_packed_tensors( + original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32 + ) + ) + dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias + dequant_module.down_proj_bias = original_module.down_proj_bias + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py new file mode 100644 index 000000000..2ffba1bea --- /dev/null +++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import re +from typing import Optional + +import torch +import torch.nn as nn +from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer +from transformers.utils.quantization_config import Mxfp4Config + +from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert +from QEfficient.utils.logging_utils import logger + + +class QEffMxfp4GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + + self.gate_up_proj_blocks = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + ) + + self.down_proj_blocks = nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_scales = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + requires_grad=False, + ) + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + ) + self.alpha = 1.702 + self.limit = 7.0 + + self.gate_up_proj_precision_config = None + self.down_proj_precision_config = None + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + gate_up_proj = convert_moe_packed_tensors( + self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32 + ) + down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32) + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False + + +class QEffMxfp4Config(Mxfp4Config): + """ + Currently there is not need to change the implementation of Mxfp4Config + This is placeholder for future when we would want to change this + """ + + pass + + +class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer): + def validate_environment(self, *args, **kwargs): + return True + + def update_torch_dtype(self, torch_dtype): + if torch_dtype not in [None, torch.float32]: + logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") + return None + + def _process_model_before_weight_loading( + self, + model: torch.nn.Module, + keep_in_fp32_modules: Optional[list[str]] = None, + **kwargs, + ): + self.modules_to_not_convert = get_keys_to_not_convert(model) + self.modules_to_not_convert = ( + ["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert + ) + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + self.modules_to_not_convert = list(set(self.modules_to_not_convert)) + config = model.config + + # -- Defining local method as it uses lot of local variables -- + def _replace_with_mxfp4_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + ): + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + if not should_convert_module(current_key_name, modules_to_not_convert): + current_key_name.pop(-1) + continue + if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: + model._modules[name] = QEffMxfp4GptOssExperts(config) + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_mxfp4_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + current_key_name.pop(-1) + return model, has_been_replaced + + _replace_with_mxfp4_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index a318fb8e4..424692d08 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import copy +import math import torch from torch import nn @@ -378,3 +379,70 @@ def repack_zeros(qzeros, bits): break qzeros = qzeros.T return qzeros + + +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def convert_moe_packed_tensors( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 32768 * 1024, +) -> torch.Tensor: + """ + reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98 + """ + + scales = scales.to(torch.int32) - 127 + + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + out = out.to(dtype).permute(0, 2, 1).contiguous() + return out diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..5c86b6355 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -47,7 +49,6 @@ def prefill_path( positions_mask = (position_ids[:, :1] != zero_tensor).view(-1, 1) mul_value = CtxScatterFuncCB3D.apply(mul_value, batch_index, zero_tensor, positions_mask) past_repetition_penalty_buffer *= mul_value - past_presence_penalty_buffer *= mul_value # Mask out-of-bounds or invalid position_ids or input_ids input_ids = torch.where(position_ids == -1, torch.iinfo(torch.int32).max, input_ids) @@ -59,6 +60,9 @@ def prefill_path( input_ids, torch.ones(input_ids.shape, dtype=torch.bool), ) + + mul_value = torch.zeros(past_presence_penalty_buffer.shape[0], 1, dtype=torch.bool) + past_presence_penalty_buffer *= mul_value return past_repetition_penalty_buffer, past_presence_penalty_buffer @@ -103,6 +107,7 @@ def sampler_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -112,6 +117,8 @@ def sampler_forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: Optional[int] = None, + vision_embeds: Optional[torch.FloatTensor] = None, + image_idx: Optional[torch.IntTensor] = None, last_accepted_output_tokens: Optional[torch.Tensor] = None, # (batch_size, spec_length or less) past_repetition_penalty_buffer: Optional[torch.Tensor] = None, repetition_penalties: Optional[torch.Tensor] = None, @@ -122,11 +129,15 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + token_bitmasks: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) and return the next tokens and/or probability distributions. + The vision_embeds and image_idx parameters are optional + and are used only for VLMs when supported by the original forward function. + Args: last_accepted_output_tokens (`torch.Tensor`, *optional*): Output tokens accepted by the Speculative Decoding Draft Language Model. @@ -169,21 +180,43 @@ def sampler_forward( random_numbers (`torch.Tensor`, *optional*): Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. - """ - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + token_bitmasks (`torch.Tensor`, *optional*): + Boolean mask used to guide token-level filtering during decoding. Each + element of this tensor indicates whether the corresponding token should be + kept (1) or masked (0). Shape: (batch_size, vocab_size) + """ + if vision_embeds is not None: + forward_kwargs = dict( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -197,6 +230,13 @@ def sampler_forward( batch_index = torch.arange(batch_size).view(-1, 1) batch_index_reshaped = batch_index.view(-1) + + # Guided decoding + if token_bitmasks is not None and (token_bitmasks != 1).any(): + assert spec_length == 1, "Currently, guided decoding is not supported with Speculative Decoding" + # Mask logits where token_bitmasks is 0 with -inf + logits = torch.where(token_bitmasks == 1, logits, torch.finfo(torch.float16).min) + # Prefill past_repetition_penalty_buffer_prefill, past_presence_penalty_buffer_prefill = prefill_path( input_ids=input_ids, @@ -224,17 +264,6 @@ def sampler_forward( is_prefill, past_presence_penalty_buffer_prefill, past_presence_penalty_buffer_decode ) - # Greedy Sampling - greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1) - if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False): - return SamplerOutput( - probs=None, - next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, - past_repetition_penalty_buffer=past_repetition_penalty_buffer, - past_presence_penalty_buffer=past_presence_penalty_buffer, - ) - # Repetition Penalty if (repetition_penalties != 1.0).any(): past_repetition_penalty_buffer_selected = past_repetition_penalty_buffer[batch_index_reshaped].repeat( @@ -253,6 +282,19 @@ def sampler_forward( ) # (batch_size * spec_length, vocab_size) logits -= presence_penalties * past_presence_penalty_buffer_selected + # Greedy Sampling + greedy_samples = torch.argmax(logits, dim=1, keepdim=True) # (batch_size * spec_length, 1) + if (temperatures == 0).all() and not self.qaic_config.get("return_pdfs", False): + return SamplerOutput( + probs=None, + next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), + past_repetition_penalty_buffer=past_repetition_penalty_buffer, + past_presence_penalty_buffer=past_presence_penalty_buffer, + ) + # TODO: Frequency Penalty # Temperature Scaling @@ -300,9 +342,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) @@ -314,7 +355,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) diff --git a/QEfficient/transformers/spd/spd_transform_forward.py b/QEfficient/transformers/spd/spd_transform_forward.py index e82bf4cdf..4703cb18d 100644 --- a/QEfficient/transformers/spd/spd_transform_forward.py +++ b/QEfficient/transformers/spd/spd_transform_forward.py @@ -76,6 +76,7 @@ def tlm_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -123,6 +124,7 @@ def tlm_forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index e487d4af4..3d6583f85 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -10,12 +10,12 @@ undo_transformers_quantizers, ) from QEfficient.utils._utils import ( # noqa: F401 + LRUCache, check_and_assign_cache_dir, create_json, create_model_params, custom_format_warning, dump_qconfig, - export_wrapper, generate_mdp_partition_config, get_num_layers_from_config, get_num_layers_vlm, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index abe383556..26bae7a34 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -12,7 +12,6 @@ import subprocess import xml.etree.ElementTree as ET from dataclasses import dataclass -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import requests @@ -27,12 +26,41 @@ PreTrainedTokenizerFast, ) -from QEfficient.utils.cache import QEFF_HOME from QEfficient.utils.constants import KWARGS_INCLUSION_LIST, QEFF_MODELS_DIR, Constants, QnnConstants -from QEfficient.utils.hash_utils import create_export_hash, json_serializable +from QEfficient.utils.hash_utils import json_serializable from QEfficient.utils.logging_utils import logger +class LRUCache: + """Simple LRU cache with size limit for vision outputs""" + + def __init__(self, max_size=100): + self._cache = {} + self._access_order = [] + self._max_size = max_size + + def get(self, key): + if key in self._cache: + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + return None + + def put(self, key, value): + if key in self._cache: + self._access_order.remove(key) + elif len(self._cache) >= self._max_size: + oldest = self._access_order.pop(0) + del self._cache[oldest] + + self._cache[key] = value + self._access_order.append(key) + + def clear(self): + self._cache.clear() + self._access_order.clear() + + class DownloadRetryLimitExceeded(Exception): """ Used for raising error when hf_download fails to download the model after given max_retries. @@ -502,59 +530,11 @@ def create_model_params(qeff_model, **kwargs) -> Dict: """ model_params = copy.deepcopy(kwargs) model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST} - model_params["config"] = qeff_model.model.config.to_diff_dict() model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None) model_params["applied_transform_names"] = qeff_model._transform_names() return model_params -def export_wrapper(func): - def wrapper(self, *args, **kwargs): - export_dir = kwargs.get("export_dir", None) - parent_dir = self.model_architecture or self.model_name - export_dir = Path(export_dir or (QEFF_HOME / parent_dir / self.model_name)) - - # PREPROCESSING OF PARAMETERS - - # Get the original signature - original_sig = inspect.signature(func) - - # Remove 'self' from parameters - params = list(original_sig.parameters.values())[1:] # skip 'self' - new_sig = inspect.Signature(params) - - # Bind args and kwargs to the new signature - bound_args = new_sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - # Get arguments as a dictionary - all_args = bound_args.arguments - - export_hash, filtered_hash_params = create_export_hash( - model_params=self.hash_params, - output_names=all_args.get("output_names"), - dynamic_axes=all_args.get("dynamic_axes"), - export_kwargs=all_args.get("export_kwargs", None), - onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), - ) - export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) - kwargs["export_dir"] = export_dir - self.export_hash = export_hash - - # _EXPORT CALL - onnx_path = func(self, *args, **kwargs) - - # POST-PROCESSING - # Dump JSON file with hashed parameters - hashed_params_export_path = export_dir / "hashed_export_params.json" - create_json(hashed_params_export_path, filtered_hash_params) - logger.info("Hashed parameters exported successfully.") - - return onnx_path - - return wrapper - - def execute_command(process: str, command: str, output_file_path: Optional[str] = None): """ Executes the give command using subprocess. diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py new file mode 100644 index 000000000..cc259ee36 --- /dev/null +++ b/QEfficient/utils/check_ccl_specializations.py @@ -0,0 +1,164 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Tuple + +from QEfficient.utils import constants +from QEfficient.utils.logging_utils import logger + + +# Better performance when context length is multiple of 1024 → map CL to the next multiple of 1024 +def next_multiple_of_1024(n: int) -> int: + """Ceil 'n' to the next multiple of 1024.""" + if n <= 0: + return 0 + return ((n + 1023) // 1024) * 1024 + + +def floor_to_1000(n: int) -> int: + """Floor 'n' to the nearest lower multiple of 1000.""" + if n <= 0: + return 0 + return (n // 1000) * 1000 + + +def is_power_of_two(n: int) -> bool: + """Return True if n is a power of two (n > 0 and n & (n - 1) == 0).""" + return n > 0 and (n & (n - 1)) == 0 + + +def build_doubling_list(start: int, limit: int, max_elements: int, last_value: int = None) -> List[int]: + """ + Build a STRICT doubling list: {start, start*2, start*4, ...} up to 'limit', + collecting at most 'max_elements' values. Returns a list. + Ensure the last element equals 'last_value' by appending or replacing the final element. + """ + values: List[int] = [] + if max_elements <= 0 or start <= 0 or limit <= 0: + return values + + element = start + while element <= limit and len(values) < max_elements: + values.append(element) + element *= 2 + + if last_value is not None and values[-1] != last_value: + if len(values) < max_elements: + values.append(last_value) + else: + values[-1] = last_value + return values[:max_elements] + + +def automatic_ccl_generation( + ctx_len: int, + prefill_seq_len: int, +) -> Tuple[List[int], List[int], int]: + """ + Automatic Compute-Context-Length Lists Generation + Purpose: + Compute decode and prefill CCL lists based on an input context length (CL), + prefill sequence length, and optional pre-specified lists. + """ + # Handle non-positive CL + if ctx_len <= 0: + mapped_cl = next_multiple_of_1024(1) + seq = [mapped_cl] + return seq, seq, mapped_cl + + mapped_cl = next_multiple_of_1024(ctx_len) + + # Early small-ctx_len case for identical lists + if mapped_cl <= constants.CCL_START_CTX_LEN: + seq = [mapped_cl] + return seq, seq, mapped_cl + + # To limit the number of elements in CCL list, the starting point will be calculated based on context length + for upper_bound, (decode_start, prefill_start) in constants.CCL_START_MAP.items(): + if mapped_cl <= upper_bound: + break + + if prefill_seq_len > 1: + # ---- Decode: strict doubling up to mapped_cl, then enforce last = mapped_cl + decode_list = build_doubling_list( + start=decode_start, limit=mapped_cl, max_elements=constants.CCL_MAX_ELEMENTS_LISTS, last_value=mapped_cl + ) + + # ---- Prefill: + if is_power_of_two(mapped_cl): + # STRICT doubling only, bounded by mapped_cl + prefill_list = build_doubling_list( + start=prefill_start, limit=mapped_cl, max_elements=constants.CCL_MAX_ELEMENTS_LISTS + ) + else: + # Doubles bounded by mapped_cl, but last must equal floor_to_1000(mapped_cl) + prefill_last = floor_to_1000(mapped_cl) + prefill_list = build_doubling_list( + start=prefill_start, + limit=mapped_cl, + max_elements=constants.CCL_MAX_ELEMENTS_LISTS, + last_value=prefill_last, + ) + + return prefill_list, decode_list, mapped_cl + + elif prefill_seq_len == 1: + # When prefill_seq_len=1 such as in MoE models, prefilling and decoding processes can use the same specializations and we can double the length of ccl lists. + # Due to limitations in the number of specializations during compilation, we set the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists to 2*constants.CCL_MAX_ELEMENTS_LISTS. + max_elems = 2 * constants.CCL_MAX_ELEMENTS_LISTS + + if mapped_cl < constants.CCL_START_CTX_LEN: + seq = [mapped_cl] + return seq, seq, mapped_cl + + limit = min(mapped_cl, constants.CCL_START_CTX_LEN * (2 ** (max_elems - 1))) + + seq_list = build_doubling_list( + start=constants.CCL_START_CTX_LEN, limit=limit, max_elements=max_elems, last_value=mapped_cl + ) + + return seq_list, seq_list, mapped_cl + else: + logger.warning("prefill_seq_len cannot be less than 1!") + + +def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len): + # Automatic CCL generation: If both ccl_prefill and ccl_decode are None + if ccl_prefill is None and ccl_decode is None: + # Generate optimized context length lists for prefill and decode based on ctx_len + # Due to compiler limitations, ccl_prefill and ccl_decode must have distinct values + ccl_prefill, ccl_decode, ctx_len = automatic_ccl_generation(ctx_len, prefill_seq_len) + else: + if prefill_seq_len == 1: + if ccl_prefill is not None and ccl_decode is not None: + # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + ccl_union_all = sorted(set([min(x, ctx_len) for x in ccl_prefill + ccl_decode])) + ccl_prefill = ccl_union_all + ccl_decode = ccl_union_all + else: + if ccl_prefill: + ccl_prefill = sorted({min(x, ctx_len) for x in (ccl_prefill)}) + if ccl_decode: + ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)}) + + if ccl_prefill is not None and ccl_decode is not None: + tmp_prefill = ccl_prefill + ccl_prefill = [] + for val in tmp_prefill: + while val in ccl_decode or val in ccl_prefill: + val -= 1 + if val < 0: + break # Prevent negative values + if val >= 0: + ccl_prefill.append(val) + ccl_prefill.sort() + + logger.info("CCL Configuration:") + logger.info(f" - Prefill context lengths: {ccl_prefill}") + logger.info(f" - Decode context lengths: {ccl_decode}") + logger.info(f" - Max context length: {ctx_len}") + return ccl_prefill, ccl_decode, ctx_len diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 57fba282b..d0318ac3e 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,6 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -84,10 +83,14 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 +FILE_CHUNK_SIZE_DEFAULT = 10 * 2**30 # 10 GB +SIZE_THRESHOLD_DEFAULT = 1024 -COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] + +COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-compile-only"] DEFAULT_AIC_HW_VERSION = "ai100" +ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL = 100 # InternVL constants # Fixing the feature size with reference to OpenGVLab/InternVL2_5-1B, OpenGVLab/InternVL2_5-38B and OpenGVLab/InternVL2_5-78B @@ -97,6 +100,8 @@ def get_models_dir(): INTERN_CTX_LEN = 4096 INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 INTERN_NUM_CHANNELS = 3 +INTERN_IMAGE_HEIGHT = 1000 +INTERN_IMAGE_WIDTH = 747 INTERN_IMG_CONTEXT_TOKEN = 151667 # Specific to InternVL3_5 series, same token won't work for InternVL2_5 series @@ -125,6 +130,54 @@ def get_models_dir(): # Wav2Vec2 Constant WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) +# Qwen2_5_vl Constants +QWEN2_5_VL_HEIGHT = 354 +QWEN2_5_VL_WIDTH = 536 + +# Modules to cache while clearing the pytorch weights +CACHE_MODULES = ["get_output_names", "get_dummy_inputs", "get_onnx_dynamic_axes", "get_specializations"] + +# Mistral3 Constants +MISTRAL3_IMAGE_HEIGHT = 1540 +MISTRAL3_IMAGE_WIDTH = 1540 + +# Molmo Constants +MOLMO_IMAGE_HEIGHT = 536 +MOLMO_IMAGE_WIDTH = 354 +# Flux Transformer Constants +FLUX_ONNX_EXPORT_SEQ_LENGTH = 256 +FLUX_ONNX_EXPORT_COMPRESSED_LATENT_DIM = 4096 +FLUX_ADALN_HIDDEN_DIM = 3072 +FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context +FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3 +FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM + +# Wan Transformer Constants +WAN_TEXT_EMBED_DIM = 5120 +WAN_PROJECTION_DIM = 6 +WAN_ONNX_EXPORT_BATCH_SIZE = 1 +WAN_ONNX_EXPORT_FRAMES = 81 +WAN_ONNX_EXPORT_LATENT_FRAMES = 21 +WAN_ONNX_EXPORT_SEQ_LEN = 512 +WAN_ONNX_EXPORT_ROTARY_DIM = 128 +WAN_DIT_OUT_CHANNELS = 64 +# Wan dims for 180p +WAN_ONNX_EXPORT_CL_180P = 5040 +WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 24 +WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 40 +WAN_ONNX_EXPORT_HEIGHT_180P = 192 +WAN_ONNX_EXPORT_WIDTH_180P = 320 + +# For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length +CCL_START_MAP = { + 32768: (4096, 4000), + 65536: (8192, 8000), + float("inf"): (16384, 16000), +} +# Limitation in the maximum number of elements in comp_ctx_lengths_decode and comp_ctx_lengths_prefill lists during automatic lists generation process. +CCL_MAX_ELEMENTS_LISTS = 5 +CCL_START_CTX_LEN = 4096 + class Constants: # Export Constants. @@ -136,6 +189,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 + NUM_KV_BLOCKS = 8 MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS SAMPLER_OPS = { "repetition_penalties", diff --git a/QEfficient/utils/export_utils.py b/QEfficient/utils/export_utils.py new file mode 100644 index 000000000..33ba694cf --- /dev/null +++ b/QEfficient/utils/export_utils.py @@ -0,0 +1,221 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +import inspect +import re +import warnings +from pathlib import Path +from typing import Dict + +from QEfficient.base.onnx_transforms import CustomOpTransform, RenameFunctionOutputsTransform +from QEfficient.transformers.cache_utils import InvalidIndexProvider +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export +from QEfficient.utils.cache import QEFF_HOME +from QEfficient.utils.hash_utils import create_export_hash +from QEfficient.utils.logging_utils import logger +from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches + + +def export_wrapper(func): + """ + Decorator for export methods that orchestrates the complete export lifecycle. + + Responsibilities: + 1. Prepare export directory structure + 2. Generate reproducible hash for export configuration + 3. Setup ONNX subfunction environment (if enabled) + 4. Execute the wrapped export function + 5. Cleanup subfunction environment (if enabled) + 6. Save export metadata + + Args: + func: The export method to wrap (typically _export) + + Returns: + Wrapped function with complete export lifecycle management + """ + + def wrapper(self, *args, **kwargs): + # 1. Setup ONNX subfunctions if requested + if use_onnx_subfunctions := kwargs.pop("use_onnx_subfunctions", False): + args, kwargs = _setup_onnx_subfunctions(self, args, kwargs) + + # 2. Prepare export directory + export_dir = _prepare_export_directory(self, kwargs) + + # 3. Generate hash and finalize export directory path + export_hash, filtered_hash_params = _generate_export_hash(self, args, kwargs, func) + export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) + kwargs["export_dir"] = export_dir + self.export_hash = export_hash + + # 4. Execute the actual export + onnx_path = func(self, *args, **kwargs) + + # 5. Save export metadata + _save_export_metadata(export_dir, filtered_hash_params) + + # 6. Always cleanup subfunctions if they were setup + if use_onnx_subfunctions: + _cleanup_onnx_subfunctions(self) + + return onnx_path + + return wrapper + + +def _prepare_export_directory(qeff_model, kwargs) -> Path: + """ + Prepare and return the base export directory path. + + Args: + qeff_model: The QEff model instance + kwargs: Keyword arguments containing optional export_dir + + Returns: + Path object for the base export directory + """ + export_dir = kwargs.get("export_dir", None) + parent_dir = qeff_model.model_architecture or qeff_model.model_name + return Path(export_dir or (QEFF_HOME / parent_dir / qeff_model.model_name)) + + +def _generate_export_hash(qeff_model, args, kwargs, func): + """ + Generate export hash from model parameters and export arguments. + + The hash ensures reproducibility and prevents conflicts between + different export configurations. + + Args: + qeff_model: The QEff model instance + args: Positional arguments to the export function + kwargs: Keyword arguments to the export function + func: The export function being wrapped + + Returns: + Tuple of (export_hash: str, filtered_hash_params: dict) + """ + # Extract function signature + original_sig = inspect.signature(func) + params = list(original_sig.parameters.values())[1:] # Skip 'self' + new_sig = inspect.Signature(params) + # Bind all arguments + bound_args = new_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + all_args = bound_args.arguments + + # Use the model's current configuration for hashing to ensure any post-load modifications are captured + # TODO: Replace with get_model_config property of modeling classes and remove the if-else + # Determine the config dict to use, preferring .to_diff_dict() if available + if hasattr(qeff_model.model, "config") and hasattr(qeff_model.model.config, "to_diff_dict"): + config_val = qeff_model.model.config.to_diff_dict() + elif hasattr(qeff_model.model, "model") and hasattr(qeff_model.model.model.config, "to_diff_dict"): + config_val = qeff_model.model.model.config.to_diff_dict() + else: + config_val = qeff_model.model.config + + copy_of_hash_params = copy.deepcopy(qeff_model.hash_params) + copy_of_hash_params.update( + { + "config": config_val, + } + ) + # Generate hash from relevant parameters + export_hash, filtered_hash_params = create_export_hash( + model_params=copy_of_hash_params, + output_names=all_args.get("output_names"), + dynamic_axes=all_args.get("dynamic_axes"), + export_kwargs=all_args.get("export_kwargs", None), + onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), + ) + + return export_hash, filtered_hash_params + + +def _setup_onnx_subfunctions(qeff_model, args, kwargs): + """ + Setup ONNX subfunction export environment. + + This function prepares the model and environment for exporting with + ONNX subfunctions enabled. It: + - Applies necessary torch patches + - Modifies output names for subfunction compatibility + - Adds subfunction-specific ONNX transforms + - Updates export kwargs with module classes + + Args: + qeff_model: The QEff model instance + kwargs: Export keyword arguments (modified in-place). + """ + warnings.warn( + "The subfunction feature is experimental. Please note that using compile " + "consecutively with and without subfunction may produce inconsistent results." + ) + + # Apply torch patches for subfunction support + apply_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = True + # Transform output names for subfunction compatibility + if "output_names" in kwargs: + kwargs["output_names"] = [ + re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"] + ] + else: + args = list(args) + args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]] + args = tuple(args) + # Add subfunction-specific ONNX transforms + qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.append(CustomOpTransform) + + # TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation + decoder_layer_classes = get_decoder_layer_classes_for_export(qeff_model.model) + if decoder_layer_classes: + kwargs["export_modules_as_functions"] = decoder_layer_classes + return args, kwargs + + +def _cleanup_onnx_subfunctions(qeff_model): + """ + Cleanup ONNX subfunction export environment. + + Restores the model and environment to pre-subfunction state by: + - Undoing torch patches + - Resetting InvalidIndexProvider flag + - Restoring original ONNX transforms list + + Args: + qeff_model: The QEff model instance + + Note: + This function is called in a finally block to ensure cleanup + even if export fails. Errors during cleanup are logged but + not re-raised to avoid masking the original exception. + """ + # Undo torch patches + undo_torch_patches() + InvalidIndexProvider.SUBFUNC_ENABLED = False + qeff_model._onnx_transforms.remove(RenameFunctionOutputsTransform) + qeff_model._onnx_transforms.remove(CustomOpTransform) + + +def _save_export_metadata(export_dir: Path, filtered_hash_params: Dict): + """ + Save export metadata to JSON file for reproducibility. + + Args: + export_dir: Directory where the export was saved + filtered_hash_params: Dictionary of parameters used for hashing + """ + # Import here to avoid circular dependency + from QEfficient.utils._utils import create_json + + hashed_params_path = export_dir / "hashed_export_params.json" + create_json(hashed_params_path, filtered_hash_params) + logger.info("Hashed parameters exported successfully.") diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index eb1f7c8e6..95474acfd 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -68,7 +68,8 @@ def prepare_pytorch_inputs(self): batch_size, input_len = input_ids.shape inputs.pop("attention_mask") inputs.pop("token_type_ids", None) - position_ids = torch.arange(input_len).view(1, -1) + usable_bs = self.full_batch_size if self.full_batch_size else 1 + position_ids = torch.arange(input_len).view(1, input_len).repeat(usable_bs, 1) inputs["input_ids"] = torch.concat( [ input_ids, @@ -87,13 +88,20 @@ def prepare_pytorch_inputs(self): if self.full_batch_size: inputs["input_ids"] = input_ids - inputs["position_ids"] = torch.arange(input_len).view(1, input_len) - inputs["batch_index"] = torch.arange(1).view(-1, 1) + inputs["position_ids"] = position_ids + inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] for i in range(self.n_layer): - past_key = torch.zeros((self.padding_shape), dtype=torch.float32) - past_value = torch.zeros((self.padding_shape), dtype=torch.float32) + if ( + all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) + and self.config.layer_types[i] == "sliding_attention" + ): + pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] + else: + pad_shape = self.padding_shape + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) past_key_values.append(pkv) inputs["past_key_values"] = tuple(past_key_values) @@ -113,18 +121,15 @@ def update_pytorch_inputs(self, inputs, pt_outputs): """ updated_inputs = {} if self.full_batch_size: - batch_index = torch.arange(1).view(-1, 1) - input_ids = pt_outputs.logits.detach().argmax(2) updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) - updated_inputs["input_ids"][batch_index.view(-1)] = input_ids + updated_inputs["input_ids"][inputs["batch_index"].view(-1)] = input_ids position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["position_ids"] = torch.full((self.full_batch_size, 1), 0) - updated_inputs["position_ids"][batch_index.view(-1)] = position_ids - - updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) + updated_inputs["position_ids"][inputs["batch_index"].view(-1)] = position_ids + updated_inputs["batch_index"] = inputs["batch_index"] else: updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 @@ -169,8 +174,17 @@ def prepare_ort_inputs(self): inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) else: for i in range(self.n_layer): - inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) - inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) + if ( + all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) + and self.config.layer_types[i] == "sliding_attention" + ): + pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] + else: + pad_shape = self.padding_shape + inputs["past_key." + str(i)] = np.zeros((pad_shape), dtype=np.float32) + inputs["past_value." + str(i)] = np.zeros((pad_shape), dtype=np.float32) + if self.full_batch_size: + inputs["batch_index"] = np.arange(self.full_batch_size).reshape(-1, 1) return inputs def update_ort_inputs(self, inputs, ort_outputs): @@ -191,7 +205,8 @@ def update_ort_inputs(self, inputs, ort_outputs): for i in range(self.n_layer): updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] - + if self.full_batch_size: + updated_inputs["batch_index"] = inputs["batch_index"] return updated_inputs def update_ort_outputs(self, ort_outputs): diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index b6b38b8b4..10e6686d0 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -14,7 +14,8 @@ def json_serializable(obj): if isinstance(obj, set): - return sorted(obj) + # Convert set to a sorted list of strings for consistent hashing + return sorted([cls.__name__ if isinstance(cls, type) else str(cls) for cls in obj]) raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") @@ -55,7 +56,6 @@ def create_export_hash(**kwargs): export_params = {} export_params["output_names"] = kwargs.get("output_names") export_params["dynamic_axes"] = kwargs.get("dynamic_axes") - export_hash_params["export_params"] = export_params export_kwargs = kwargs.get("export_kwargs") @@ -67,5 +67,4 @@ def create_export_hash(**kwargs): export_hash_params.update(onnx_transform_kwargs) if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict): export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict() - return hash_dict_params(export_hash_params), export_hash_params diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index c54dadeac..61553e7ea 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import os +from typing import List import numpy as np import onnx @@ -276,6 +277,54 @@ def __init__( self.config = config self.gen_len = max_gen_len + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + + for idx, (image, query) in enumerate(zip(images, queries)): + # Prepare conversation format for each image-query pair + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + }, + ] + prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + # Process inputs + inputs = self.processor(images=image, text=prompt, return_tensors="pt") + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Generate tokens + output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) + offset_output = output[0, inputs["input_ids"].shape[1] :] + + # Decode and print output + py_output = self.processor.tokenizer.decode(offset_output).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Query:", repr(query)) + print("Completion:", repr(py_output)) + + generated_ids.append(offset_output.numpy()) + + return generated_ids + @torch.no_grad() def run_vlm_hf_model_on_pytorch(self, model, inputs): output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False) @@ -448,6 +497,57 @@ def __init__(self, batch_size, processor, config, image, prompt, prompt_len, ctx self.config = config self.gen_len = max_gen_len + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + + for idx, (image, query) in enumerate(zip(images, queries)): + num_patches_list = [] + pixel_values = [] + questions = [] + + pixel_value = self.processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + question = "\n" + query + + pixel_values.append(pixel_value) + pixel_values = torch.cat(pixel_values, dim=0) + questions.append(question) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = self.processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = self.processor.tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_values.clone() + + generation_config = dict(max_new_tokens=self.gen_len, do_sample=False) + generation_config["eos_token_id"] = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) + + # Decode and print output + outputs = model.generate(**inputs, **generation_config) + offset_output = outputs[0].detach().numpy() + + py_output = self.processor.tokenizer.decode(offset_output, skip_special_tokens=True).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Completion:", repr(py_output)) + generated_ids.append(offset_output) + + return generated_ids + @torch.no_grad() def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config): outputs = model.generate(**inputs, **generation_config) @@ -490,3 +590,34 @@ def run_vlm_hf_model_on_pytorch(self, model, inputs, generation_config): print("Original HF Model Outputs (Torch CPU):") print("Completion:", repr(py_output)) return generated_ids + + @torch.no_grad() + def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries, generation_config): + """ + Function responsible for running HuggingFace ``PyTorch`` model for continuous batching + and return the output tokens for each prompt/image pair. + + ``Mandatory`` Args: + :model (torch.nn.module): Original ``PyTorch`` model + :images (List[PIL.Image]): List of input images + :queries (List[str]): List of input queries + :generation_config (dict): Generation configuration parameters + + Return: + :List[numpy.ndarray]: List of generated output tokens for each prompt + """ + generated_ids = [] + for idx, (image, query) in enumerate(zip(images, queries)): + inputs = self.processor.process(images=[image], text=query) + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + outputs = model.generate_from_batch( + inputs, generation_config, tokenizer=self.processor.tokenizer, do_sample=False + ) + + offset_output = outputs[0, inputs["input_ids"].size(1) :] + + py_output = self.processor.tokenizer.decode(offset_output, skip_special_tokens=True).strip() + print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:") + print("Completion:", repr(py_output)) + generated_ids.append(offset_output) + return generated_ids diff --git a/QEfficient/utils/sampler_utils.py b/QEfficient/utils/sampler_utils.py index 6fb1b326f..82a0843bc 100644 --- a/QEfficient/utils/sampler_utils.py +++ b/QEfficient/utils/sampler_utils.py @@ -5,13 +5,18 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Set +from typing import Dict, List, Optional, Set +import torch + +from QEfficient.utils import constants from QEfficient.utils.constants import Constants from QEfficient.utils.logging_utils import logger -def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[bool] = None) -> bool: +def validate_sampler_inputs( + session_inputs: Set[str], include_sampler: Optional[bool] = None, include_guided_decoding: Optional[bool] = None +) -> bool: """ Validates whether the `QAICInferenceSession` inputs match inputs required for on-device sampling. @@ -28,7 +33,7 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ValueError if partial support is detected or if user intent conflicts with QPC capabilities. """ - sampler_inputs = Constants.SAMPLER_INPUTS + sampler_inputs = Constants.SAMPLER_INPUTS | ({"token_bitmasks"} if include_guided_decoding else set()) count = len(sampler_inputs & session_inputs) session_includes_sampler = True @@ -56,3 +61,92 @@ def validate_sampler_inputs(session_inputs: Set[str], include_sampler: Optional[ ) return session_includes_sampler + + +def get_sampling_inputs_and_outputs( + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + continuous_batching: bool, + vocab_size: int, + qaic_config: Dict, +): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + continuous_batching : bool + Whether this model will be used for continuous batching in the future. + vocab_size: int + Vocabulary size for this model. + qaic_config : Dict + QAIC config dictionary. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + seq_len: int = example_inputs["input_ids"].shape[-1] + + example_inputs["last_accepted_output_tokens"] = torch.zeros((bs, seq_len), dtype=torch.int64) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if continuous_batching else bs, vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + if qaic_config.get("include_guided_decoding", False): + example_inputs["token_bitmasks"] = torch.zeros((bs, vocab_size), dtype=torch.bool) + dynamic_axes["token_bitmasks"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/utils/torch_patches.py b/QEfficient/utils/torch_patches.py new file mode 100644 index 000000000..0b9b37afa --- /dev/null +++ b/QEfficient/utils/torch_patches.py @@ -0,0 +1,115 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""Monkey patches for torch.onnx.utils to fix ONNX export issues.""" + +import torch +import torch.onnx.utils as onnx_utils +from torch import _C + +# Store original references before patching +_original_setup_trace_module_map = onnx_utils._setup_trace_module_map +_original_get_module_attributes = getattr(onnx_utils, "_get_module_attributes", None) + + +def _setup_trace_module_map_patched( + model, + export_modules_as_functions, +): + """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch.""" + + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + # FIX: use empty dict to avoid type mismatch + onnx_attrs = {} + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n)) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _get_module_attributes(module): + """Helper function to get module attributes safely.""" + import typing + + import torch.nn + + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def apply_torch_patches(): + """Apply monkey patches for ONNX export.""" + onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched + if hasattr(onnx_utils, "_get_module_attributes"): + onnx_utils._get_module_attributes = _get_module_attributes + + +def undo_torch_patches(): + """Undo monkey patches and restore original functions.""" + onnx_utils._setup_trace_module_map = _original_setup_trace_module_map + if _original_get_module_attributes: + onnx_utils._get_module_attributes = _original_get_module_attributes diff --git a/README.md b/README.md index b396daede..cb6f32382 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,9 @@ *Latest news* :fire:
+- [10/2025] Added support for Qwen2.5VL Multi-Model [Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct) +- [10/2025] Added support for Mistral3 Multi-Model [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) +- [10/2025] Added support for Molmo Multi-Model [allenai/Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) - [06/2025] Added support for Llama4 Multi-Model [meta-llama/Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) - [06/2025] Added support for Gemma3 Multi-Modal-Model [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it) - [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1) @@ -90,9 +93,13 @@ python3.10 -m venv qeff_env source qeff_env/bin/activate pip install -U pip -# Clone and Install the QEfficient Repo. +# Clone and Install the QEfficient repository from the mainline branch pip install git+https://github.com/quic/efficient-transformers +# Clone and Install the QEfficient repository from a specific branch, tag or commit by appending @ref +# Release branch (e.g., release/v1.20.0): +pip install "git+https://github.com/quic/efficient-transformers@release/v1.20.0" + # Or build wheel package using the below command. pip install build wheel python -m build --wheel --outdir dist @@ -105,8 +112,8 @@ For more details about using ``QEfficient`` via Cloud AI 100 Apps SDK, visit [Li ## Documentation -* [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html#) -* [Python API](https://quic.github.io/efficient-transformers/source/hl_api.html) +* [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) +* [QEFF API](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) * [Validated Models](https://quic.github.io/efficient-transformers/source/validate.html) * [Models coming soon](https://quic.github.io/efficient-transformers/source/validate.html#models-coming-soon) diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png new file mode 100644 index 000000000..9e58da61d Binary files /dev/null and b/docs/image/girl_laughing.png differ diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 98ec72b7c..f15d8de2f 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -125,10 +125,10 @@ You can pass input prompts in single string but separate with pipe (|) symbol". python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 3 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompt "My name is|The flat earth theory is the belief that|The sun rises from" --mxfp6 --mos 1 --aic_enable_depth_first ``` -You can also pass path of txt file with input prompts when you want to run inference on lot of prompts, Example below, sample txt file(prompts.txt) is present in examples folder. +You can also pass path of txt file with input prompts when you want to run inference on lot of prompts, Example below, sample txt file(prompts.txt) is present in examples/sample_prompts folder. ```bash -python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 3 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompts_txt_file_path examples/prompts.txt --mxfp6 --mos 1 --aic_enable_depth_first +python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 3 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompts_txt_file_path examples/sample_prompts/prompts.txt --mxfp6 --mos 1 --aic_enable_depth_first ``` **QNN CLI Inference Command** @@ -221,4 +221,26 @@ Benchmark the model on Cloud AI 100, run the infer API to print tokens and tok/s tokenizer = AutoTokenizer.from_pretrained(model_name) qeff_model.generate(prompts=["My name is"],tokenizer=tokenizer) ``` + +### Local Model Execution +If the model and tokenizer are already downloaded, we can directly load them from local path. + +```python +from QEfficient import QEFFAutoModelForCausalLM +from transformers import AutoTokenizer + +# Local path to the downloaded model. You can find downloaded HF models in: +# - Default location: ~/.cache/huggingface/hub/models--{model_name}/snapshots/{snapshot_id}/ +local_model_repo = "~/.cache/huggingface/hub/models--gpt2/snapshots/607a30d783dfa663caf39e06633721c8d4cfcd7e" + +# Load model from local path +model = QEFFAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_model_repo) + +model.compile(num_cores=16) + +# Load tokenizer from the same local path +tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=local_model_repo) + +model.generate(prompts=["Hi there!!"], tokenizer=tokenizer) +``` End to End demo examples for various models are available in [**notebooks**](https://github.com/quic/efficient-transformers/tree/main/notebooks) directory. Please check them out. diff --git a/docs/source/release_docs.md b/docs/source/release_docs.md index 79e4bd181..97389e571 100644 --- a/docs/source/release_docs.md +++ b/docs/source/release_docs.md @@ -13,7 +13,7 @@ Welcome to the official release of **Efficient Transformer Library v1.20.0**! Th - Text & Image+Text support - Chunk attention, Single/Dual QPC support - Multi-image prompts enabled via VLLM interface - - [Llama4 Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/llama4_example.py) + - [Llama4 Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text/models/llama_vision/single_image.py) - **Grok-1** - Executable via [`QEffAutoModelForCausalLM`](#QEffAutoModelForCausalLM) @@ -22,7 +22,7 @@ Welcome to the official release of **Efficient Transformer Library v1.20.0**! Th - Executable via [`QEFFAutoModelForImageTextToText`](#QEFFAutoModelForImageTextToText) - Text & Image+Text support - Sliding window support - - [Gemma3 Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/gemma3_example/gemma3_mm.py) + - [Gemma3 Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/image_text_to_text/models/gemma_vision/inference.py) - **SwiftKV (Llama-3.1-SwiftKV-8B-Instruct)** @@ -32,7 +32,7 @@ Welcome to the official release of **Efficient Transformer Library v1.20.0**! Th - **GGUF Models** - Executable via [`QEffAutoModelForCausalLM`](#QEffAutoModelForCausalLM) - Execution support (non-quantized) - - [Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/basic_gguf_models.py) + - [Example Script](https://github.com/quic/efficient-transformers/blob/main/examples/text_generation/gguf_models.py) - **FP8 Compressed Quantization** - Support for [`Llama-3.3-70B-Instruct-FP8-Dynamic`](https://huggingface.co/Infermatic/Llama-3.3-70B-Instruct-FP8-Dynamic) diff --git a/docs/source/supported_features.rst b/docs/source/supported_features.rst index 4177b451f..8260342f2 100644 --- a/docs/source/supported_features.rst +++ b/docs/source/supported_features.rst @@ -6,16 +6,18 @@ Supported Features * - Feature - Impact + * - `Compute Context Length (CCL) `_ + - Optimizes inference by using different context lengths during prefill and decode phases, reducing memory footprint and computation for shorter sequences while maintaining support for longer contexts. Supports both text-only and vision-language models. Refer `sample script `_ for more **details**. * - Sentence embedding, Flexible Pooling configuration and compilation with multiple sequence lengths - - Supports standard/custom pooling with AI 100 acceleration and sentence embedding. Enables efficient sentence embeddings via Efficient-Transformers. Compile with one or multiple seq_len; optimal graph auto-selected at runtime. Refer `sample script `_ for more **details**. + - Supports standard/custom pooling with AI 100 acceleration and sentence embedding. Enables efficient sentence embeddings via Efficient-Transformers. Compile with one or multiple seq_len; optimal graph auto-selected at runtime. Refer `sample script `_ for more **details**. * - `SpD, multiprojection heads `_ - - Implemented post-attention hidden size projections to speculate tokens ahead of the base model. Refer `sample script `_ for more **details**. + - Implemented post-attention hidden size projections to speculate tokens ahead of the base model. Refer `sample script `_ for more **details**. * - `QNN Compilation support `_ - Enabled for AutoModel classes QNN compilation capabilities for multi-models, embedding models and causal models. * - `Disaggregated serving `_ - It support for separate prefill and decode compilation for encoder (vision) and language models. * - `GGUF model execution `_ - - Supported GGUF model execution (without quantized weights). Refer `sample script `_ for more **details**. + - Supported GGUF model execution (without quantized weights). Refer `sample script `_ for more **details**. * - Replication of KV - Enabled FP8 model support on `replicate_kv_heads script `_. * - `gradient checkpointing `_ @@ -23,27 +25,29 @@ Supported Features * - Swift KV `Snowflake/Llama-3.1-SwiftKV-8B-Instruct `_ - Reduces computational overhead during inference by optimizing key-value pair processing, leading to improved throughput. Support for both `continuous and non-continuous batching execution `_ in SwiftKV * - :ref:`Vision Language Model ` - - Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer `sample script `_ for more **details**. + - Provides support for the AutoModelForImageTextToText class from the transformers library, enabling advanced vision-language tasks. Refer `sample script `_ for more **details**. * - :ref:`Speech Sequence to Sequence Model ` - - Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer `sample script `_ for more **details**. + - Provides support for the QEFFAutoModelForSpeechSeq2Seq Facilitates speech-to-text sequence models. Refer `sample script `_ for more **details**. * - Support for FP8 Execution - Enables execution with FP8 precision, significantly improving performance and reducing memory usage for computational tasks. * - Prefill caching - Enhances inference speed by caching key-value pairs for shared prefixes, reducing redundant computations and improving efficiency. + * - On Device Sampling + - Enables sampling operations to be executed directly on the QAIC device rather than the host CPU for QEffForCausalLM models. This enhancement significantly reduces host-device communication overhead and improves inference throughput and scalability. Refer `sample script `_ for more **details**. * - Prompt-Lookup Decoding - - Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer `sample script `_ for more **details**. + - Speeds up text generation by using overlapping parts of the input prompt and the generated text, making the process faster without losing quality. Refer `sample script `_ for more **details**. * - :ref:`PEFT LoRA support ` - - Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer `sample script `_ for more **details**. + - Enables parameter-efficient fine-tuning using low-rank adaptation techniques, reducing the computational and memory requirements for fine-tuning large models. Refer `sample script `_ for more **details**. * - :ref:`QNN support ` - Enables compilation using QNN SDK, making Qeff adaptable for various backends in the future. * - :ref:`Embedding model support ` - Facilitates the generation of vector embeddings for retrieval tasks. * - :ref:`Speculative Decoding ` - - Accelerates text generation by using a draft model to generate preliminary predictions, which are then verified by the target model, reducing latency and improving efficiency. Refer `sample script `_ for more **details**. + - Accelerates text generation by using a draft model to generate preliminary predictions, which are then verified by the target model, reducing latency and improving efficiency. Refer `sample script `_ for more **details**. * - :ref:`Finite lorax ` - - Users can activate multiple LoRA adapters and compile them with the base model. At runtime, they can specify which prompt should use which adapter, enabling mixed adapter usage within the same batch. Refer `sample script `_ for more **details**. + - Users can activate multiple LoRA adapters and compile them with the base model. At runtime, they can specify which prompt should use which adapter, enabling mixed adapter usage within the same batch. Refer `sample script `_ for more **details**. * - Python and CPP Inferencing API support - - Provides flexibility while running inference with Qeff and enabling integration with various applications and improving accessibility for developers. Refer `sample script `_ for more **details**. + - Provides flexibility while running inference with Qeff and enabling integration with various applications and improving accessibility for developers. Refer `sample script `_ for more **details**. * - :ref:`Continuous batching ` - Optimizes throughput and latency by dynamically batching requests, ensuring efficient use of computational resources. * - AWQ and GPTQ support @@ -54,7 +58,5 @@ Supported Features - A script for computing the perplexity of a model, allowing for the evaluation of model performance and comparison across different models and datasets. Refer `sample script `_ for more **details**. * - KV Heads Replication Script - A sample script for replicating key-value (KV) heads for the Llama-3-8B-Instruct model, running inference with the original model, replicating KV heads, validating changes, and exporting the modified model to ONNX format. Refer `sample script `_ for more **details**. - * - Context Length Specializations (upcoming) - - Increases the maximum context length that models can handle, allowing for better performance on tasks requiring long sequences of text. * - Block Attention (in progress) - - Reduces inference latency and computational cost by dividing context into blocks and reusing key-value states, particularly useful in RAG. \ No newline at end of file + - Reduces inference latency and computational cost by dividing context into blocks and reusing key-value states, particularly useful in RAG. diff --git a/docs/source/validate.md b/docs/source/validate.md index e17e85578..b5ab87629 100644 --- a/docs/source/validate.md +++ b/docs/source/validate.md @@ -4,21 +4,21 @@ ## Text-only Language Models ### Text Generation Task -**QEff Auto Class:** [`QEFFAutoModelForCausalLM`](#QEFFAutoModelForCausalLM) +**QEff Auto Class:** `QEFFAutoModelForCausalLM` -| Architecture | Model Family | Representative Models | CB Support | -|-------------------------|--------------------|--------------------------------------------------------------------------------------|------------| -| **FalconForCausalLM** | Falcon | [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) | ✔️ | -| **Qwen3MoeForCausalLM** | Qwen3Moe | [Qwen/Qwen3-30B-A3B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-30B-A3B-Instruct-2507) | ✔️ | +| Architecture | Model Family | Representative Models | [vLLM Support](https://quic.github.io/cloud-ai-sdk-pages/latest/Getting-Started/Installation/vLLM/vLLM/index.html) | +|-------------------------|--------------------|--------------------------------------------------------------------------------------|--------------| +| **FalconForCausalLM** | Falcon** | [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) | ✔️ | +| **Qwen3MoeForCausalLM** | Qwen3Moe | [Qwen/Qwen3-30B-A3B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-30B-A3B-Instruct-2507) | ✕ | | **GemmaForCausalLM** | CodeGemma | [google/codegemma-2b](https://huggingface.co/google/codegemma-2b)
[google/codegemma-7b](https://huggingface.co/google/codegemma-7b) | ✔️ | -| | Gemma | [google/gemma-2b](https://huggingface.co/google/gemma-2b)
[google/gemma-7b](https://huggingface.co/google/gemma-7b)
[google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b)
[google/gemma-2-9b](https://huggingface.co/google/gemma-2-9b)
[google/gemma-2-27b](https://huggingface.co/google/gemma-2-27b) | ✔️ | +| | Gemma*** | [google/gemma-2b](https://huggingface.co/google/gemma-2b)
[google/gemma-7b](https://huggingface.co/google/gemma-7b)
[google/gemma-2-2b](https://huggingface.co/google/gemma-2-2b)
[google/gemma-2-9b](https://huggingface.co/google/gemma-2-9b)
[google/gemma-2-27b](https://huggingface.co/google/gemma-2-27b) | ✔️ | | **GPTBigCodeForCausalLM** | Starcoder1.5 | [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) | ✔️ | | | Starcoder2 | [bigcode/starcoder2-15b](https://huggingface.co/bigcode/starcoder2-15b) | ✔️ | | **GPTJForCausalLM** | GPT-J | [EleutherAI/gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b) | ✔️ | | **GPT2LMHeadModel** | GPT-2 | [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) | ✔️ | | **GraniteForCausalLM** | Granite 3.1 | [ibm-granite/granite-3.1-8b-instruct](https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
[ibm-granite/granite-guardian-3.1-8b](https://huggingface.co/ibm-granite/granite-guardian-3.1-8b) | ✔️ | | | Granite 20B | [ibm-granite/granite-20b-code-base-8k](https://huggingface.co/ibm-granite/granite-20b-code-base-8k)
[ibm-granite/granite-20b-code-instruct-8k](https://huggingface.co/ibm-granite/granite-20b-code-instruct-8k) | ✔️ | -| **InternVLChatModel** | Intern-VL | [OpenGVLab/InternVL2_5-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B) | | +| **InternVLChatModel** | Intern-VL | [OpenGVLab/InternVL2_5-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B) | ✔️ | | | | **LlamaForCausalLM** | CodeLlama | [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf)
[codellama/CodeLlama-13b-hf](https://huggingface.co/codellama/CodeLlama-13b-hf)
[codellama/CodeLlama-34b-hf](https://huggingface.co/codellama/CodeLlama-34b-hf) | ✔️ | | | DeepSeek-R1-Distill-Llama | [deepseek-ai/DeepSeek-R1-Distill-Llama-70B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B) | ✔️ | | | InceptionAI-Adapted | [inceptionai/jais-adapted-7b](https://huggingface.co/inceptionai/jais-adapted-7b)
[inceptionai/jais-adapted-13b-chat](https://huggingface.co/inceptionai/jais-adapted-13b-chat)
[inceptionai/jais-adapted-70b](https://huggingface.co/inceptionai/jais-adapted-70b) | ✔️ | @@ -31,45 +31,42 @@ | **MistralForCausalLM** | Mistral | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) | ✔️ | | **MixtralForCausalLM** | Codestral
Mixtral | [mistralai/Codestral-22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)
[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | ✔️ | | **MPTForCausalLM** | MPT | [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b) | ✔️ | -| **Phi3ForCausalLM** | Phi-3, Phi-3.5 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | ✔️ | +| **Phi3ForCausalLM** | Phi-3**, Phi-3.5** | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | ✔️ | | **QwenForCausalLM** | DeepSeek-R1-Distill-Qwen | [DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | ✔️ | | | Qwen2, Qwen2.5 | [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✔️ | | **LlamaSwiftKVForCausalLM** | swiftkv | [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct) | ✔️ | -| **Grok1ModelForCausalLM** | grok-1 | [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1) | ✔️ | - ---- - +| **Grok1ModelForCausalLM** | grok-1 | [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1) | ✕ | +- ** set "trust-remote-code" flag to True for e2e inference with vLLM +- *** pass "disable-sliding-window" flag for e2e inference of Gemma-2 family of models with vLLM ## Embedding Models ### Text Embedding Task -**QEff Auto Class:** [`QEFFAutoModel`](#QEFFAutoModel) - -| Architecture | Model Family | Representative Models | -|--------------|--------------|---------------------------------| -| **BertModel** | BERT-based | [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)
[BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)
[BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5)
[e5-large-v2](https://huggingface.co/intfloat/e5-large-v2) | -| **LlamaModel** | Llama-based | [intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) | -| **MPNetForMaskedLM** | MPNet | [sentence-transformers/multi-qa-mpnet-base-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1) | -| **MistralModel** | Mistral | [e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) | -| **NomicBertModel** | NomicBERT | [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) | -| **Qwen2ForCausalLM** | Qwen2 | [stella_en_1.5B_v5](https://huggingface.co/NovaSearch/stella_en_1.5B_v5) | -| **RobertaModel** | RoBERTa | [ibm-granite/granite-embedding-30m-english](https://huggingface.co/ibm-granite/granite-embedding-30m-english)
[ibm-granite/granite-embedding-125m-english](https://huggingface.co/ibm-granite/granite-embedding-125m-english) | -| **XLMRobertaForSequenceClassification** | XLM-RoBERTa | [bge-reranker-v2-m3bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | -| **XLMRobertaModel** | XLM-RoBERTa |[ibm-granite/granite-embedding-107m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-107m-multilingual)
[ibm-granite/granite-embedding-278m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-278m-multilingual) | - ---- +**QEff Auto Class:** `QEFFAutoModel` + +| Architecture | Model Family | Representative Models | vLLM Support | +|--------------|--------------|---------------------------------|--------------| +| **BertModel** | BERT-based | [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)
[BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)
[BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5)
[e5-large-v2](https://huggingface.co/intfloat/e5-large-v2) | ✔️ | +| **MPNetForMaskedLM** | MPNet | [sentence-transformers/multi-qa-mpnet-base-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1) | ✕ | +| **MistralModel** | Mistral | [e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) | ✕ | +| **NomicBertModel** | NomicBERT | [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) | ✕ | +| **Qwen2ForCausalLM** | Qwen2 | [stella_en_1.5B_v5](https://huggingface.co/NovaSearch/stella_en_1.5B_v5) | ✔️ | +| **RobertaModel** | RoBERTa | [ibm-granite/granite-embedding-30m-english](https://huggingface.co/ibm-granite/granite-embedding-30m-english)
[ibm-granite/granite-embedding-125m-english](https://huggingface.co/ibm-granite/granite-embedding-125m-english) | ✔️ | +| **XLMRobertaForSequenceClassification** | XLM-RoBERTa | [bge-reranker-v2-m3bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | ✕ | +| **XLMRobertaModel** | XLM-RoBERTa |[ibm-granite/granite-embedding-107m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-107m-multilingual)
[ibm-granite/granite-embedding-278m-multilingual](https://huggingface.co/ibm-granite/granite-embedding-278m-multilingual) | ✔️ | ## Multimodal Language Models ### Vision-Language Models (Text + Image Generation) -**QEff Auto Class:** [`QEFFAutoModelForImageTextToText`](#QEFFAutoModelForImageTextToText) +**QEff Auto Class:** `QEFFAutoModelForImageTextToText` -| Architecture | Model Family | Representative Models | CB Support | Single Qpc Support | Dual Qpc Support | -|-----------------------------|--------------|----------------------------------------------------------------------------------------|------------|--------------------|------------------| -| **LlavaForConditionalGeneration** | LLaVA-1.5 | [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) | ✕ | ✔️ | ✔️ | -| **MllamaForConditionalGeneration** | Llama 3.2 | [meta-llama/Llama-3.2-11B-Vision Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
[meta-llama/Llama-3.2-90B-Vision](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision) | ✕ | ✔️ | ✔️ | -|**LlavaNextForConditionalGeneration** | Granite Vision | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b) | ✕ | ✕ | ✔️ | -|**Llama4ForConditionalGeneration** | Llama-4-Scout | [Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) | ✕ | ✔️ | ✔️ | -|**Gemma3ForConditionalGeneration** | Gemma3 | [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)| ✕ | ✔️ | ✔️ | +| Architecture | Model Family | Representative Models | Qeff Single Qpc | Qeff Dual Qpc | vllm Single Qpc | vllm Dual Qpc | +|------------------------------------|--------------|----------------------------------------------------------------------------------------|------------|---------------------|-------------------|-----------------| +| **LlavaForConditionalGeneration** | LLaVA-1.5 | [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) | ✔️ | ✔️ | ✔️ | ✔️ | +| **MllamaForConditionalGeneration** | Llama 3.2 | [meta-llama/Llama-3.2-11B-Vision Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
[meta-llama/Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct) | ✔️ | ✔️ | ✔️ | ✔️ | +| **LlavaNextForConditionalGeneration** | Granite Vision | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b) | ✕ | ✔️ | ✕ | ✔️ | +| **Llama4ForConditionalGeneration** | Llama-4-Scout | [Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct) | ✔️ | ✔️ | ✔️ | ✔️ | +| **Gemma3ForConditionalGeneration** | Gemma3*** | [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it) | ✔️ | ✔️ | ✔️ | ✕ | +- *** pass "disable-sliding-window" flag for e2e inference with vLLM **Dual QPC:** @@ -85,25 +82,20 @@ In the Dual QPC(Qualcomm Program Container) setup, the model is split across two **Single QPC:** In the single QPC(Qualcomm Program Container) setup, the entire model—including both image encoding and text generation—runs within a single QPC. There is no model splitting, and all components operate within the same execution environment. -**For more details click [here](#QEFFAutoModelForImageTextToText)** -```{NOTE} + +**Note:** The choice between Single and Dual QPC is determined during model instantiation using the `kv_offload` setting. If the `kv_offload` is set to `True` it runs in dual QPC and if its set to `False` model runs in single QPC mode. -``` --- - ### Audio Models (Automatic Speech Recognition) - Transcription Task +**QEff Auto Class:** `QEFFAutoModelForSpeechSeq2Seq` -**QEff Auto Class:** [`QEFFAutoModelForSpeechSeq2Seq`](#QEFFAutoModelForSpeechSeq2Seq) - -| Architecture | Model Family | Representative Models | -|--------------|--------------|----------------------------------------------------------------------------------------| -| **Whisper** | Whisper | [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny)
[openai/whisper-base](https://huggingface.co/openai/whisper-base)
[openai/whisper-small](https://huggingface.co/openai/whisper-small)
[openai/whisper-medium](https://huggingface.co/openai/whisper-medium)
[openai/whisper-large](https://huggingface.co/openai/whisper-large)
[openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) | - ---- +| Architecture | Model Family | Representative Models | vLLM Support | +|--------------|--------------|----------------------------------------------------------------------------------------|--------------| +| **Whisper** | Whisper | [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny)
[openai/whisper-base](https://huggingface.co/openai/whisper-base)
[openai/whisper-small](https://huggingface.co/openai/whisper-small)
[openai/whisper-medium](https://huggingface.co/openai/whisper-medium)
[openai/whisper-large](https://huggingface.co/openai/whisper-large)
[openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) | ✔️ | (models_coming_soon)= # Models Coming Soon diff --git a/examples/CONTRIBUTING.md b/examples/CONTRIBUTING.md new file mode 100644 index 000000000..d7766fa92 --- /dev/null +++ b/examples/CONTRIBUTING.md @@ -0,0 +1,260 @@ +# Contributing Examples + +This guide explains how to add new examples to the QEfficient repository. + +## When to Add an Example + +Add a new example if: +- The model requires special configuration not covered by existing examples +- You're demonstrating a new feature or optimization technique +- The model has unique requirements (dependencies, image sizes, etc.) + +Don't add an example if: +- The model works with existing generic examples (just use those) +- The only difference is the model name, you can include the model name in validated model list and model class readme file. + +## Directory Structure + +Place your example in the appropriate domain: +- `text_generation/` - Text-only language models +- `image_text_to_text/` - Vision-language models +- `embeddings/` - Embedding models +- `audio/` - Speech and audio models +- `peft/` - Fine-tuning and adapter examples +- `performance/` - Optimization techniques + + + +## File Requirements + +### 1. Python Script + +Your example script should: +- Include the copyright header +- Use argparse for command-line arguments +- Provide clear error messages +- Print results in a readable format + +Basic template: +```python +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +from transformers import AutoTokenizer +from QEfficient import QEFFAutoModelForCausalLM + +def main(): + parser = argparse.ArgumentParser(description="Description of what this example does") + parser.add_argument("--model-name", type=str, required=True, help="HuggingFace model ID") + parser.add_argument("--prompt", type=str, default="Hello", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=32) + parser.add_argument("--ctx-len", type=int, default=128) + parser.add_argument("--num-cores", type=int, default=16) + parser.add_argument("--num-devices", type=int, default=1) + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name) + + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=args.num_devices, + ) + + exec_info = model.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + ) + + print(f"Generated: {exec_info.generated_texts[0]}") + +if __name__ == "__main__": + main() +``` + +### 2. README.md + +Each model-specific example needs a README explaining: +- What the model does +- Any special requirements +- How to run it +- Expected output + +Template: +```markdown +# [Model Name] + +## Overview +Brief description of the model and what makes it special. + +## Requirements +```bash +# For single package +pip install package-name==1.2.3 + +# For multiple packages +pip install package-name==1.2.3 another-package==4.5.6 + +# Or use a requirements.txt file +pip install -r requirements.txt +``` + +**Note:** Always specify exact versions to ensure reproducibility. Use `pip show package-name` to check installed versions. + +## Usage +```bash +python inference.py --model-name [model-id] --prompt "Your prompt" +``` + +## Special Notes +Any model-specific considerations, limitations, or configuration details. + +## References +- Model card: [link] +- Paper: [link] (optional) + +## Code Guidelines + +- Use clear variable names +- Add comments for non-obvious code +- Handle errors gracefully +- Follow existing code style in the repository +- Test your example before submitting + +## Testing Your Example + +Before submitting: +1. Run the example with default parameters +2. Test with different model sizes if applicable +3. Verify the README instructions work +4. Check that all dependencies are documented + +## Submitting Your Contribution + +Follow these steps to submit your example to the QEfficient repository: + +### 1. Fork and Clone the Repository + +First, fork the repository to your GitHub account, then clone your fork: + +```bash +# Fork the repository on GitHub (click the "Fork" button) +# Then clone your fork +git clone git@github.com:YOUR_USERNAME/efficient-transformers.git +cd efficient-transformers + +# Add upstream remote to keep your fork in sync +git remote add upstream git@github.com:quic/efficient-transformers.git +``` + +### 2. Create a Feature Branch + +Create a descriptive branch for your changes: + +```bash +# Update your main branch +git checkout main +git pull upstream main + +# Create a new branch +git checkout -b add-[model-name]-example +``` + +### 3. Make Your Changes + +Add your example files following the guidelines above: +- Python script with proper copyright header +- README.md with clear documentation +- requirements.txt (if needed) + +### 4. Run Pre-commit Checks + +Before committing, ensure your code passes all quality checks: + +```bash +# Install pre-commit if not already installed +pip install pre-commit + +# Run pre-commit on your changed files +pre-commit run --files path/to/your/file1.py path/to/your/file2.md +``` + +**Important:** If pre-commit reports any failures: +- Some issues will be auto-fixed (formatting, trailing whitespace, etc.) +- For issues that aren't auto-fixed, manually correct them +- Re-run `pre-commit run --files ` until all checks pass + +### 5. Commit with Sign-off (DCO) + +All commits must be signed off to comply with the Developer Certificate of Origin (DCO): + +```bash +# Stage your changes +git add examples/your_domain/your_example.py +git add examples/your_domain/README.md + +# Commit with sign-off +git commit -s --author "Your Name " -m "Add [model-name] example + +- Implements inference for [model-name] +- Includes documentation and usage examples +- Tested with [specific configurations]" +``` + +**Commit Message Guidelines:** +- Use a clear, descriptive title +- Add a blank line, then detailed description if needed +- Always include the `-s` flag for DCO sign-off + +### 6. Push to Your Fork + +Push your branch to your forked repository: + +```bash +git push origin add-[model-name]-example +``` + +### 7. Create a Pull Request + +1. Go to your fork on GitHub +2. Click "Compare & pull request" for your branch +3. Fill out the PR template with: + - **Title:** Clear, descriptive title (e.g., "Add Llama-3.2-Vision example") + - **Description:** + - What the example demonstrates + - Why it's needed (what makes it different from existing examples) + - Any special testing considerations + - Link to model card or documentation + - **Testing:** Describe how you tested the example + +### 8. Ensure CI Checks Pass + +After creating the PR, verify that all automated checks pass: + +- ✅ **DCO Check:** Ensures all commits are signed off +- ✅ **Lint Check:** Code style and formatting validation +- ✅ **Tests:** Automated test suite (if applicable) + +If any checks fail: +1. Review the error messages in the PR +2. Make necessary fixes in your local branch +3. Commit and push the fixes (with sign-off) +4. The PR will automatically update and re-run checks + +### 9. Address Review Feedback + +Maintainers will review your PR and may request changes: +- Make requested changes in your local branch +- Commit with sign-off and push to update the PR +- Respond to comments to facilitate discussion + +## Questions + +For questions or issues, open a GitHub issue or discussion. diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..3913b25ce --- /dev/null +++ b/examples/README.md @@ -0,0 +1,97 @@ +# QEfficient Examples + +Examples for running models on Qualcomm Cloud AI 100. + +For detailed documentation, see https://quic.github.io/efficient-transformers/ + +## Quick Navigation + +### Text Generation +Language model inference. + +| Example | Description | Script | +|---------|-------------|--------| +| Basic Inference | Simple text generation | [text_generation/basic_inference.py](text_generation/basic_inference.py) | +| GGUF Models | GGUF format support | [text_generation/gguf_models.py](text_generation/gguf_models.py) | +| MoE Models | Mixture of Experts | [text_generation/moe_inference.py](text_generation/moe_inference.py) | +| Continuous Batching | Dynamic batching | [text_generation/continuous_batching.py](text_generation/continuous_batching.py) | + +[See all text generation examples →](text_generation/) + +### Image-Text-to-Text +Vision-language models. + +| Example | Model | Script | +|---------|---------------|---------------| +| Basic VLM | Most VLMs | [image_text_to_text/basic_vlm_inference.py](image_text_to_text/basic_vlm_inference.py) | + +[See all vision-language examples →](image_text_to_text/) + +### Embeddings +Sentence and document embeddings. + +| Example | Model | Script | +|---------|-------|--------| +| Text Embeddings | all-MiniLM-L6-v2 | [embeddings/text_embeddings.py](embeddings/text_embeddings.py) | + +[See all embedding examples →](embeddings/) + +### Audio +Speech processing models. + +| Example | Model | Task | Script | +|---------|-------|------|--------| +| Speech-to-Text | Whisper | Transcription | [audio/speech_to_text.py](audio/speech_to_text.py) | +| CTC Speech Recognition | Wav2Vec2 | Recognition | [audio/wav2vec2_inference.py](audio/wav2vec2_inference.py) | + +[See all audio examples →](audio/) + +### PEFT +Parameter-efficient fine-tuning. + +| Example | Description | Script | +|---------|-------------|--------| +| Single Adapter | Load and use one adapter | [peft/single_adapter.py](peft/single_adapter.py) | +| Multi-Adapter | Multiple adapters with CB | [peft/multi_adapter.py](peft/multi_adapter.py) | + +**Note:** PEFT examples use hardcoded configurations to demonstrate specific adapter workflows. Modify the scripts directly to test different adapters or configurations. + +[See all PEFT examples →](peft/) + +### Performance +Optimization techniques. + +| Example | Technique | Script | +|---------|-----------|--------| +| Draft-based SpD | Speculative decoding | [performance/speculative_decoding/draft_based.py](performance/speculative_decoding/draft_based.py) | +| Prompt Lookup | N-gram speculation | [performance/speculative_decoding/prompt_lookup.py](performance/speculative_decoding/prompt_lookup.py) | +| Multi-Projection | Turbo models | [performance/speculative_decoding/multi_projection.py](performance/speculative_decoding/multi_projection.py) | +| On-Device Sampling | Sampling parameters | [performance/on_device_sampling.py](performance/on_device_sampling.py) | +| Compute Context Length | Dynamic context optimization | [performance/compute_context_length/basic_inference.py](performance/compute_context_length/basic_inference.py) | +| C++ Execution | Native C++ API | [performance/cpp_execution/](performance/cpp_execution/) | + +[See all performance examples →](performance/) + +## Installation + +For installation instructions, see the [Quick Installation guide](../README.md#quick-installation) in the main README. + + +## Running Examples + +### Python Scripts + +Basic usage: +```bash +python text_generation/basic_inference.py \ + --model-name gpt2 \ + --prompt "Hello, how are you?" +``` + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on adding new examples. + +## Documentation + +Full documentation: https://quic.github.io/efficient-transformers/ diff --git a/examples/audio/README.md b/examples/audio/README.md new file mode 100644 index 000000000..df0204d87 --- /dev/null +++ b/examples/audio/README.md @@ -0,0 +1,87 @@ +# Audio Examples + +Examples for running audio processing models on Qualcomm Cloud AI 100. + +## Dependencies + +Install required packages: +```bash +pip install librosa==0.10.2 soundfile==0.13.1 +``` + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Supported Models + +**QEff Auto Classes:** +- `QEFFAutoModelForSpeechSeq2Seq` (for Whisper models) +- `QEFFAutoModelForCTC` (for Wav2Vec2 models) + +For the complete list of supported audio models, see the [Validated Models - Audio Section](../../docs/source/validate.md#audio-models). + +Popular models include: +- Whisper (tiny, base, small, medium, large, large-v3-turbo) +- Wav2Vec2 (base-960h) + +## Available Examples + +### speech_to_text.py +Speech-to-text transcription using Whisper models. + +**Usage:** +```bash +# With default parameters +python speech_to_text.py \ + +# With custom parameters +python speech_to_text.py \ + --model-name openai/whisper-tiny \ + --ctx-len 25 \ + --num-cores 16 +``` + +**Parameters:** +- `--model-name`: HuggingFace Whisper model ID (default: `openai/whisper-tiny`) +- `--ctx-len`: Context length for generation (default: `25`) +- `--num-cores`: Number of cores (default: `16`) + +This example: +- Loads a sample audio from the librispeech dataset +- Uses Whisper-tiny model by default +- Compiles and runs inference on Cloud AI 100 +- Outputs the transcribed text + +### wav2vec2_inference.py +Speech recognition using Wav2Vec2 models with CTC (Connectionist Temporal Classification). + +**Usage:** +```bash +# With default parameters +python wav2vec2_inference.py + +# With custom parameters +python wav2vec2_inference.py \ + --model-name facebook/wav2vec2-base-960h \ + --num-cores 16 +``` + +**Parameters:** +- `--model-name`: HuggingFace CTC model ID (default: `facebook/wav2vec2-base-960h`) +- `--num-cores`: Number of cores (default: `16`) + +This example: +- Loads a sample audio from the librispeech dataset +- Uses Wav2Vec2-base-960h model by default +- Compiles and runs inference on Cloud AI 100 +- Outputs the recognized text + +## Documentation + +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Validated Audio Models](https://quic.github.io/efficient-transformers/source/validate.html#audio-models) +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) diff --git a/examples/audio/speech_to_text.py b/examples/audio/speech_to_text.py new file mode 100644 index 000000000..9f1df19aa --- /dev/null +++ b/examples/audio/speech_to_text.py @@ -0,0 +1,66 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from datasets import load_dataset +from transformers import AutoProcessor + +from QEfficient import QEFFAutoModelForSpeechSeq2Seq + + +def main(): + parser = argparse.ArgumentParser(description="Speech-to-text inference with Whisper") + parser.add_argument( + "--model-name", + type=str, + default="openai/whisper-tiny", + help="HuggingFace Whisper model ID", + ) + parser.add_argument( + "--ctx-len", + type=int, + default=25, + help="Context length for generation", + ) + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + args = parser.parse_args() + + print(f"Loading Whisper model: {args.model_name}") + + ## STEP 1 -- load audio sample + + # Using a standard english dataset + print("Loading audio sample from dataset...") + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample_rate = ds[0]["audio"]["sampling_rate"] + data = ds[0]["audio"]["array"] + + # Reshape so shape corresponds to data with batch size 1 + data = data.reshape(-1) + + # Load processor + processor = AutoProcessor.from_pretrained(args.model_name) + + ## STEP 2 -- init base model + qeff_model = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(args.model_name) + + ## STEP 3 -- export and compile model + qeff_model.compile(num_cores=args.num_cores) + + ## STEP 4 -- generate output for loaded input and processor + exec_info = qeff_model.generate( + inputs=processor(data, sampling_rate=sample_rate, return_tensors="pt"), generation_len=args.ctx_len + ) + + ## STEP 5 -- use processor to decode output + transcription = processor.batch_decode(exec_info.generated_ids)[0] + print(f"\nTranscription: {transcription}") + + +if __name__ == "__main__": + main() diff --git a/examples/audio/wav2vec2_inference.py b/examples/audio/wav2vec2_inference.py new file mode 100644 index 000000000..9d310b1c2 --- /dev/null +++ b/examples/audio/wav2vec2_inference.py @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from datasets import load_dataset +from transformers import AutoProcessor + +from QEfficient import QEFFAutoModelForCTC + + +def main(): + parser = argparse.ArgumentParser(description="CTC speech recognition inference with Wav2Vec2") + parser.add_argument( + "--model-name", + type=str, + default="facebook/wav2vec2-base-960h", + help="HuggingFace CTC model ID (e.g., Wav2Vec2)", + ) + + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + args = parser.parse_args() + + print(f"Loading CTC model: {args.model_name}") + + ## STEP 1 -- load audio sample + # Using a standard english dataset + print("Loading audio sample from dataset...") + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + data = ds[0]["audio"]["array"] + + # Reshape so shape corresponds to data with batch size 1 + data = data.reshape(-1) + + # Load processor + processor = AutoProcessor.from_pretrained(args.model_name) + + ## STEP 2 -- Load the model + model = QEFFAutoModelForCTC.from_pretrained(args.model_name) + + ## STEP 3 -- Compile the model + model.compile(num_cores=args.num_cores) + + ## STEP 4 -- Run the model and generate the output + model_output = model.generate(processor, inputs=data) + print(f"\nTranscription: {model_output}") + + +if __name__ == "__main__": + main() diff --git a/examples/basic_gguf_models.py b/examples/basic_gguf_models.py deleted file mode 100644 index 84fc73059..000000000 --- a/examples/basic_gguf_models.py +++ /dev/null @@ -1,23 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -# This is the work example of the GGUF models with the AI 100 - -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM - -# Load the model and tokenizer -model_name = "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF" -gguf_file = "Mistral-7B-Instruct-v0.3.fp16.gguf" -# org_model_name = "mistralai/Mistral-7B-Instruct-v0.3" - -tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file) -model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file) - -generated_qpc_path = model.compile(prefill_seq_len=32, ctx_len=128, num_cores=16, num_devices=1) -model.generate(prompts=["How are you?"], tokenizer=tokenizer) diff --git a/examples/diffusers/flux/README.md b/examples/diffusers/flux/README.md new file mode 100644 index 000000000..2a3c1605f --- /dev/null +++ b/examples/diffusers/flux/README.md @@ -0,0 +1,243 @@ +# FLUX.1-schnell Image Generation Examples + +This directory contains examples demonstrating how to use the QEffFluxPipeline to generate images using the FLUX.1-schnell model from Black Forest Labs. + +## Overview + +FLUX.1-schnell is a fast, distilled version of the FLUX.1 text-to-image model optimized for speed with minimal quality loss. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient image generation. + +## Files + +- **`flux_1_schnell.py`** - Basic example showing simple image generation +- **`flux_1_shnell_custom.py`** - Advanced example with customization options +- **`flux_config.json`** - Configuration file for pipeline modules + +## Quick Start + +### Basic Usage + +The simplest way to generate images with FLUX.1-schnell: + +```python +from QEfficient import QEffFluxPipeline +import torch + +# Initialize pipeline +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate image +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Save image +output.images[0].save("girl_laughing.png") +``` + +Run the basic example: +```bash +python flux_1_schnell.py +``` + +## Advanced Customization + +The `flux_1_shnell_custom.py` example demonstrates several advanced features: + +### 1. Custom Model Components + +You can provide custom text encoders, transformers, and tokenizers: + +```python +pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + text_encoder=custom_text_encoder, + transformer=custom_transformer, + tokenizer=custom_tokenizer, +) +``` + +### 2. Custom Scheduler + +Replace the default scheduler with your own: + +```python +pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) +``` + +### 3. Reduce Model Layers for Faster Inference + +Trade quality for speed by reducing transformer blocks: + +```python +original_blocks = pipeline.transformer.model.transformer_blocks +org_single_blocks = pipeline.transformer.model.single_transformer_blocks +pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +pipeline.transformer.model.config['num_layers'] = 1 +pipeline.transformer.model.config['num_single_layers'] = 1 +``` + +### 4. Pre-compile with Custom Configuration + +Compile the model separately before generation: + +```python +pipeline.compile( + compile_config="examples/diffusers/flux/flux_config.json", + height=512, + width=512, + use_onnx_subfunctions=False +) +``` + +### 5. Runtime Configuration + +Use custom configuration during generation: + +```python +output = pipeline( + prompt="A girl laughing", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) +``` + +Run the advanced example: +```bash +python flux_1_shnell_custom.py +``` + +## Configuration File + +The `flux_config.json` file controls compilation and execution settings for each pipeline module: + +### Module Structure + +The configuration includes four main modules: + +1. **text_encoder** (CLIP) - Encodes text prompts (77 token sequence) +2. **text_encoder_2** (T5) - Secondary text encoder (256 token sequence) +3. **transformer** - Core diffusion transformer model +4. **vae_decoder** - Decodes latents to images + +### Configuration Parameters + +Each module has three sections: + +#### Specializations +- `batch_size`: Batch size for inference +- `seq_len`: Sequence length for text encoders +- `steps`: Number of inference steps (transformer only) +- `channels`: Number of channels (VAE decoder only) + +#### Compilation +- `onnx_path`: Path to pre-exported ONNX model (null for auto-export) +- `compile_dir`: Directory for compiled artifacts (null for auto-generation) +- `mdp_ts_num_devices`: Number of devices for model data parallelism +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use +- `mos`: Multi-output streaming (transformer only) +- `mdts-mos`: Multi-device tensor slicing with MOS (transformer only) +- `aic-enable-depth-first`: Enable depth-first compilation (VAE only) + +#### Execute +- `device_ids`: List of device IDs to use (null for auto-selection) + +### Example Configuration Snippet + +```json +{ + "transformer": { + "specializations": { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": { + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": { + "device_ids": null + } + } +} +``` + +## Key Parameters + +### Generation Parameters + +- **`prompt`** (str): Text description of the image to generate +- **`height`** (int): Output image height in pixels (default: 1024) +- **`width`** (int): Output image width in pixels (default: 1024) +- **`guidance_scale`** (float): Classifier-free guidance scale (0.0 for schnell) +- **`num_inference_steps`** (int): Number of denoising steps (4 recommended for schnell) +- **`max_sequence_length`** (int): Maximum text sequence length (256 recommended) +- **`generator`** (torch.Generator): Random seed for reproducibility +- **`parallel_compile`** (bool): Enable parallel compilation of modules +- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export (experimental) + +### Performance Tuning + +- **Faster inference**: Reduce `num_inference_steps` or model layers +- **Better quality**: Increase `num_inference_steps` or use full model +- **Memory optimization**: Adjust `mdp_ts_num_devices` in config +- **Precision trade-offs**: Toggle `mxfp6_matmul` and `convert_to_fp16` + +## Output + +The pipeline returns an output object containing: +- `images`: List of generated PIL Image objects +- Performance metrics (timing information) + +Example output: +```python +print(output) # Displays performance information +image = output.images[0] # Access the generated image +image.save("output.png") # Save to disk +``` + +## Hardware Requirements + +- Qualcomm Cloud AI 100 accelerator +- Sufficient memory for model compilation and execution +- Multiple devices recommended for optimal transformer performance (see `mdp_ts_num_devices`) + +## Notes + +- FLUX.1-schnell is optimized for 4-step generation with `guidance_scale=0.0` +- The transformer module benefits most from multi-device parallelism +- ONNX subfunctions (`use_onnx_subfunctions=True`) is experimental and may improve compile time but is not recommended for production use +- Custom configurations allow fine-tuning for specific hardware setups + +## Troubleshooting + +- **Out of memory**: Reduce image dimensions or increase `mdp_ts_num_devices` +- **Slow compilation**: Enable `parallel_compile=True` +- **Quality issues**: Ensure using recommended parameters (4 steps, guidance_scale=0.0) +- **Device errors**: Check `device_ids` in config or set to `null` for auto-selection + +## References + +- [FLUX.1 Model Card](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +- [QEfficient Documentation](../../../README.md) +- [Diffusers Pipeline Guide](../../README.md) diff --git a/examples/diffusers/flux/flux_1_schnell.py b/examples/diffusers/flux/flux_1_schnell.py new file mode 100644 index 000000000..46f26bb6b --- /dev/null +++ b/examples/diffusers/flux/flux_1_schnell.py @@ -0,0 +1,45 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1-schnell Image Generation Example + +This example demonstrates how to use the QEffFluxPipeline to generate images +using the FLUX.1-schnell model from Black Forest Labs. FLUX.1-schnell is a +fast, distilled version of the FLUX.1 text-to-image model optimized for +speed with minimal quality loss. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# Initialize the FLUX.1-schnell pipeline from pretrained weights +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + +# Generate an image from a text prompt +# use_onnx_subfunctions=True enables ONNX-based optimizations for faster compilation +output = pipeline( + prompt="A laughing girl", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Extract the generated image from the output +image = output.images[0] + +# Save the generated image to disk +image.save("girl_laughing.png") + +# Print the output object (contains perf info) +print(output) diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py new file mode 100644 index 000000000..201ebe659 --- /dev/null +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -0,0 +1,113 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +FLUX.1 Schnell Custom Configuration Example + +This example demonstrates how to customize the FLUX.1 model with various options: +1. Custom image dimensions (height/width) +2. Custom transformer model and text encoder +3. Custom scheduler configuration +4. Reduced model layers for faster inference +5. Custom compilation settings +6. Custom runtime configuration via JSON config file + +Use this example to learn how to fine-tune FLUX.1 for your specific needs. +""" + +import torch + +from QEfficient import QEffFluxPipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization with default parameters +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") +# Option 2: Advanced initialization with custom modules +# Uncomment and modify to use your own custom components: +# +# pipeline = QEffFluxPipeline.from_pretrained( +# "black-forest-labs/FLUX.1-schnell", +# text_encoder=custom_text_encoder, # Your custom CLIP text encoder +# transformer=custom_transformer, # Your custom transformer model +# tokenizer=custom_tokenizer, # Your custom tokenizer +# ) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# Uncomment to use a custom scheduler (e.g., different sampling methods): +# +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up image generation. +# +# Trade-off: Faster inference but potentially lower image quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only the first transformer block: +# +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +# pipeline.transformer.model.config['num_layers'] = 1 +# pipeline.transformer.model.config['num_single_layers'] = 1 + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# Pre-compile the model for optimized performance on target hardware. +# +# When to use: +# - When you want to compile the model separately before generation +# - When you need to skip image generation and only prepare the model +# +# NOTE-1: If compile_config is not specified, the default configuration from +# QEfficient/diffusers/pipelines/flux/flux_config.json will be used +# +# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations (Experimental so not recommended) +# This feature improves export performance by breaking down the model into smaller, +# more manageable ONNX functions, which can lead to improve compile time. +# Uncomment to compile with a custom configuration: +# pipeline.compile( +# compile_config="examples/diffusers/flux/flux_config.json", +# height=512, +# width=512, +# use_onnx_subfunctions=False +# ) + +# ============================================================================ +# IMAGE GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate an image using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +output = pipeline( + prompt="A laughing girl", + custom_config_path="examples/diffusers/flux/flux_config.json", + height=1024, + width=1024, + guidance_scale=0.0, + num_inference_steps=4, + max_sequence_length=256, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +image = output.images[0] +# Save the generated image to disk +image.save("laughing_girl.png") +print(output) diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json new file mode 100644 index 000000000..73b92265f --- /dev/null +++ b/examples/diffusers/flux/flux_config.json @@ -0,0 +1,99 @@ +{ + "description": "Default configuration for Flux pipeline", + + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "compile_only": true + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 4, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16, + "aic-enable-depth-first": true, + "compile_only":true + }, + "execute": + { + "device_ids": null + } + } + } +} diff --git a/examples/diffusers/wan/README.md b/examples/diffusers/wan/README.md new file mode 100644 index 000000000..b90bf3908 --- /dev/null +++ b/examples/diffusers/wan/README.md @@ -0,0 +1,249 @@ +# WAN 2.2 Text-to-Video Generation Examples + +This directory contains examples demonstrating how to use the QEffWanPipeline to generate videos using the WAN 2.2 text-to-video model with Lightning LoRA optimization. + +## Overview + +WAN 2.2 is a text-to-video diffusion model that uses dual-stage processing for high-quality video generation. These examples show how to leverage Qualcomm Cloud AI 100 acceleration for efficient video generation with Lightning LoRA for fast 4-step inference. + +## Files + +- **`wan_lightning.py`** - Complete example with Lightning LoRA for fast video generation +- **`wan_config.json`** - Configuration file for transformer module compilation + +## Quick Start + +### Basic Usage + +The simplest way to generate videos with WAN 2.2 Lightning: +### 1. Load Model +```python +from QEfficient import QEffWanPipeline +import torch +from diffusers.utils import export_to_video + +# Initialize pipeline +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") +``` + +### 2. Lightning LoRA Integration + +Load high and low noise LoRA adapters for fast 4-step generation: + +```python +from huggingface_hub import hf_hub_download +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +import safetensors.torch + +# Download Lightning LoRAs +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + +# Load and apply LoRAs +def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) +``` + + +### 3. Compile API + +To compile the model for desired resolution: + +```python +# Compile with custom configuration +pipeline.compile( + compile_config="examples/diffusers/wan/wan_config.json", + parallel=True, + height=480, + width=832, + num_frames=81, + use_onnx_subfunctions=False, +) +``` + +### 4. Generate video +```python +output = pipeline( + prompt="A cat playing in a sunny garden", + num_frames=81, + height=480, + width=832, + guidance_scale=1.0, + num_inference_steps=4, + generator=torch.manual_seed(42), + parallel_compile=True, + use_onnx_subfunctions=False, +) + +# Export video +frames = output.images[0] +export_to_video(frames, "cat_garden.mp4", fps=16) +``` + +Run the Lightning example: +```bash +python wan_lightning.py +``` + +## Advanced Customization + + +### 1. Reduce Model Layers for Faster Inference + + +```python +# Reduce to 2 layers for faster inference +pipeline.transformer.model.transformer_high.config.num_layers = 2 +pipeline.transformer.model.transformer_low.config.num_layers = 2 + +original_blocks = pipeline.transformer.model.transformer_high.blocks +org_blocks = pipeline.transformer.model.transformer_low.blocks + +pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)] +) +pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( + [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)] +) +``` + +### 2. To Run with Blocking + +Use environment variables to enable attention blocking: + +```bash +# For 180p Generation (192x320) with HKV blocking +ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python wan_lightning.py + +# For 480p Generation (480x832) with HQKV blocking +ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python wan_lightning.py + +# for 720P Generation (720x1280) with HQKV blocking +ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python wan_lightning.py +``` + +### Blocking Modes + +Head blocking is common in all modes + +- **`kv`**: Block key-value processing (along with Head blocking) +- **`q`**: Block query processing (along with Head blocking) +- **`qkv`**: Block query, key, and value (along with Head blocking) +- **`default`**: Head-only blocking + + +## Configuration File + +The `wan_config.json` file controls compilation settings for the transformer module: + +### Module Structure + +The configuration includes dual specializations for WAN's high and low noise models: + +```json +{ + "transformer": { + "specializations":[ + { + "batch_size":"1", + "cl":"5040", + "latent_height":"24", + "latent_width":"40", + "model_type":"1", + "num_channels":"16", + "num_frames":"21", + "sequence_length":"512", + "steps":"1" + }, + { + "batch_size":"1", + "cl":"5040", + "latent_height":"24", + "latent_width":"40", + "model_type":"2", + "num_channels":"16", + "num_frames":"21", + "sequence_length":"512", + "steps":"1" + } + ] +} +} +``` + +### Configuration Parameters + +#### Specializations +- `batch_size`: Batch size for inference +- `num_channels`: Number of latent channels (16 for WAN) +- `num_frames`: Number of latent frames (21 for 81 input frames) +- `latent_height`/`latent_width`: Latent space dimensions +- `cl`: Compressed latent dimension for transformer +- `sequence_length` : Sequence length of text encoder 512 +- `model_type`: 1 for high noise model, 2 for low noise model + +#### Compilation +- `mdp_ts_num_devices`: Number of devices for model parallelism (16 recommended) +- `mxfp6_matmul`: Enable MXFP6 quantization for matrix multiplication +- `convert_to_fp16`: Convert model to FP16 precision +- `aic_num_cores`: Number of AI cores to use (16 recommended) +- `mos`: Degree of weight splitting done across cores (1 is recommended) +- `mdts_mos`: Degree of weight splitting done across multi-device tensor slices (1 is recommended) + +## Key Parameters + +### Generation Parameters + +- **`prompt`** (str): Text description of the video to generate +- **`num_frames`** (int): Number of video frames (default: 81) +- **`height`** (int): Output video height in pixels (default: 480) +- **`width`** (int): Output video width in pixels (default: 832) +- **`guidance_scale`** (float): Guidance scale for high noise stage (1.0 for Lightning) +- **`guidance_scale_2`** (float): Guidance scale for low noise stage (1.0 for Lightning) +- **`num_inference_steps`** (int): Number of denoising steps (4 for Lightning) +- **`generator`** (torch.Generator): Random seed for reproducibility +- **`parallel_compile`** (bool): Enable parallel compilation of modules +- **`use_onnx_subfunctions`** (bool): Enable ONNX modular export + + +## Output + +The pipeline returns an output object containing: +- `images`: List of video frames as PIL Image objects +- Performance metrics (timing information) + +Example output: +```python +print(output) # Displays performance information +frames = output.images[0] # Access the generated video frames +export_to_video(frames, "output.mp4", fps=16) # Export to MP4 +``` + +## Notes + +- WAN 2.2 Lightning is optimized for 4-step generation with `guidance_scale=1.0` +- The transformer uses dual-stage processing (high/low noise models) +- Attention blocking is essential for higher resolutions (480p+) + + +## References + +- [WAN 2.2 Model Card](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) +- [Lightning LoRA](https://huggingface.co/lightx2v/Wan2.2-Lightning) +- [QEfficient Documentation](../../../README.md) diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json new file mode 100644 index 000000000..7e752ba14 --- /dev/null +++ b/examples/diffusers/wan/wan_config.json @@ -0,0 +1,37 @@ +{ + "description": "Default configuration for Wan pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)", + "model_type": "wan", + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 16, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +} \ No newline at end of file diff --git a/examples/diffusers/wan/wan_lightning.py b/examples/diffusers/wan/wan_lightning.py new file mode 100644 index 000000000..691da651f --- /dev/null +++ b/examples/diffusers/wan/wan_lightning.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# Load the pipeline +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + +# Download the LoRAs +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + + +# LoRA conversion +def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# Load into the transformers +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + +prompt = "In a warmly lit living room, an elderly man with gray hair sits in a wooden armchair adorned with a blue cushion. He wears a gray cardigan over a white shirt, engrossed in reading a book. As he turns the pages, he subtly adjusts his posture, ensuring his glasses stay in place. He then removes his glasses, holding them in his hand, and turns his head to the right, maintaining his grip on the book. The soft glow of a bedside lamp bathes the scene, creating a calm and serene atmosphere, with gentle shadows enhancing the intimate setting." + +output = pipeline( + prompt=prompt, + num_frames=81, + guidance_scale=1.0, + guidance_scale_2=1.0, + num_inference_steps=4, + generator=torch.manual_seed(0), + custom_config_path="examples/diffusers/wan/wan_config.json", + height=480, + width=832, + use_onnx_subfunctions=True, + parallel_compile=True, +) +frames = output.images[0] +export_to_video(frames, "output_t2v.mp4", fps=16) +print(output) diff --git a/examples/diffusers/wan/wan_lightning_custom.py b/examples/diffusers/wan/wan_lightning_custom.py new file mode 100644 index 000000000..a60d57bb6 --- /dev/null +++ b/examples/diffusers/wan/wan_lightning_custom.py @@ -0,0 +1,162 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Wan2.2-Lightning Custom Configuration Example + +This example demonstrates how to customize the Wan2.2-Lightning model with various options: +1. Custom video dimensions (height/width) and frame count +2. Custom scheduler configuration +3. Reduced model layers for faster inference +4. Custom compilation settings +5. Custom runtime configuration via JSON config file +6. LoRA adapter loading and configuration + +Use this example to learn how to tune Wan2.2-Lightning for your specific video generation needs. +""" + +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# ============================================================================ +# PIPELINE INITIALIZATION WITH CUSTOM PARAMETERS +# ============================================================================ + +# Option 1: Basic initialization with default parameters +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + +# ============================================================================ +# LORA ADAPTER LOADING FOR LIGHTNING MODEL +# ============================================================================ +# Download and load Lightning LoRA adapters for faster inference + +# Download the LoRAs from Hugging Face Hub +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + + +# LoRA conversion utility function +def load_wan_lora(path: str): + """Convert and load WAN LoRA weights from safetensors format.""" + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# Load LoRA adapters into the high and low noise transformers +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + +# ============================================================================ +# OPTIONAL: CUSTOM SCHEDULER CONFIGURATION +# ============================================================================ +# Uncomment to use a custom scheduler (e.g., different sampling methods): +# +# pipeline.scheduler = custom_scheduler.from_config(pipeline.scheduler.config) + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up video generation. +# +# Trade-off: Faster inference but potentially lower video quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only a subset of transformer layers: +# +# # Configure for 2-layer model (faster inference) +# pipeline.transformer.model.transformer_high.config.num_layers = 1 +# pipeline.transformer.model.transformer_low.config.num_layers = 1 +# +# # Reduce high noise transformer blocks +# original_blocks = pipeline.transformer.model.transformer_high.blocks +# pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( +# [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config.num_layers)] +# ) +# +# # Reduce low noise transformer blocks +# org_blocks = pipeline.transformer.model.transformer_low.blocks +# pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( +# [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.num_layers)] +# ) + +# ============================================================================ +# OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION +# ============================================================================ +# Pre-compile the model for optimized performance on target hardware. +# +# When to use: +# - When you want to compile the model separately before generation +# - When you need to skip video generation and only prepare the model +# +# NOTE-1: If compile_config is not specified, the default configuration from +# QEfficient/diffusers/pipelines/wan/wan_config.json will be used +# +# NOTE-2: use_onnx_subfunctions=True enables modular ONNX export optimizations +# This feature improves export performance by breaking down the model into smaller, +# more manageable ONNX functions, which can lead to improved compile time. +# +# Uncomment to compile with a custom configuration: +# pipeline.compile( +# compile_config="examples/diffusers/wan/wan_config.json", +# parallel=True, +# height=480, +# width=832, +# num_frames=81, +# use_onnx_subfunctions=True +# ) + +# ============================================================================ +# VIDEO GENERATION WITH CUSTOM RUNTIME CONFIGURATION +# ============================================================================ +# Generate a video using the configured pipeline. +# +# Note: Use of custom_config_path provides flexibility to set device_ids for each +# module, so you can skip the separate pipeline.compile() step. + +# Custom prompt for video generation +prompt = "A cat wearing a hat walking through a magical forest with glowing mushrooms and fireflies dancing around, cinematic lighting, high quality" + +# Alternative video dimensions for different use cases, corresponding default blocking +# height=192, width=320 # ATTENTION_BLOCKING_MODE=kv head_block_size=16 num_kv_blocks=3 python3 examples/diffusers/wan/wan_lightning.py +# height=480, width=832 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=21 num_q_blocks=2 python3 examples/diffusers/wan/wan_lightning.py +# height=720, width=1280 # ATTENTION_BLOCKING_MODE=qkv head_block_size=16 num_kv_blocks=48 num_q_blocks=5 python3 examples/diffusers/wan/wan_lightning.py + +output = pipeline( + prompt=prompt, + num_frames=81, # Number of video frames to generate + guidance_scale=1.0, # Primary guidance scale + guidance_scale_2=1.0, # Secondary guidance scale for dual guidance + num_inference_steps=4, # Lightning model uses fewer steps + generator=torch.manual_seed(42), # For reproducible results + custom_config_path="examples/diffusers/wan/wan_config.json", + height=480, + width=832, + use_onnx_subfunctions=True, # Enable ONNX optimizations + parallel_compile=False, # Set to True for parallel compilation +) + +# Extract generated frames and export to video +frames = output.images[0] +export_to_video(frames, "custom_wan_lightning_output.mp4", fps=16) +print(output) diff --git a/examples/disagg_serving/README.md b/examples/disagg_serving/README.md new file mode 100644 index 000000000..fcf665357 --- /dev/null +++ b/examples/disagg_serving/README.md @@ -0,0 +1,31 @@ +# We should be using disaggragate serving for GPTOSS model for best performance + - GPT-OSS model has 128/4 for 120b and 32/4 ratio of total_experts/experts_per_tok + - We use read all experts only once always strategy in prefill-only model + - And we treat weights activtions meaning read only chosen experts for decode-only model + +# Prefill-only model +## Blocking default behviour when `prefill_only=True` in compile API + - NUM_Q_BLOCKS= set number of Q blocks in attention + - NUM_FFN_BLOCKS= set number of blocks in FFN + - ENABLE_OPT_SWA="0" or "1" to enable/disable optimized SWA. when enabled we will be using only valid KVs for given block in Attention reducing MACs + - prefix_caching is not supported with this mode + +## Chunking pass `enable_chunking=True` and `prefill_only=True` in compile API + - Optimized SWA i.e. reading only valid KV as per diagonal attention mask is enabled for this version by default + - This model can be used for prefix_caching by passing `kv_cache_batch_size=` in compile API + +# Decode-only model +## Retain Sliding window length of KV for sliding window layers, default behavour when `prefill_seq_len=1` in compile API + - This reduces the amount of DDR used by the model + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed +## Full KV for sliding window layers pass `retain_full_kv=True` along with `prefill_seq_len=1` in compile API + - This uses higher DDR as we are retaining ctx_len KV even for sliding window layers but will be reading only sliding window len kv in attention + - CB is enabled for this version pass `continous_batching=True` in `from_pretrained` call and strictly pass `full_batch_size=` and optinally `kv_cache_batch_size=` if needed + - This is enabled for the usecase of multi-turn chat, where we will be running prefill-> decode and then use cache of prefill as well as decode combined to again run prefill, so we want to retain full KV for sliding window layers + + +NOTE: +* decode-only model currently fails compilation with `use_onnx_subfunctions=True` so avoid using it +* 120B model needs NPI, there are two versions of NPI one with and without subfunction both are uploaded here, pass it as `node_precision_info=` +* It is advised to use `use_onnx_subfunctions=True` with prefill-only model, otherwise the compilation times are too high, with this the model is supposed to export and fail during compile as it needs assert sdk, so user is supposed to run this compilation manually by pasting the command printed in the error + diff --git a/examples/disagg_serving/gpt_oss_disagg_mode.py b/examples/disagg_serving/gpt_oss_disagg_mode.py new file mode 100644 index 000000000..fd0d5b045 --- /dev/null +++ b/examples/disagg_serving/gpt_oss_disagg_mode.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 256 +CTX_LEN = 256 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, +) +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + use_onnx_subfunctions=True, +) + +prefill_session = QAICInferenceSession(prefill_qpc_path) + +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() +qpc_out = prefill_session.run(inputs) +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) diff --git a/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py b/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py new file mode 100644 index 000000000..cac646d5e --- /dev/null +++ b/examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py @@ -0,0 +1,146 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +dir_path = os.path.dirname(os.path.realpath(__file__)) +subfunc_npi_file_path = os.path.join(dir_path, "subfunction_120b_npi.yaml") +non_subfunc_npi_file_path = os.path.join(dir_path, "non_subfunction_120b_npi.yaml") + +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +# Run prefill +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 8192 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, + # split_retained_state_io=True, # This should be used for disagg serving via VLLM + node_precision_info=non_subfunc_npi_file_path, +) + + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "provide path here" +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, + # split_retained_state_io=True, # This should be used for disagg serving via VLLM + node_precision_info=subfunc_npi_file_path, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +decode_session = QAICInferenceSession(decode_qpc_path) +prefill_session = QAICInferenceSession(prefill_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/disagg_serving/non_subfunction_120b_npi.yaml b/examples/disagg_serving/non_subfunction_120b_npi.yaml new file mode 100644 index 000000000..ec6cf034f --- /dev/null +++ b/examples/disagg_serving/non_subfunction_120b_npi.yaml @@ -0,0 +1,148 @@ +FP32NodeInstanceNames: + - /model/layers.0/Add_1_output_0 + - /model/layers.0/Add_output_0 + - /model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/Add_1_output_0 + - /model/layers.1/Add_output_0 + - /model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/Add_1_output_0 + - /model/layers.10/Add_output_0 + - /model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/Add_1_output_0 + - /model/layers.11/Add_output_0 + - /model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/Add_1_output_0 + - /model/layers.12/Add_output_0 + - /model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/Add_1_output_0 + - /model/layers.13/Add_output_0 + - /model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/Add_1_output_0 + - /model/layers.14/Add_output_0 + - /model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/Add_1_output_0 + - /model/layers.15/Add_output_0 + - /model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/Add_1_output_0 + - /model/layers.16/Add_output_0 + - /model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/Add_1_output_0 + - /model/layers.17/Add_output_0 + - /model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/Add_1_output_0 + - /model/layers.18/Add_output_0 + - /model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/Add_1_output_0 + - /model/layers.19/Add_output_0 + - /model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/Add_1_output_0 + - /model/layers.2/Add_output_0 + - /model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/Add_1_output_0 + - /model/layers.20/Add_output_0 + - /model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/Add_1_output_0 + - /model/layers.21/Add_output_0 + - /model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/Add_1_output_0 + - /model/layers.22/Add_output_0 + - /model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/Add_1_output_0 + - /model/layers.23/Add_output_0 + - /model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/Add_1_output_0 + - /model/layers.24/Add_output_0 + - /model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/Add_1_output_0 + - /model/layers.25/Add_output_0 + - /model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/Add_1_output_0 + - /model/layers.26/Add_output_0 + - /model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/Add_1_output_0 + - /model/layers.27/Add_output_0 + - /model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/Add_1_output_0 + - /model/layers.28/Add_output_0 + - /model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/Add_1_output_0 + - /model/layers.29/Add_output_0 + - /model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/Add_1_output_0 + - /model/layers.3/Add_output_0 + - /model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/Add_1_output_0 + - /model/layers.30/Add_output_0 + - /model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/Add_1_output_0 + - /model/layers.31/Add_output_0 + - /model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/Add_1_output_0 + - /model/layers.32/Add_output_0 + - /model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/Add_1_output_0 + - /model/layers.33/Add_output_0 + - /model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/Add_1_output_0 + - /model/layers.34/Add_output_0 + - /model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/Add_1_output_0 + - /model/layers.35/Add_output_0 + - /model/norm/Add_output_0 + - /model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/Add_1_output_0 + - /model/layers.4/Add_output_0 + - /model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/Add_1_output_0 + - /model/layers.5/Add_output_0 + - /model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/Add_1_output_0 + - /model/layers.6/Add_output_0 + - /model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/Add_1_output_0 + - /model/layers.7/Add_output_0 + - /model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/Add_1_output_0 + - /model/layers.8/Add_output_0 + - /model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/Add_1_output_0 + - /model/layers.9/Add_output_0 + - /model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /model/norm/CustomRMSNorm_output_0 + \ No newline at end of file diff --git a/examples/disagg_serving/subfunction_120b_npi.yaml b/examples/disagg_serving/subfunction_120b_npi.yaml new file mode 100644 index 000000000..3e1b6e264 --- /dev/null +++ b/examples/disagg_serving/subfunction_120b_npi.yaml @@ -0,0 +1,50 @@ +FP32NodeInstanceNames: + - onnx::Shape_139893 + - onnx::Shape_140187 + - onnx::Shape_144086 + - onnx::Shape_144410 + - onnx::Shape_883 + - onnx::Shape_1215 + - hidden_states.267 + - hidden_states.271 + - hidden_states.275 + - hidden_states.279 + - hidden_states.3 + - hidden_states.7 + - /model/norm/CustomRMSNorm_output_0 + - /model/layers.0/QEffGptOssDecoderLayer_output_127 + - /model/layers.1/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.2/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.3/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.4/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.5/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.6/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.7/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.8/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.9/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.10/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.11/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.12/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.13/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.14/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.15/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.16/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.17/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.18/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.19/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.20/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.21/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.22/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.23/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.24/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.25/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.26/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.27/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.28/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.29/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.30/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.31/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.32/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.33/QEffGptOssDecoderLayer.2_output_2 + - /model/layers.34/QEffGptOssDecoderLayer.1_output_2 + - /model/layers.35/QEffGptOssDecoderLayer.2_output_2 diff --git a/examples/embedding_model.py b/examples/embedding_model.py deleted file mode 100644 index 7e6973e2e..000000000 --- a/examples/embedding_model.py +++ /dev/null @@ -1,46 +0,0 @@ -# ----------------------------------------------------------------------------- - -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause - -# ----------------------------------------------------------------------------- - -# This is the work example of the Embedding model with the AI 100 -# For more information, visit: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 - -import torch -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModel as AutoModel - - -def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() - last_hidden_states[input_mask_expanded == 0] = -1e9 - return torch.max(last_hidden_states, 1)[0] - - -# Sentences we want sentence embeddings for -sentences = "This is an example sentence" - -# Load model from HuggingFace Hub -tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") - - -# You can specify the pooling strategy either as a string (e.g., "max") or by passing a custom pooling function. -# If no pooling is specified, the model will return its default output (typically token embeddings). -qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling=max_pooling) -# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling="max") -# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") - -# Here seq_len can be list of seq_len or single int -qeff_model.compile(num_cores=16, seq_len=[32, 64]) -# qeff_model.compile(num_cores=16, seq_len=32) - - -# Tokenize sentences -encoded_input = tokenizer(sentences, return_tensors="pt") - -sentence_embeddings = qeff_model.generate(encoded_input) - -print("Sentence embeddings:", sentence_embeddings) diff --git a/examples/embeddings/README.md b/examples/embeddings/README.md new file mode 100644 index 000000000..baf80919c --- /dev/null +++ b/examples/embeddings/README.md @@ -0,0 +1,71 @@ +# Embedding Examples + +Examples for running text embedding models on Qualcomm Cloud AI 100. + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Supported Models + +**QEff Auto Class:** `QEFFAutoModel` + +For the complete list of supported embedding models, see the [Validated Models - Embedding Section](../../docs/source/validate.md#embedding-models). + +Popular model families include: +- BERT-based (BGE, E5) +- MPNet +- Mistral-based +- NomicBERT +- Qwen2 +- RoBERTa (Granite) +- XLM-RoBERTa (multilingual) + +## Available Examples + +### text_embeddings.py +Generate text embeddings using transformer models. + +**Usage:** +```bash +# With default parameters +python text_embeddings.py + +# With custom parameters +python text_embeddings.py \ + --model-name sentence-transformers/all-MiniLM-L6-v2 \ + --sentences "This is an example sentence" \ + --pooling max \ + --num-cores 16 \ + --seq-len "32,64" +``` + +**Parameters:** +- `--model-name`: HuggingFace embedding model ID (default: `sentence-transformers/all-MiniLM-L6-v2`) +- `--sentences`: Input text to generate embeddings for (default: `"This is an example sentence"`) +- `--pooling`: Pooling strategy - `max`, `mean`, or `none` (default: `max`) +- `--num-cores`: Number of cores (default: `16`) +- `--seq-len`: Sequence length(s) - single int or comma-separated list (default: `"32,64"`) + +This example: +- Uses `sentence-transformers/all-MiniLM-L6-v2` by default +- Demonstrates custom pooling strategies (max pooling) +- Compiles for multiple sequence lengths [32, 64] +- Outputs text embeddings +- Works with various embedding model families (BERT, MPNet, Mistral-based, etc.) + +## Pooling Strategies + +The example supports different pooling strategies: +- **max**: Max pooling over token embeddings +- **mean**: Mean pooling over token embeddings +- **custom**: Pass your own pooling function + +## Documentation + +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Validated Embedding Models](https://quic.github.io/efficient-transformers/source/validate.html#embedding-models) +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) diff --git a/examples/embeddings/text_embeddings.py b/examples/embeddings/text_embeddings.py new file mode 100644 index 000000000..e69e6f1af --- /dev/null +++ b/examples/embeddings/text_embeddings.py @@ -0,0 +1,92 @@ +# ----------------------------------------------------------------------------- + +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause + +# ----------------------------------------------------------------------------- + +import argparse + +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModel as AutoModel + + +def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """Apply max pooling to the last hidden states.""" + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() + last_hidden_states[input_mask_expanded == 0] = -1e9 + return torch.max(last_hidden_states, 1)[0] + + +def main(): + parser = argparse.ArgumentParser(description="Text embeddings inference") + parser.add_argument( + "--model-name", + type=str, + default="sentence-transformers/all-MiniLM-L6-v2", + help="HuggingFace embedding model ID", + ) + parser.add_argument( + "--sentences", + type=str, + default="This is an example sentence", + help="Input sentence(s) to generate embeddings for", + ) + parser.add_argument( + "--pooling", + type=str, + default="max", + choices=["max", "mean", "none"], + help="Pooling strategy: 'max' for max pooling, 'mean' for mean pooling, 'none' for no pooling", + ) + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument( + "--seq-len", + type=str, + default="32,64", + help="Sequence length(s) - single int (e.g., '32') or comma-separated list (e.g., '32,64')", + ) + args = parser.parse_args() + + # Parse seq_len argument + if "," in args.seq_len: + seq_len = [int(x.strip()) for x in args.seq_len.split(",")] + else: + seq_len = int(args.seq_len) + + print(f"Loading embedding model: {args.model_name}") + print(f"Pooling strategy: {args.pooling}") + print(f"Sequence length(s): {seq_len}") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Load model with pooling strategy + # You can specify the pooling strategy either as a string (e.g., "max") or by passing a custom pooling function. + # If no pooling is specified, the model will return its default output (typically token embeddings). + if args.pooling == "max": + qeff_model = AutoModel.from_pretrained(args.model_name, pooling=max_pooling) + elif args.pooling == "mean": + qeff_model = AutoModel.from_pretrained(args.model_name, pooling="mean") + else: + qeff_model = AutoModel.from_pretrained(args.model_name) + + # Compile the model + # seq_len can be a list of seq_len or single int + qeff_model.compile(num_cores=args.num_cores, seq_len=seq_len) + + # Tokenize sentences + encoded_input = tokenizer(args.sentences, return_tensors="pt") + + # Run the generation + sentence_embeddings = qeff_model.generate(encoded_input) + + print(f"\nInput: {args.sentences}") + print(f"Sentence embeddings shape: {sentence_embeddings['output'].shape}") + print(f"Sentence embeddings: {sentence_embeddings}") + + +if __name__ == "__main__": + main() diff --git a/examples/image_text_to_text/README.md b/examples/image_text_to_text/README.md new file mode 100644 index 000000000..a6f1608b4 --- /dev/null +++ b/examples/image_text_to_text/README.md @@ -0,0 +1,112 @@ +# Image-Text-to-Text (Vision-Language Models) + +Multi-modal models that process both images and text. + + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` +## Quick Start +### Generic VLM Inference +Generic script for vision-language models: + +```bash +# With default parameters +python basic_vlm_inference.py + +# With custom parameters +python basic_vlm_inference.py \ + --model-name llava-hf/llava-1.5-7b-hf \ + --image-url "https://example.com/image.jpg" \ + --query "Describe this image" \ + --prefill-seq-len 128 \ + --ctx-len 3000 \ + --generation-len 128 \ + --num-cores 16 +``` + +### Single QPC Mode +Run the entire model (vision encoder + language model) in a single QPC: + +```bash +python basic_vlm_inference.py \ + --model-name llava-hf/llava-1.5-7b-hf \ + --image-url "https://example.com/image.jpg" \ + --query "Describe this image" \ + --num-cores 16 \ + --num-devices 1 +``` + +### Dual QPC Mode +Split the model into two QPCs (vision encoder + language model separately): + +```bash +python basic_vlm_inference.py \ + --model-name llava-hf/llava-1.5-7b-hf \ + --image-url "https://example.com/image.jpg" \ + --query "Describe this image" \ + --kv-offload \ + --num-cores 16 \ + --num-devices 1 +``` + +**Note:** In Dual QPC mode (`kv_offload=True`), the vision encoder runs in one QPC and the language model in another, with outputs transferred via host. This provides flexibility for independent execution of vision and language components. + +### Text-Only Execution (Skip Vision) +Run text-only inference without image processing: + +```bash +python basic_vlm_inference.py \ + --model-name llava-hf/llava-1.5-7b-hf \ + --prompt "Tell me about yourself" \ + --skip-vision True +``` + +**Note:** Use `skip_vision=True` when you want to run the language model without processing any images. This is useful for text-only tasks on vision-language models. + +### Continuous Batching +Dynamic batching for VLMs: + +```bash +python continuous_batching_vlm.py \ + --model-name meta-llama/Llama-4-Scout-17B-16E-Instruct \ + --full-batch-size 4 \ +``` + +## Supported Models + +**QEff Auto Class:** `QEFFAutoModelForImageTextToText` + +For the complete list of supported vision-language models, see the [Validated Models - Vision-Language Models Section](../../docs/source/validate.md#vision-language-models-text--image-generation). + +Popular model families include: +- Llama Vision (3.2, 4-Scout) +- Qwen VL (2.5) +- Mistral Vision (Small-3.1) +- Gemma-3 +- Granite Vision (3.2) +- InternVL +- Molmo +- LLaVA + +### Model-Specific Examples + +Some models have specialized examples demonstrating advanced features: + +| Model | Location | +|-------|----------| +| **Llama-4** | [models/llama4/](models/llama4/) | +| **Qwen** | [models/qwen_vl/](models/qwen_vl/) | +| **Mistral** | [models/mistral_vision/](models/mistral_vision/) | +| **Gemma** | [models/gemma_vision/](models/gemma_vision/) | +| **Granite** | [models/granite_vision/](models/granite_vision/) | +| **InternVL** | [models/internvl/](models/internvl/) | +| **Molmo** | [models/molmo/](models/molmo/) | + + +## Documentation +- **Full Guide**: [VLM Documentation](../../docs/source/quick_start.md#vision-language-models) +- **API Reference**: [QEFFAutoModelForImageTextToText](../../docs/source/qeff_autoclasses.md#QEFFAutoModelForImageTextToText) diff --git a/examples/image_text_to_text/basic_vlm_inference.py b/examples/image_text_to_text/basic_vlm_inference.py new file mode 100644 index 000000000..45d5454cb --- /dev/null +++ b/examples/image_text_to_text/basic_vlm_inference.py @@ -0,0 +1,134 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=True, + prefill_seq_len=32, + ctx_len=512, + generation_len=128, + img_size=336, + num_cores=16, + num_devices=1, +): + ## STEP 1: Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` determines Single QPC vs Dual QPC mode: + # - Single QPC (kv_offload=False): Entire model runs in one QPC + # - Dual QPC (kv_offload=True): Vision encoder and language model run in separate QPCs + # with outputs transferred via host for flexibility + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, attn_implementation="eager", kv_offload=kv_offload + ) + + ## STEP 2: Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP 3: Load and Process the Inputs for Inference + # Note: the message format would change for different model + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP 4: Run Inference on the Compiled Model + + streamer = TextStreamer(processor.tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + + +def main(): + parser = argparse.ArgumentParser(description="Vision-Language Model (VLM) inference") + parser.add_argument( + "--model-name", + type=str, + default="llava-hf/llava-1.5-7b-hf", + help="HuggingFace VLM model ID", + ) + parser.add_argument( + "--query", + type=str, + default="Describe this image.", + help="Text query/question about the image", + ) + parser.add_argument( + "--image-url", + type=str, + default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + help="URL of the image to process", + ) + parser.add_argument( + "--kv-offload", + action="store_true", + default=True, + help="Enable Dual QPC mode (vision encoder and LM in separate QPCs)", + ) + parser.add_argument("--prefill-seq-len", type=int, default=128, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=3000, help="Context length") + parser.add_argument("--generation-len", type=int, default=128, help="Number of tokens to generate") + parser.add_argument("--img-size", type=int, default=336, help="Image size for processing") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument("--num-devices", type=int, default=1, help="Number of devices") + args = parser.parse_args() + + print(f"Running VLM inference with model: {args.model_name}") + print(f"KV offload (Dual QPC mode): {args.kv_offload}") + + run_model( + model_name=args.model_name, + query=args.query, + image_url=args.image_url, + kv_offload=args.kv_offload, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + generation_len=args.generation_len, + img_size=args.img_size, + num_cores=args.num_cores, + num_devices=args.num_devices, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/gemma3_example/fp32_nodes_gemma3_27b.yaml b/examples/image_text_to_text/models/gemma_vision/configs/fp32_nodes_gemma3_27b.yaml similarity index 100% rename from examples/gemma3_example/fp32_nodes_gemma3_27b.yaml rename to examples/image_text_to_text/models/gemma_vision/configs/fp32_nodes_gemma3_27b.yaml diff --git a/examples/gemma3_example/fp32_nodes_gemma3_4b.yaml b/examples/image_text_to_text/models/gemma_vision/configs/fp32_nodes_gemma3_4b.yaml similarity index 100% rename from examples/gemma3_example/fp32_nodes_gemma3_4b.yaml rename to examples/image_text_to_text/models/gemma_vision/configs/fp32_nodes_gemma3_4b.yaml diff --git a/examples/gemma3_example/gemma3_mm.py b/examples/image_text_to_text/models/gemma_vision/gemma3_example.py similarity index 94% rename from examples/gemma3_example/gemma3_mm.py rename to examples/image_text_to_text/models/gemma_vision/gemma3_example.py index e090148f7..5c1f141d4 100644 --- a/examples/gemma3_example/gemma3_mm.py +++ b/examples/image_text_to_text/models/gemma_vision/gemma3_example.py @@ -13,15 +13,17 @@ # Change model_id to "google/gemma-3-27b-it" for 27B model model_id = "google/gemma-3-4b-it" + config = AutoConfig.from_pretrained(model_id) + # For Testing Purpose Only config.text_config.num_hidden_layers = 1 config.vision_config.num_hidden_layers = 2 + tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_id) -# pass HF_TOKEN if gated model -# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +# For single QPC: kv_offload=False, For dual QPC: kv_offload=True qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, config=config, attn_implementation="eager", kv_offload=True ) @@ -105,5 +107,5 @@ ) inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) output = qeff_model.generate(inputs=inputs, generation_len=100) - print(tokenizer.batch_decode(output.generated_ids)) + print(tokenizer.batch_decode(output.generated_ids, skip_special_tokens=True)) print(output) diff --git a/examples/granite_example/readme.md b/examples/image_text_to_text/models/granite_vision/README.md similarity index 100% rename from examples/granite_example/readme.md rename to examples/image_text_to_text/models/granite_vision/README.md diff --git a/examples/image_text_to_text/models/granite_vision/continuous_batching.py b/examples/image_text_to_text/models/granite_vision/continuous_batching.py new file mode 100644 index 000000000..22c4270bc --- /dev/null +++ b/examples/image_text_to_text/models/granite_vision/continuous_batching.py @@ -0,0 +1,67 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "ibm-granite/granite-vision-3.2-2b" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=5500, + ctx_len=6000, + num_cores=16, + num_devices=4, + img_size=384, + mxfp6_matmul=False, +) + +image_urls = [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", +] + +prompts = [ + "Describe the image", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=10, + image_height=1610, + image_width=1109, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output.generated_texts) diff --git a/examples/granite_example/granite_vision_inference.py b/examples/image_text_to_text/models/granite_vision/granite_example.py similarity index 96% rename from examples/granite_example/granite_vision_inference.py rename to examples/image_text_to_text/models/granite_vision/granite_example.py index 230e10a40..08b01b1ef 100644 --- a/examples/granite_example/granite_vision_inference.py +++ b/examples/image_text_to_text/models/granite_vision/granite_example.py @@ -5,15 +5,14 @@ # # ----------------------------------------------------------------------------- +import os + import requests from PIL import Image from transformers import AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText -# Add HuggingFace Token to access the model -HF_TOKEN = "" - def run_model( model_name, @@ -29,7 +28,6 @@ def run_model( num_devices=1, ): ## STEP - 1 Load the Processor and Model - processor = AutoProcessor.from_pretrained(model_name, token=token) # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. @@ -40,7 +38,6 @@ def run_model( model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, token=token, kv_offload=kv_offload) ## STEP - 2 Export & Compile the Model - model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, @@ -88,9 +85,12 @@ def run_model( num_cores = 16 num_devices = 4 + # Get HF token from environment variable (None if not set) + hf_token = os.getenv("HF_TOKEN") + run_model( model_name=model_name, - token=HF_TOKEN, + token=hf_token, query=query, kv_offload=kv_offload, image_url=image_url, diff --git a/examples/intern_example/readme.md b/examples/image_text_to_text/models/internvl/README.md similarity index 95% rename from examples/intern_example/readme.md rename to examples/image_text_to_text/models/internvl/README.md index 6b0b674c9..8371ffc50 100644 --- a/examples/intern_example/readme.md +++ b/examples/image_text_to_text/models/internvl/README.md @@ -2,7 +2,6 @@ This directory contains an example script of how to run inference on InternVL-1B model via QEFFAutoModelForCausalLM class. ## Required packages: -- `torch==2.7.0+cpu` - `torchvision==0.22.0+cpu` - `timm==1.0.14` - `einops==0.8.1` @@ -14,7 +13,7 @@ pip install torch==2.7.0+cpu --extra-index-url https://download.pytorch.org/whl/ To run example script after package installations: ```sh -python internvl_inference.py +python internvl_example.py ``` Expected output for given sample inputs in the script: diff --git a/examples/image_text_to_text/models/internvl/continuous_batching.py b/examples/image_text_to_text/models/internvl/continuous_batching.py new file mode 100644 index 000000000..ca3e0ede3 --- /dev/null +++ b/examples/image_text_to_text/models/internvl/continuous_batching.py @@ -0,0 +1,100 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.test_utils import InternProcessor + +model_id = "OpenGVLab/InternVL2_5-1B" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) +# For Testing Purpose Only +config.llm_config.num_hidden_layers = 2 +config.vision_config.num_hidden_layers = 2 + +# The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. +# To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. +model_hf = AutoModelForCausalLM.from_pretrained( + model_id, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, +) + +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False) +processor = InternProcessor(model_hf, tokenizer) + + +continuous_batching = True +if continuous_batching: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + trust_remote_code=True, + ) + + qeff_model.compile( + num_patches=13, # Set num_patches according to image_height and image_width, default is 13 (747 x 1000) + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config, trust_remote_code=True + ) + + qeff_model.compile( + num_patches=13, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=10, + image_height=747, + image_width=1000, +) + +print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/intern_example/internvl_inference.py b/examples/image_text_to_text/models/internvl/internvl_example.py similarity index 100% rename from examples/intern_example/internvl_inference.py rename to examples/image_text_to_text/models/internvl/internvl_example.py diff --git a/examples/image_text_to_text/models/llama4/continuous_batching.py b/examples/image_text_to_text/models/llama4/continuous_batching.py new file mode 100644 index 000000000..515e7c01b --- /dev/null +++ b/examples/image_text_to_text/models/llama4/continuous_batching.py @@ -0,0 +1,91 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Continuous Batching Example for Llama-4-Scout Vision Model + +This example demonstrates how to use continuous batching with vision-language models +to process multiple image-text pairs simultaneously in a single batch. +""" + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +# Model configuration +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +## STEP 1: Load Model Configuration and Processor +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only - reduce layers for faster testing +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +## STEP 2: Initialize Model with Continuous Batching +# Enable continuous batching to process multiple prompts in parallel +# Set kv_offload=True for Dual QPC mode (vision encoder + language model separately) +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, # Dual QPC mode + config=config, + continuous_batching=True, # Enable continuous batching +) + +## STEP 3: Compile the Model for Cloud AI 100 +# Configure compilation parameters for continuous batching +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, # Batch size per request + full_batch_size=4, # Total batch size for continuous batching + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +## STEP 4: Prepare Input Images and Prompts +# Define multiple images to process in the batch +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +# Define corresponding prompts for each image +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +## STEP 5: Run Inference with Continuous Batching +# Process all image-prompt pairs in a single batch +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, # Images are processed with their corresponding prompts + device_ids=[0, 1, 2, 3], + generation_len=100, +) + +## STEP 6: Display Results +print("Generated IDs:", exec_info.generated_ids) +print("\nFull execution info:") +print(exec_info) diff --git a/examples/llama4_multi_image_example.py b/examples/image_text_to_text/models/llama4/multi_image.py similarity index 100% rename from examples/llama4_multi_image_example.py rename to examples/image_text_to_text/models/llama4/multi_image.py diff --git a/examples/llama4_example.py b/examples/image_text_to_text/models/llama4/single_image.py similarity index 65% rename from examples/llama4_example.py rename to examples/image_text_to_text/models/llama4/single_image.py index 981bac203..ca1017d58 100644 --- a/examples/llama4_example.py +++ b/examples/image_text_to_text/models/llama4/single_image.py @@ -5,29 +5,47 @@ # # ----------------------------------------------------------------------------- +""" +Single Image Inference Example for Llama-4-Scout Vision Model + +This example demonstrates two modes: +1. Text-only mode (skip_vision=True): Run language model without image processing +2. Vision+Text mode (skip_vision=False): Process image and text together +""" + import torch import transformers from transformers import AutoConfig, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText +# Model configuration model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +## STEP 1: Load Model Configuration and Processor config = AutoConfig.from_pretrained(model_id) -# For Testing Purpose Only +# For Testing Purpose Only - reduce layers for faster testing config.text_config.num_hidden_layers = 4 config.vision_config.num_hidden_layers = 2 +## STEP 2: Initialize the Model +# Set kv_offload=True for Dual QPC mode (vision encoder + language model separately) qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="eager", kv_offload=True, config=config ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) -### use skip_vision=Ture, if want to run only text, ow false ### +# Toggle between text-only and vision+text modes +# Set skip_vision=True for text-only execution (no image processing) +# Set skip_vision=False for vision+text execution (process images with text) skip_vision = True if skip_vision: - ## Only Text ## + ## TEXT-ONLY MODE ## + + ## STEP 3: Compile Model for Text-Only Execution + # Set skip_vision=True to bypass image processing qeff_model.compile( prefill_seq_len=128, ctx_len=3072, @@ -38,10 +56,12 @@ mxfp6_matmul=True, mxint8_kv_cache=True, aic_enable_depth_first=True, - skip_vision=True, + skip_vision=True, # Skip vision encoder for text-only inference mos=1, ) + ## STEP 4: Prepare Text-Only Input + # Create a text-only message without any image messages = [ { "role": "user", @@ -51,6 +71,7 @@ }, ] + ## STEP 5: Process Input with Chat Template inputs = processor.apply_chat_template( messages, add_generation_prompt=True, @@ -59,14 +80,20 @@ return_tensors="pt", ) + ## STEP 6: Run Text-Only Inference streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + + ## STEP 7: Display Results print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) else: - ## Vision + Text ## + ## VISION + TEXT MODE ## + + ## STEP 3: Compile Model for Vision+Text Execution + # Do not set skip_vision (defaults to False) to enable image processing qeff_model.compile( prefill_seq_len=128, ctx_len=3072, @@ -80,11 +107,13 @@ mos=1, ) - ### IMAGE + TEXT ### + ## STEP 4: Prepare Image and Text Input + # Define the image URL to process image_url = ( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" ) + # Create a message with both image and text messages = [ { "role": "user", @@ -95,6 +124,7 @@ }, ] + ## STEP 5: Process Input with Chat Template inputs = processor.apply_chat_template( messages, add_generation_prompt=True, @@ -102,10 +132,14 @@ return_dict=True, return_tensors="pt", ) + # Convert pixel values to float32 for processing inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + ## STEP 6: Run Vision+Text Inference streamer = TextStreamer(tokenizer) output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + + ## STEP 7: Display Results print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) - print() diff --git a/examples/mistral3_example.py b/examples/image_text_to_text/models/mistral_vision/mistral3_example.py similarity index 100% rename from examples/mistral3_example.py rename to examples/image_text_to_text/models/mistral_vision/mistral3_example.py diff --git a/examples/molmo_example.py b/examples/image_text_to_text/models/molmo/molmo_example.py similarity index 96% rename from examples/molmo_example.py rename to examples/image_text_to_text/models/molmo/molmo_example.py index 09658ce41..04bba5248 100644 --- a/examples/molmo_example.py +++ b/examples/image_text_to_text/models/molmo/molmo_example.py @@ -16,7 +16,8 @@ model_id = "allenai/Molmo-7B-D-0924" config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) -config.num_hidden_layers = 2 +# For faster execution user can run on 2 layers, This is only for testing purpose +# config.num_hidden_layers = 2 # load the model qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, kv_offload=True, trust_remote_code=True, config=config) diff --git a/examples/image_text_to_text/models/qwen_vl/basic_inference.py b/examples/image_text_to_text/models/qwen_vl/basic_inference.py new file mode 100644 index 000000000..374f70ad2 --- /dev/null +++ b/examples/image_text_to_text/models/qwen_vl/basic_inference.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = True + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=8, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=8, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/image_text_to_text/models/qwen_vl/continuous_batching.py b/examples/image_text_to_text/models/qwen_vl/continuous_batching.py new file mode 100644 index 000000000..03094dc92 --- /dev/null +++ b/examples/image_text_to_text/models/qwen_vl/continuous_batching.py @@ -0,0 +1,69 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +config.text_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/image_text_to_text_inference.py b/examples/image_text_to_text_inference.py deleted file mode 100644 index e722284ba..000000000 --- a/examples/image_text_to_text_inference.py +++ /dev/null @@ -1,120 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import requests -from PIL import Image -from transformers import AutoProcessor, TextStreamer - -from QEfficient import QEFFAutoModelForImageTextToText - -# Add HuggingFace Token to access the model -HF_TOKEN = "" - - -def run_model( - model_name, - token, - query, - image_url, - kv_offload=False, - prefill_seq_len=32, - ctx_len=512, - generation_len=128, - img_size=560, - num_cores=16, - num_devices=1, -): - ## STEP - 1 Load the Processor and Model - - processor = AutoProcessor.from_pretrained(model_name, token=token) - - # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. - # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. - # The outputs of the Vision Encoder are then passed to the Language model via host in this case. - - model = QEFFAutoModelForImageTextToText.from_pretrained( - model_name, token=token, attn_implementation="eager", kv_offload=kv_offload - ) - - ## STEP - 2 Export & Compile the Model - - model.compile( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - img_size=img_size, - num_cores=num_cores, - num_devices=num_devices, - mxfp6_matmul=False, - ) - - ## STEP - 3 Load and process the inputs for Inference - - image = Image.open(requests.get(image_url, stream=True).raw) - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": query}, - ], - } - ] - input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] - - inputs = processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=prefill_seq_len, - ) - - ## STEP - 4 Run Inference on the compiled model - - streamer = TextStreamer(processor.tokenizer) - model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) - - -if __name__ == "__main__": - # Model name and Input parameters - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - query = "Describe this image." - image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" - - # Compilation parameters for the model - kv_offload = False - prefill_seq_len = 32 - ctx_len = 512 - generation_len = 128 - img_size = 560 - num_cores = 16 - num_devices = 1 - - run_model( - model_name=model_name, - token=HF_TOKEN, - query=query, - kv_offload=kv_offload, - image_url=image_url, - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - generation_len=generation_len, - img_size=img_size, - num_cores=num_cores, - num_devices=num_devices, - ) - - -""" -Expected Response: - -This image depicts a charming anthropomorphic rabbit standing on a dirt path in front of a picturesque stone cottage, surrounded by a serene landscape. - -The rabbit, with its light brown fur and distinctive long ears, is attired in a stylish blue coat, brown vest, and tan pants, exuding a sense of sophistication. The dirt path, flanked by vibrant flowers and lush greenery, leads to the cottage, which features a thatched roof and a chimney, adding to the rustic charm of the scene. In the background, rolling hills and trees create a breathtaking panorama, while the sky above is a brilliant blue with white clouds, completing the - -""" diff --git a/examples/onboarding_guide/causallm/README.md b/examples/onboarding_guide/causallm/README.md new file mode 100644 index 000000000..e7ac3f362 --- /dev/null +++ b/examples/onboarding_guide/causallm/README.md @@ -0,0 +1,292 @@ +# Onboarding a CausalLM Model + +## Prerequisites + +Install `qefficient-transformers` library in editable mode: +```sh +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install -e . +``` + +--- + +## Transformers Version Compatibility + +**Important:** QEfficient has a pinned `transformers` library version dependency. + +**Check the current version:** +```bash +grep "transformers==" pyproject.toml +``` + +See `dependencies` in [`pyproject.toml`](../../../pyproject.toml) for the exact version. + +**Compatibility rules:** +- You can only onboard models that are supported in the pinned transformers version or earlier +- Models added to transformers after this version are not yet supported +- Always verify when your target model was added to the transformers library + +**How to verify model compatibility:** + +1. Check transformers release history at [HuggingFace Transformers Releases](https://github.com/huggingface/transformers/releases) +2. Find the release where your model was first introduced +3. Compare versions: + - If model's release version ≤ QEfficient's pinned version → Proceed with onboarding + - If model's release version > QEfficient's pinned version → Cannot onboard yet + + +**Need a newer model?** + +If you need to onboard a model that requires a newer transformers version: +1. Open an issue on the [QEfficient GitHub repository](https://github.com/quic/efficient-transformers/issues) +2. Request a transformers version bump +3. Provide justification and the specific model you need + +--- + +## Introduction + +This guide walks you through onboarding a new CausalLM model to QEfficient-transformers. We use an example model named `Blueprint` to demonstrate the required changes. + +--- + +## Onboarding Process + +```mermaid +flowchart TD + A["Check Transformers Library +• Locate model in transformers/models/<model>/modeling_*.py +• Identify architecture classes (Attention, DecoderLayer, etc.)"] + + B{"Class already +Implemented"} + + C["Create Custom Files +• Create modeling_*.py +• Implement custom classes +• Add __qeff_init__ methods"] + + D["Test the model using +the auto model class +and validate the +functionality"] + + E["Add Mappings in pytorch_transforms.py +• CustomOpsTransform (RMSNorm) +• KVCacheTransform (all model classes) +• ExternalModuleMapperTransform (if needed)"] + + K{"if all test passes"} + + L["Debug & Fix Issues +Retest with test pipelines"] + + M["Submit PR +(Follow +CONTRIBUTING +guidelines)"] + + A --> B + B -->|No| C + B -->|Yes| D + C --> E + E --> F + + subgraph F["Testing Pipeline (4 Stages)"] + direction TB + G["Stage 1: PyTorch HF Model (Baseline) +(tokens should match)"] + H["Stage 2: PyTorch KV Model (After QEff transforms) +(tokens should match)"] + I["Stage 3: ONNX/ORT Model (After export) +(tokens should match)"] + J["Stage 4: Cloud AI 100 (Hardware execution) +(tokens should match)"] + + G --> H + H --> I + I --> J + end + + F --> K + K -->|No| L + L --> F + K -->|Yes| M +``` + +--- + +## Step 1: Check Transformers Library + +1. **Locate the model** in the transformers library: + - Path: `/src/transformers/models//modeling_.py` + - Example: `/src/transformers/models/blueprint/modeling_blueprint.py` + +2. **Identify required classes**: + - Attention Layer + - Decoder Layer + - Model (main class) + - ForCausalLM (top-level) + - RMSNorm/LayerNorm + - RotaryEmbedding (if applicable) + +3. **Check existing implementations** in `QEfficient/transformers/models/`: + - If similar classes exist → Reuse patterns + - If not → Create custom implementations + +--- + +## Step 2: Create Custom Files & Mappings + +### 2.1 Create Custom Modeling File + +Create directory structure: +``` +QEfficient/transformers/models/blueprint/ +├── __init__.py +└── modeling_blueprint.py +``` + +**Key modifications in `modeling_blueprint.py`:** +- `QEffBlueprintRotaryEmbedding`: Precompute sin/cos for rotary embeddings +- `QEffBlueprintAttention`: Use `position_ids`, return `past_key_value`, implement `__qeff_init__` +- `QEffBlueprintDecoderLayer`: Return `past_key_value` from forward pass +- `QEffBlueprintModel`: Use `QEffDynamicCache` instead of standard cache +- `QEffBlueprintForCausalLM`: Entry point with additional parameters + +See `modeling_example.py` for detailed implementation examples. + +### 2.2 Add Mappings in pytorch_transforms.py + +**CustomOpsTransform** (RMSNorm mapping): +```python +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + BlueprintRMSNorm: CustomRMSNormAIC, + } +``` + +**KVCacheTransform** (all model classes): +```python +class KVCacheTransform(ModuleMappingTransform): + _module_mapping = { + BlueprintAttention: QEffBlueprintAttention, + BlueprintDecoderLayer: QEffBlueprintDecoderLayer, + BlueprintModel: QEffBlueprintModel, + BlueprintForCausalLM: QEffBlueprintForCausalLM, + } +``` + +See `example_pytorch_transforms.py` for complete example. + +--- + +## Step 3: Testing (4-Stage Pipeline) + +Your implementation is validated through four stages: + +| Stage | Description | Validation | +|-------|-------------|------------| +| **1. PyTorch HF** | Original transformers model | Baseline tokens | +| **2. PyTorch KV** | After QEff transforms | Tokens match Stage 1 | +| **3. ONNX/ORT** | After export to ONNX | Tokens match Stage 2 | +| **4. Cloud AI 100** | Hardware execution | Tokens match Stage 3 | + +**Test function:** `check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100` in `tests/transformers/models/test_causal_lm_models.py` + +### Common Issues + +**Token mismatch (Stage 1→2):** +- Check all classes are mapped in `KVCacheTransform` +- Verify `__qeff_init__` methods exist +- Ensure `position_ids` are correctly passed + +**ONNX export failure (Stage 2→3):** +- Check for unsupported PyTorch operations +- Verify dynamic shapes are defined + +**Compilation failure (Stage 3→4):** +- Reduce `num_cores` or model size +- Check device availability: `get_available_device_id()` + +--- + +## Step 4: Add to Test Suite + +Edit `tests/transformers/models/test_causal_lm_models.py`: + +```python +test_models_causal = [ + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "gpt2", + # ... existing models ... + "YourOrg/YourModel-7B", # Add your model here +] +``` + +**Run tests:** +```bash +# Test your specific model +pytest tests/transformers/models/test_causal_lm_models.py::test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100 -k "YourModel" -v + +# Run all regular tests +pytest tests/transformers/models/test_causal_lm_models.py -m regular +``` + +--- + +## Step 5: Validation Checklist + +Before submitting PR: + +**Implementation:** +- [ ] Created `QEfficient/transformers/models//` directory +- [ ] Implemented all required custom classes +- [ ] Added mappings in `CustomOpsTransform` and `KVCacheTransform` +- [ ] Added imports at top of `pytorch_transforms.py` + +**Testing:** +- [ ] Model added to `test_models_causal` list +- [ ] All 4 stages pass (PyTorch HF → KV → ORT → AI 100) +- [ ] Continuous batching tests pass +- [ ] `qconfig.json` generated successfully + +**Code Quality:** +- [ ] Code follows project style guidelines +- [ ] Commits use DCO sign-off (`git commit -s`) +- [ ] Branch created from `main` + +--- + +## Step 6: Submit Pull Request + +Follow guidelines in [CONTRIBUTING.md](../../../CONTRIBUTING.md): + +1. Create feature branch: `git checkout -b add-yourmodel-support main` +2. Commit with DCO: `git commit -s -m "Add support for YourModel"` +3. Push and create PR targeting `main` branch +4. Include test results in PR description + +--- + +## Troubleshooting Quick Reference + +| Issue | Solution | +|-------|----------| +| Token mismatch between stages | Check class mappings, verify `position_ids` handling | +| Shape errors | Verify KV cache dimensions, check `past_key_value` returns | +| ONNX export fails | Replace unsupported ops, define dynamic shapes | +| Compilation fails | Reduce `num_cores`, check device availability | +| Runtime errors | Verify input shapes match specializations | + +**Debug tip:** Start with `n_layer=1` and short prompts, then gradually increase complexity. + +--- + +## References + +- [Hugging Face Transformers](https://github.com/huggingface/transformers) +- [QEfficient Transformers](https://github.com/quic/efficient-transformers) +- [Contributing Guidelines](../../../CONTRIBUTING.md) +- [Test Suite](../../../tests/transformers/models/test_causal_lm_models.py) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py new file mode 100644 index 000000000..ff62588f9 --- /dev/null +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -0,0 +1,291 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Example pytorch_transforms.py showing common model onboarding patterns. + +This file demonstrates three representative patterns: +1. Blueprint - Standard decoder-only model (example for onboarding) +2. Llama - Most common architecture pattern +3. Mixtral - Mixture of Experts (MoE) model + +For more examples and patterns, see: +- Production transforms: QEfficient/base/pytorch_transforms.py +- All model implementations: QEfficient/transformers/models/ +- Specific patterns: + * Gemma (custom RMSNorm): QEfficient/transformers/models/gemma/ + * Multimodal (Llama4, Mllama): QEfficient/transformers/models/llama4/ + * External models (Grok): QEfficient/transformers/models/grok_1/ + * Vision-Language models: QEfficient/transformers/models/mllama/ +""" + +import warnings +from types import MethodType +from typing import Callable, Optional, Tuple, Union + +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) +from torch import nn + +# Example imports for three representative models +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRMSNorm, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MixtralRMSNorm, + MixtralSparseMoeBlock, +) + +from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaAttention, + QEffLlamaDecoderLayer, + QEffLlamaForCausalLM, + QEffLlamaModel, +) +from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( + QEffMixtralAttention, + QeffMixtralDecoderLayer, + QEffMixtralForCausalLM, + QEffMixtralModel, + QEffMixtralSparseMoeBlock, +) +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.sampler.sampler import sampler_forward +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" + + +class CustomOpsTransform(ModuleMappingTransform): + """ + Maps RMSNorm classes to custom implementations optimized for Cloud AI 100. + + Most models use the standard CustomRMSNormAIC. For special cases (like Gemma), + you can create custom RMSNorm in QEfficient.customop. + """ + + _module_mapping = { + # Blueprint - Example model for onboarding + BlueprintRMSNorm: CustomRMSNormAIC, + # Llama - Most common pattern + LlamaRMSNorm: CustomRMSNormAIC, + # Mixtral - MoE model pattern + MixtralRMSNorm: CustomRMSNormAIC, + # TODO: Add your model's RMSNorm mapping here: + # YourModelRMSNorm: CustomRMSNormAIC, + } + + +class KVCacheTransform(ModuleMappingTransform): + """ + Maps model classes to their QEfficient counterparts with KV cache support. + + This is the most critical transform for enabling efficient inference. + All model classes (Attention, DecoderLayer, Model, ForCausalLM) must be mapped. + """ + + _module_mapping = { + # Blueprint - Example model for onboarding + BlueprintAttention: QEffBlueprintAttention, + BlueprintDecoderLayer: QEffBlueprintDecoderLayer, + BlueprintModel: QEffBlueprintModel, + BlueprintForCausalLM: QEffBlueprintForCausalLM, + # Llama - Most common pattern (standard decoder-only) + LlamaAttention: QEffLlamaAttention, + LlamaDecoderLayer: QEffLlamaDecoderLayer, + LlamaModel: QEffLlamaModel, + LlamaForCausalLM: QEffLlamaForCausalLM, + # Mixtral - MoE model pattern (includes SparseMoeBlock) + MixtralAttention: QEffMixtralAttention, + MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, + MixtralDecoderLayer: QeffMixtralDecoderLayer, + MixtralModel: QEffMixtralModel, + MixtralForCausalLM: QEffMixtralForCausalLM, + # TODO: Add your model's class mappings here: + # YourModelAttention: QEffYourModelAttention, + # YourModelDecoderLayer: QEffYourModelDecoderLayer, + # YourModelModel: QEffYourModelModel, + # YourModelForCausalLM: QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + QEffBlueprintForCausalLM, + # TODO: Add your model's ForCausalLM class here if using Speculative Decoding: + # QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None) + + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + + if speculative_model_type not in (supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys())): + raise ValueError( + f"Speculative model type {speculative_model_type} is not supported. " + f"Currently only support {supported_spd_model_types}" + ) + + if (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) + transformed = True + else: + raise NotImplementedError( + f"Model class {model_class} does not yet support returning multiple logits to keep." + ) + + kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp + return model, transformed + + +class SamplerTransform: + """ + Add nodes at the output of any generic QEffForCausalLM model to enable the + sampling of next tokens at the device (instead of the host) and return the + next tokens and/or probability distributions. + + Note: To achieve this, the generic QEffForCausalLM model must provide the + logits as output. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # TODO: Add your model's ForCausalLM class here if using on-device sampling: + # QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + if qaic_config is None or not qaic_config.get("include_sampler", False): + return model, transformed + + if (model_class := model.__class__) in cls._module_mapping: + model.old_forward = model.forward + model.forward = MethodType(sampler_forward, model) + transformed = True + else: + raise NotImplementedError(f"Model class {model_class} does not support on device sampling.") + + return model, transformed + + +class VlmKVOffloadTransform(ModuleMappingTransform): + """ + Vision-Language Model transform with KV offloading (two QPC setup). + + Used for multimodal models where vision and text processing are separated. + See QEfficient/transformers/models/mllama/ for implementation examples. + """ + + _module_mapping = { + # TODO: Add VLM models with KV offloading here: + # YourVLMTextCrossAttention: QEffYourVLMTextCrossAttentionTwoQPC, + } + + +class VlmNoKVOffloadTransform(ModuleMappingTransform): + """ + Vision-Language Model transform without KV offloading (single QPC setup). + + Used for multimodal models in single QPC configuration. + See QEfficient/transformers/models/mllama/ for implementation examples. + """ + + _module_mapping = { + # TODO: Add VLM models without KV offloading here: + # YourVLMTextCrossAttention: QEffYourVLMTextCrossAttentionSingleQPC, + } + + +class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_string_replace_method = { + # TODO: Add external model mappings here (for models not in transformers library): + # "YourExternalModelClass": { + # "forward": QEffYourExternalModel.forward, + # "__qeff_init__": QEffYourExternalModel.__qeff_init__, + # }, + } + + _match_class_replace_method = {} + + +class PoolingTransform: + """ + Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. + The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. + """ + + @classmethod + def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: + transformed = False + pooling_method = ( + POOLING_MAP[pooling] + if isinstance(pooling, str) and pooling in POOLING_MAP + else validate_user_pooling_function(pooling) + ) + model = PooledModel(model, pooling_method) + warnings.warn("Pooling is applied to the model.") + return model, transformed diff --git a/examples/onboarding_guide/causallm/modeling_example.py b/examples/onboarding_guide/causallm/modeling_example.py new file mode 100644 index 000000000..195c9d7db --- /dev/null +++ b/examples/onboarding_guide/causallm/modeling_example.py @@ -0,0 +1,394 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Blueprint model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintConfig, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRotaryEmbedding, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffBlueprintRotaryEmbedding(BlueprintRotaryEmbedding): + """ + Add the required Rotary Embedding functionality to the model based on the Class in the transformers modeling file. + The purpose of this class is to precompute sin and cos values for the rotary embedding and cache it for faster inference. + This class is more or less the same for all models that are onboarded. + """ + + def __init__(self, config: BlueprintConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors. + + We modify this method to enable the application of the rotary embedding based on position_ids + instead of seq_len. This is needed as our modified modelling accepts position_ids and not + the attention_mask as an input. + """ + # + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + """ + Implements the forward pass of Eager Attention for the model. + We explicitly support Eager mode based attention on our device. + The method would mostly be generic so we don't expect it to have much changes. + MIN_MASKED_ATTENTION_VALUE is a special value which helps our compiler know what -inf should be represented by. + """ + pass + + +class QEffBlueprintAttention(BlueprintAttention): + """ + Here we'll setup the forward pass of the Attention module as implemented in the original model. + We initialize our own RotaryEmbedding module via __qeff_init__ method call. + + """ + + # < We load our own custom class for the rotary embedding to enable supporting position_ids> + # Since we map the custom classes to the original classes, __init__ method wouldn't work as expected, + # Hence we use __qeff_init__ method to initialize something while the mapping happens. + + def __qeff_init__(self): + self.rotary_emb = QEffBlueprintRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Most of the implementation remains the same as original forward method. + The parts where difference occurs are the way we apply the rotary embeddings. + Also, we return the past_key_values instead of storing it in the default transformers cache. + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + # We build the rotary embeddings different from the transformers method. + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # Application of the rotary embeddings requires position_ids as well. + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # < We add all the required items for cache kwargs which would enable updating QEffDynamicCache > + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # < We override the attention_interface method with our own to enable Eager Attention> + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffBlueprintDecoderLayer(BlueprintDecoderLayer): + """ + Overrides the forward method of the original BlueprintDecoderLayer. + Only changes being that the past_key_value is returned and `self.self_attn` method + is now an object of QEffBlueprintAttention instead. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + The modified forward function also stores and returns the past_key_value. + Every other operation remains the same. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # < Self attention would also have to return the past_key_value as well and we capture it here> + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffBlueprintModel(BlueprintModel): + """ + Replaces the original BlueprintModel with a modified version. + We initialize the custom `QEffDynamicCache` for past_key_values here instead of the DynamicCache class. + This custom Cache class has all the required custom ops to perform CtxScatter/CtxGather as well as other required operations. + This enables us to cache the past key values in the way we want for AIC. The component won't require any changes mostly. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + # < We create the custom QEffDynamicCache here to be used during the AIC execution> + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffBlueprintForCausalLM(BlueprintForCausalLM): + """ + No major changes are needed in the forward method of this class, it is the entry point for the model during inference. + We add the additionally required parameters and pass those down the line as well. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # < We add the additional parameters that we use for our models here and pass them down the line > + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/examples/onboarding_guide/customop/CustomGELU/src/customgelu_aic.cpp b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_aic.cpp new file mode 100644 index 000000000..ac018b91d --- /dev/null +++ b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_aic.cpp @@ -0,0 +1,23 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: BSD-3-Clause +// +//----------------------------------------------------------------------------- + +#include "CustomOpAICInterface.h" +#include "stddef.h" + +extern "C" { + +/* The AIC compilation target supports an API similar to the Interpreter API. +Additionally, threadId, which is the AIC thread ID, is passed. +Kernel is invoked by four AIC threads with threadId equal to 0, 1, 2, and 3. */ + +void CustomGELUAIC( + const CustomOpContext *ctx, + const int32_t threadId) +{ +} + +} diff --git a/examples/onboarding_guide/customop/CustomGELU/src/customgelu_functions.cpp b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_functions.cpp new file mode 100644 index 000000000..f0ebb8f89 --- /dev/null +++ b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_functions.cpp @@ -0,0 +1,74 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: BSD-3-Clause +// +//----------------------------------------------------------------------------- + +#include "CustomOpFunctions.h" +#include "CustomOpInterpreterInterface.h" +#include "CustomOpTileConfigHelpers.h" +#include "CustomOpTypes.h" + +#include + +extern "C" { +bool customOpVerify( + const CustomOpPropertiesHandle *const opProp) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + return true; +} + +const char * customOpSelectImpl( + const CustomOpPropertiesHandle *const opProp, + const CustomOpKernelInfo *const kernelInfos, + const int32_t numKernels, + const char *backend) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + /* For AIC pick 'AIC', for Interpreter pick 'Interpreter' */ + if (strcmp(backend, "AIC") == 0) + { + return ""; + } + else if (strcmp(backend, "Interpreter") == 0) + { + return ""; + } + return nullptr; +} + +bool customOpInferShape( + CustomOpPropertiesHandle *const opProp) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + return false; +} + +bool customOpSetProperties( + CustomOpPropertiesHandle *opProp) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + return false; +} + +bool customOpMapTiles( + CustomOpPropertiesHandle *opProp) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + return false; +} +void customOpDeallocateMemory( + CustomOpPropertiesHandle *opProp) +{ + /* Refer to function declaration at CustomOpFunctions.h for usage. */ + + CustomOpTileConfigHelpers::destroyTileConfigsAndMergeConfigs(opProp); +} +} diff --git a/examples/onboarding_guide/customop/CustomGELU/src/customgelu_interpreter.cpp b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_interpreter.cpp new file mode 100644 index 000000000..bdae3430a --- /dev/null +++ b/examples/onboarding_guide/customop/CustomGELU/src/customgelu_interpreter.cpp @@ -0,0 +1,33 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: BSD-3-Clause +// +//----------------------------------------------------------------------------- + +/* +* This file can be compiled separately and can be loaded using dlopen +* Compilation command: (tried with gcc 5.5) +* g++ -shared -std=c++11 -fPIC -o _lib.so .cpp .cpp -I/opt/qti-aic/dev/inc +* for example: g++ -shared -std=c++11 -fPIC -o reluop_lib.so reluop_functions.cpp reluop_interpreter.cpp -I/opt/qti-aic/dev/inc +*/ + +#include "CustomOpInterpreterInterface.h" + +extern "C" { +void CustomGELUInterpreter( + CustomOpContext *ctx) +{ + /* The interpreter implementation is provided to the compiler as a shared library + (or collection of shared libraries). Each shared library can contain multiple + versions (flavors) of implementations of the operation, refered onwards as kernels. + A kernel is selected at model compilation time by the selection function. The + developer is responsible for compilation of these shared libraries. As the interface + is C, the shared libraries can be compiled by various compilers (GCC, CLANG, etc). + In addition, as these shared libraries are running on the Host CPU, the developer + can open files, dump results, use stdout/stderr for printing debug messages, etc. + This makes the Interpreter implementation a very effective way for debugging the + operation functionality as part of model execution. The signature of the + kernel (implementation) is generic, and fits any custom operation. */ +} +} diff --git a/examples/onboarding_guide/customop/CustomGELU/src/customop_lib.so b/examples/onboarding_guide/customop/CustomGELU/src/customop_lib.so new file mode 100644 index 000000000..f1013b1e3 Binary files /dev/null and b/examples/onboarding_guide/customop/CustomGELU/src/customop_lib.so differ diff --git a/examples/onboarding_guide/customop/README.md b/examples/onboarding_guide/customop/README.md new file mode 100644 index 000000000..e8e523b70 --- /dev/null +++ b/examples/onboarding_guide/customop/README.md @@ -0,0 +1,343 @@ +# Adding Custom Operations to QEfficient + +Custom ops are hardware optimized implementations of neural network operators for Qualcomm Cloud AI 100. This example walks you through the complete process of creating, registering, and deploying custom operations. + +## When to Add Custom Ops + +Add custom ops when: + +- Replacing standard PyTorch ops with faster hardware-optimized versions +- Implementing operations the compiler doesn't support natively +- Optimizing frequently used operations in your model + +## Understanding the 3-Layer Pattern + +Every custom op in QEfficient follows a 3-layer architecture: + +1. **ONNX Script** - Defines the ONNX representation that gets exported (which the compiler later reads) +2. **PyTorch Autograd Function** - Bridges PyTorch execution and ONNX export +3. **nn.Module Wrapper** - What users interact with in their code + +--- + +## Step 1: Review the Example Template + +Before creating a custom op, examine the example template to understand the structure: + +**File:** `examples/onboarding_guide/customop/example_custom_op.py` + +This file demonstrates a complete custom GELU implementation with all three layers. Study how: +- Constants are defined in the ONNX script +- The autograd function handles both PyTorch execution and ONNX export +- The module wrapper provides a clean user interface + +--- + +## Step 2: Create the Custom Op File + +Create a new file under `QEfficient/customop/.py` with the following structure: + +### Layer 1: ONNX Script + +This defines the ONNX representation that PyTorch exports (which the compiler will later read and compile). + +```python +import onnxscript +import torch +from torch import nn +from QEfficient.utils import constants + +ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) + +@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) +def CustomOpBluePrint(input: onnxscript.FLOAT): + """ + ONNX implementation of the operation. + + Important constraints: + - Domain MUST be "com.qti.aisw.onnx" for custom ops + - Use ops.Constant() for any constant values + """ + +``` + + +PyTorch's ONNX exporter uses this to generate the ONNX graph + +### Layer 2: PyTorch Autograd Function + +This bridges PyTorch execution and ONNX export. + +```python +class CustomOpBluePrintFunc(torch.autograd.Function): + + @staticmethod + def forward(input: torch.Tensor, mode: str = "default"): + """ + PyTorch implementation - can use ANY PyTorch operations. + This runs during normal PyTorch execution (training, inference, etc.) + """ + + + @staticmethod + def setup_context(ctx, inputs, outputs): + """Store any tensors needed for backward pass (not needed for inference-only ops)""" + + + @staticmethod + def symbolic(g: torch.Graph, input: torch.Value, mode: str = "default") -> torch.Value: + """ + Called during ONNX export - maps to the ONNX script. + """ + return g.onnxscript_op(CustomOpBluePrint, input) +``` + +### Layer 3: Module Wrapper + +This is the user facing interface. + +```python +class CustomOpBluePrintAIC(nn.Module): + """ + User-facing module wrapper. + This is what users instantiate and use in their models. + """ + + def __init__(self, mode: str = "default"): + super().__init__() + pass + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass delegates to the autograd function""" + return CustomOpBluePrintFunc.apply(input, self.mode) +``` + +Provides a clean, standard `nn.Module` interface + +--- + +## Step 3: Register the Custom Op in Qeff + +Add the custom op to `QEfficient/customop/__init__.py`: + +```python +from QEfficient.customop. import CustomOpBluePrintAIC + +__all__ = [ + # ... existing exports ... + "CustomOpBluePrintAIC", +] +``` + +This makes the custom op importable from `QEfficient.customop`. + +--- + +## Step 4: Map to PyTorch Transforms + +Add the mapping in `QEfficient/transformers/models/pytorch_transforms.py` to automatically replace standard PyTorch modules with the custom op. + + +### Example from the Codebase + +```python +from transformers.activations import NewGELUActivation +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from QEfficient.customop import CustomGELUAIC, CustomRMSNormAIC + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + # Activation functions + NewGELUActivation: CustomGELUAIC, + + # Normalization layers + LlamaRMSNorm: CustomRMSNormAIC, + } +``` + +**How the transform works:** +1. When a model is loaded, QEfficient scans all modules +2. For each module type in `_module_mapping`, it replaces the module with the custom implementation +3. The replacement happens automatically before ONNX export + +--- + +## Step 5: Export the Model with QEff + +Use QEfficient to export the model with custom ops applied: + +```python +from QEfficient import QEFFAutoModelForCausalLM +from transformers import AutoTokenizer + +# Model name +model_name = "gpt2" + +# Load model - custom ops are automatically applied via transforms +model = QEFFAutoModelForCausalLM.from_pretrained(model_name) + +# Export to ONNX +# The custom ops will be included in the ONNX graph +export_path = model.export() + +print("Model exported successfully with custom ops!") +print(f"Model path : {export_path}") +``` + +The export process: +1. Loads the PyTorch model +2. Applies `CustomOpsTransform` to replace standard ops with custom ops +3. Saves the ONNX model + +--- + +## Step 6: Generate C++ and YAML Files + +Use the SDK tool to generate the custom op package structure from the ONNX model: + +```bash +# Generate custom op package from ONNX model +python /opt/qti-aic/tools/custom-ops/gen_custom_op_package.py \ + --onnx_model path/to/model.onnx \ + --output_dir CustomOp_Package \ + --domain com.qti.aisw.onnx +``` + +This generates: +``` +CustomOp_Package/ +├── custom_op_config.yaml # Op configuration for compiler +└── / + └── src/ + ├── customop_functions.cpp # Utility functions (to be implemented) + └── customop_interpreter.cpp # Backend implementation (to be implemented) +``` + +--- + +## Step 7: Build the Shared Library + +Compile the C++ code into a shared library (.so file) that the runtime can load. + +### Reference the SDK Example + +See `/opt/qti-aic/examples/apps/custom-op/basic-example/README.md` for detailed instructions. + +### Build Command + +```bash +# Navigate to the custom op package +cd CustomOp_Package//src + +# Compile the shared library +g++ -shared -std=c++11 -fPIC \ + -o customop_lib.so \ + customop_functions.cpp \ + customop_interpreter.cpp \ + -I/opt/qti-aic/dev/inc \ + -I/opt/qti-aic/dev/lib +``` + +After building, verify the .so file exists: +```bash +ls -lh customop_lib.so +``` + +--- + +## Step 8: Compile the Model with Custom Op + +Pass the custom op YAML config to the compiler: + +```python +from QEfficient import QEFFAutoModelForCausalLM +from transformers import AutoTokenizer + +model_name = "gpt2" + +# Load model with custom ops +model = QEFFAutoModelForCausalLM.from_pretrained(model_name) + +# Compile with custom op - pass the YAML config path +model.compile( + num_cores=16, + registered_custom_op="CustomOp_Package/custom_op_config.yaml" +) + +print("Model compiled successfully with custom op!") +``` +--- + + +## Workflow Summary + +Here's the complete workflow from start to finish: + +### Quick Reference + +1. **Review** `example_custom_op.py` to understand the 3-layer pattern +2. **Create** `QEfficient/customop/your_op.py` with ONNX script, autograd function, and module wrapper +3. **Register** in `QEfficient/customop/__init__.py` +4. **Map** in `QEfficient/transformers/models/pytorch_transforms.py` +5. **Export** model with QEff (custom ops automatically applied) +6. **Generate** C++ and YAML files using SDK tool +7. **Implement** C++ functions and build .so library +8. **Compile** model with `registered_custom_op` parameter + +### Workflow Diagram + +```mermaid +graph TD + A[Start: Need Custom Op] --> B[Review example_custom_op.py] + B --> C[Create QEfficient/customop/your_op.py] + C --> C1[Layer 1: ONNX Script] + C --> C2[Layer 2: Autograd Function] + C --> C3[Layer 3: Module Wrapper] + C1 --> D[Register in __init__.py] + C2 --> D + C3 --> D + D --> E[Add mapping in pytorch_transforms.py] + E --> F[Export model with QEff] + F --> G[Generate C++ & YAML with SDK tool] + G --> H[Implement C++ functions] + H --> I[Build .so library] + I --> J[Compile model with registered_custom_op] + J --> K[Run inference on device] + K --> L[Verify performance] + L --> M{Performance OK?} + M -->|Yes| N[Done!] + M -->|No| O[Optimize C++ implementation] + O --> I + + style A fill:#2196F3,stroke:#1976D2,stroke-width:2px,color:#fff + style N fill:#4CAF50,stroke:#388E3C,stroke-width:2px,color:#fff + style C fill:#FFC107,stroke:#FFA000,stroke-width:2px,color:#000 + style J fill:#FF9800,stroke:#F57C00,stroke-width:2px,color:#fff + style K fill:#E91E63,stroke:#C2185B,stroke-width:2px,color:#fff +``` + +### Key Files Reference + +| Purpose | File Path | +|---------|-----------| +| Example template | `examples/onboarding_guide/customop/example_custom_op.py` | +| SDK custom op tool | `/opt/qti-aic/tools/custom-ops/gen_custom_op_package.py` | +| SDK examples | `/opt/qti-aic/examples/apps/custom-op/` | + +--- + +## Examples and References + +### Example Implementations + +- **`example_custom_op.py`** - Complete template showing the 3-layer pattern +- **`example_pytorch_transforms.py`** - How to register custom ops in transforms +- **`QEfficient/customop/rms_norm.py`** - Real implementation with learnable parameters +- **`QEfficient/customop/ctx_scatter_gather.py`** - Advanced custom op example + +### Documentation + +- **[Custom Ops Directory](../../../QEfficient/customop/)** - All custom op implementations +- **[PyTorch Transforms](../../../QEfficient/transformers/models/pytorch_transforms.py)** - Transform registry +- **[SDK Custom Op Documentation](/opt/qti-aic/examples/apps/custom-op)** - Hardware-specific details +- **[Contributing Guide](../../../CONTRIBUTING.md)** - How to contribute custom ops diff --git a/examples/onboarding_guide/customop/custom_op_config.yaml b/examples/onboarding_guide/customop/custom_op_config.yaml new file mode 100644 index 000000000..06187e80b --- /dev/null +++ b/examples/onboarding_guide/customop/custom_op_config.yaml @@ -0,0 +1,27 @@ +--- +version: Major.Minor.Patch +CustomOps: +# CustomGELU + - type: CustomGELU + package: com.qti.aisw.onnx + inputs: + - name: in1 + maxDims: 5 + parameters: [] + outputs: + - name: out1 + maxDims: 5 + functionsLibrary: ./CustomGELU/src/customgelu_lib.so + implementations: + - backend: Interpreter + type: CustomGELUInterpreter + impl: ./CustomGELU/src/customgelu_lib.so + - backend: AIC + type: CustomGELUAIC + impl: ./CustomGELU/src/customgelu_aic.cpp + memoryConfig: + DDR: + CacheableDDR: + VTCM: [in1, out1] + requiredFor: +... diff --git a/examples/onboarding_guide/customop/example_custom_op.py b/examples/onboarding_guide/customop/example_custom_op.py new file mode 100644 index 000000000..f682bcebf --- /dev/null +++ b/examples/onboarding_guide/customop/example_custom_op.py @@ -0,0 +1,86 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Custom op template - shows the 3-layer pattern we use for all custom ops. +""" + +import onnxscript +import torch +from torch import nn + +from QEfficient.utils import constants + +ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) + + +# Layer 1: ONNX Script +# This is what the compiler sees when it compiles your model + + +@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) +def CustomOpBluePrint(input: onnxscript.FLOAT): + """ + ONNX implementation of your operation. + Important: Only use tensor inputs - no strings or other types! + """ + sqrt_2 = ops.Constant(value_floats=[1.4142135623730951]) + half = ops.Constant(value_floats=[0.5]) + one = ops.Constant(value_floats=[1.0]) + + x_scaled = ops.Div(input, sqrt_2) + erf_x = ops.Erf(x_scaled) + result = ops.Mul(input, ops.Mul(half, ops.Add(one, erf_x))) + + return result + + +# Layer 2: PyTorch Autograd Function +# Connects PyTorch execution to ONNX export +# Pytorch forward function is called during PyTorch execution (CPU/GPU). +# When running on ONNX runtime, the CustomOpBluePrint function (Layer 1) is called instead. + + +class CustomOpBluePrintFunc(torch.autograd.Function): + @staticmethod + def forward(input: torch.Tensor, mode: str = "default"): + """PyTorch implementation - can use any PyTorch ops""" + if mode == "approximate": + return 0.5 * input * (1.0 + torch.tanh(0.7978845608028654 * (input + 0.044715 * input**3))) + else: + return input * 0.5 * (1.0 + torch.erf(input / 1.4142135623730951)) + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, input: torch.Value, mode: str = "default") -> torch.Value: + """Called during ONNX export - don't pass string params!""" + return g.onnxscript_op(CustomOpBluePrint, input).setTypeAs(input) + + +# Layer 3: Module Wrapper +# What users actually interact with + + +class CustomOpBluePrintAIC(nn.Module): + def __init__(self, mode: str = "default"): + super().__init__() + if mode not in ["default", "approximate"]: + raise ValueError(f"mode must be 'default' or 'approximate', got {mode}") + self._mode_str = mode + + @property + def mode(self) -> str: + return self._mode_str if hasattr(self, "_mode_str") else "default" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return CustomOpBluePrintFunc.apply(input, self.mode) + + def extra_repr(self) -> str: + return f"mode={self.mode}" diff --git a/examples/onboarding_guide/customop/example_pytorch_transforms.py b/examples/onboarding_guide/customop/example_pytorch_transforms.py new file mode 100644 index 000000000..591890c52 --- /dev/null +++ b/examples/onboarding_guide/customop/example_pytorch_transforms.py @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Example pytorch_transforms.py showing how to register custom operations. + +This file demonstrates how to add custom operation mappings to the transform +system so they are automatically applied when loading models. + +For the actual production transforms, see: +- QEfficient/transformers/models/pytorch_transforms.py +""" + +from transformers.activations import ( + NewGELUActivation, +) + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop import CustomGELUAIC + + +class CustomOpsTransform(ModuleMappingTransform): + """ + Maps standard PyTorch operations to custom Cloud AI 100 implementations. + + How it works: + 1. When a model is loaded, this transform scans all modules + 2. For each module type in _module_mapping, it replaces the module + with the corresponding custom implementation + 3. The replacement happens automatically before ONNX export + """ + + _module_mapping = { + # ACTIVATION FUNCTIONS + # GELU + NewGELUActivation: CustomGELUAIC, + # TODO: Add other activation functions + # nn.SiLU: CustomSiLUAIC, + # nn.Mish: CustomMishAIC, + # NORMALIZATION LAYERS + # RMSNorm - Used by Llama, Mistral, Mixtral, etc. + # from transformers.models.llama.modeling_llama import LlamaRMSNorm + # LlamaRMSNorm: CustomRMSNormAIC, + # TODO: Add your model's normalization layers + # YourModelRMSNorm: CustomRMSNormAIC, + # OTHER OPERATIONS + # TODO: Add other custom operations + # nn.Linear: CustomLinearAIC, # If you have a custom linear layer + # nn.Embedding: CustomEmbeddingAIC, # If you have custom embeddings + } diff --git a/examples/peft/README.md b/examples/peft/README.md new file mode 100644 index 000000000..fbc8c99b7 --- /dev/null +++ b/examples/peft/README.md @@ -0,0 +1,83 @@ +# PEFT Examples + +Examples for running Parameter-Efficient Fine-Tuning (PEFT) models with LoRA adapters on Qualcomm Cloud AI 100. + + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Supported Models + +**QEff Auto Class:** `QEffAutoPeftModelForCausalLM` + +PEFT/LoRA adapters work with any supported base model architecture. + +Popular base models include: +- Llama +- Mistral, Mixtral + + +## Available Examples + +### single_adapter.py +Load and use a single LoRA adapter with a base model. + +**Usage:** +```python +python single_adapter.py +``` + +This example: +- Loads Mistral-7B base model with a LoRA adapter +- Demonstrates adapter switching +- Shows inference with different adapters (magicoder, tldr, gsm8k, agnews) + +### multi_adapter.py +Use multiple LoRA adapters with continuous batching. + +**Usage:** +```python +python multi_adapter.py +``` + +This example: +- Runs multiple adapters simultaneously in one batch +- Demonstrates continuous batching with `full_batch_size=4` +- Shows different prompts using different adapters in the same batch + +## Key Features + +### Single Adapter Mode +- Load one LoRA adapter at a time +- Switch between adapters dynamically +- Suitable for single-task inference + +### Multi-Adapter Mode (Continuous Batching) +- Run multiple adapters simultaneously +- Different prompts can use different adapters in the same batch +- Efficient for multi-task scenarios +- Requires `continuous_batching=True` and `finite_adapters=True` + +## Adapter Management + +```python +# Load adapter +qeff_model.load_adapter("predibase/adapter_name", "adapter_name") + +# Set active adapter +qeff_model.set_adapter("adapter_name") + +# Unload adapter +qeff_model.unload_adapter("adapter_name") +``` + +## Documentation + +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Validated Base Models](https://quic.github.io/efficient-transformers/source/validate.html#text-only-language-models) +- [PEFT Documentation](https://huggingface.co/docs/peft) +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) diff --git a/examples/lora_models.py b/examples/peft/multi_adapter.py similarity index 100% rename from examples/lora_models.py rename to examples/peft/multi_adapter.py diff --git a/examples/peft_models.py b/examples/peft/single_adapter.py similarity index 60% rename from examples/peft_models.py rename to examples/peft/single_adapter.py index 63c196a22..4f84bd13c 100644 --- a/examples/peft_models.py +++ b/examples/peft/single_adapter.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +## This example demonstrates single adapter usage with sequential adapter switching ## + from transformers import AutoTokenizer, TextStreamer from QEfficient import QEffAutoPeftModelForCausalLM @@ -12,19 +14,27 @@ base_model_name = "mistralai/Mistral-7B-v0.1" tokenizer = AutoTokenizer.from_pretrained(base_model_name) streamer = TextStreamer(tokenizer) +prefill_seq_len = 32 +ctx_len = 1024 +generation_len = 1024 + + +## STEP 1 -- init base model +qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/magicoder", "magicoder") + +## STEP 2 -- export & compile qeff model +qeff_model.compile(prefill_seq_len=prefill_seq_len, ctx_len=ctx_len) -m = QEffAutoPeftModelForCausalLM.from_pretrained("predibase/magicoder", "magicoder") -m.export() -m.compile(prefill_seq_len=32, ctx_len=1024) +## STEP 3 -- run inference with different adapters -# Magicoder adapter -m.set_adapter("magicoder") +# Magicoder adapter - code generation +qeff_model.set_adapter("magicoder") inputs = tokenizer("def fibonacci", return_tensors="pt") -m.generate(**inputs, streamer=streamer, max_new_tokens=1024) +qeff_model.generate(**inputs, streamer=streamer, max_new_tokens=generation_len) -# TLDR, summary generator -m.load_adapter("predibase/tldr_headline_gen", "tldr_headline_gen") -m.set_adapter("tldr_headline_gen") +## STEP 3.1 -- load and use TLDR headline generator adapter +qeff_model.load_adapter("predibase/tldr_headline_gen", "tldr_headline_gen") +qeff_model.set_adapter("tldr_headline_gen") inputs = tokenizer( """Summarize this passage in one sentence or less: Jeffrey Berns, CEO of Blockchains LLC, wants the Nevada government to allow companies like \ his to form local governments on land they own, granting them power over everything from \ @@ -36,21 +46,21 @@ Summary: """, return_tensors="pt", ) -m.generate(**inputs, streamer=streamer, max_new_tokens=1024) +qeff_model.generate(**inputs, streamer=streamer, max_new_tokens=1024) -# Math problems -m.load_adapter("predibase/gsm8k", "gsm8k") -m.set_adapter("gsm8k") +## STEP 3.2 -- load and use GSM8K adapter for math problems +qeff_model.load_adapter("predibase/gsm8k", "gsm8k") +qeff_model.set_adapter("gsm8k") inputs = tokenizer( "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. \ How many total meters does he run a week?", return_tensors="pt", ) -m.generate(**inputs, streamer=streamer, max_new_tokens=1024) +qeff_model.generate(**inputs, streamer=streamer, max_new_tokens=1024) -# News explanation -m.load_adapter("predibase/agnews_explained", "agnews_explained") -m.set_adapter("agnews_explained") +## STEP 3.3 -- load and use AGNews adapter for news classification +qeff_model.load_adapter("predibase/agnews_explained", "agnews_explained") +qeff_model.set_adapter("agnews_explained") inputs = tokenizer( """Below is a news article. Please classify it under one of the following \ classes (World, Business, Sports, Sci/Tech) and provide a reasonable coherent explanation for \ @@ -65,4 +75,4 @@ """, return_tensors="pt", ) -m.generate(**inputs, streamer=streamer, max_new_tokens=1024) +qeff_model.generate(**inputs, streamer=streamer, max_new_tokens=1024) diff --git a/examples/performance/README.md b/examples/performance/README.md new file mode 100644 index 000000000..9308ce6db --- /dev/null +++ b/examples/performance/README.md @@ -0,0 +1,160 @@ +# Performance Optimization Examples + +Examples demonstrating performance optimization techniques for Qualcomm Cloud AI 100. + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Available Examples + +### Speculative Decoding + +Accelerate text generation using speculative decoding techniques. + +#### draft_based.py +Draft-based speculative decoding with separate draft and target models. + +**Basic Usage:** +```bash +python speculative_decoding/draft_based.py \ + --target-model-name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --draft-model-name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --num-speculative-tokens 4 +``` + +**Advanced Usage:** +```bash +python speculative_decoding/draft_based.py \ + --target-model-name meta-llama/Llama-3.1-8B \ + --draft-model-name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --num-speculative-tokens 4 \ + --prefill-seq-len 32 \ + --ctx-len 128 \ + --target-device-group 0,1 \ + --draft-device-group 2 +``` +errors in this example + + +#### prompt_lookup.py +Prompt Lookup Decoding (PLD) - N-gram based speculation without a draft model. + +**Basic Usage:** +```bash +python speculative_decoding/prompt_lookup.py \ + --target-model-name TinyLlama/TinyLlama-1.1B-Chat-v1.0 \ + --num-speculative-tokens 3 \ + --max-ngram-size 3 +``` + +#### multi_projection.py +Multi-projection speculative decoding (Turbo models). + +**Basic Usage:** +```bash +python speculative_decoding/multi_projection.py \ + --pretrained-model-name-or-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 +``` +error + +### On-Device Sampling + +Control sampling parameters directly on the AI 100 hardware. + +#### on_device_sampling.py +Configure sampling parameters (temperature, top-k, top-p, etc.) on-device. + +**Basic Usage:** +```bash +python on_device_sampling.py \ + --model-name meta-llama/Llama-3.1-8B \ + --num-cores 16 \ + --prompt-len 128 \ + --ctx-len 256 +``` + +**Advanced Usage with Sampling Parameters:** +```bash +python on_device_sampling.py \ + --model-name meta-llama/Llama-3.1-8B \ + --prompt-len 128 \ + --ctx-len 256 \ + --full-batch-size 2 \ + --device-group 0,1,2,3 \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --repetition-penalty 1.9 \ + --temperature 0.67 \ + --top-k 54720 \ + --top-p 0.89 +``` + +### Compute-Context-Length + +Calculating Context-Length dynamically during inference for getting the best related performance within each window of context-length + +#### compute_context_length/basic_inference.py +Configure CCL parameters: 1) ccl-enabled: to activate CCL feature, 2) comp-ctx-lengths-prefill: list of context length to be used during prefilling, and 3) comp-ctx-lengths-decode: list of context lengths to be used during decoding. + +**Usage for Text-only models:** +```bash +python compute_context_length/basic_inference.py \ + --model-name meta-llama/Llama-3.1-8B \ + --num-cores 16 \ + --prefill-seq-len 32 \ + --ctx-len 1024 \ + --ccl-enabled \ + --comp-ctx-lengths-prefill 500,1000 \ + --comp-ctx-lengths-decode 512,1024 +``` + +**Usage for VLM models such as mllama and llava:** +```bash +python compute_context_length/vlm_inference.py \ + --model-name meta-llama/Llama-3.2-11B-Vision-Instruct \ + --hf-token "" \ + --num-cores 16 \ + --prefill-seq-len 32 \ + --ctx-len 8192 \ + --img-size 560 \ + --ccl-enabled \ + --comp-ctx-lengths-prefill 4096 \ + --comp-ctx-lengths-decode 6144,8192 +``` + +**Usage with other MoE and Multimodal models:** +For various models available in compute_context_length directory such as gemma3, gpt_oss, granite_vision, internvl, llama4_cb, llama4_multi_image, llama4, mistral3, molmo, qwen2_5_vl, qwen2_5_vl_cb, and qwen3moe, use the related inference script and only change the model-name and ccl configuration in the related script. The following is an example of each model: +```bash +python compute_context_length/gemma3.py +python compute_context_length/gpt_oss.py +python compute_context_length/granite_vision.py +python compute_context_length/internvl.py +python compute_context_length/llama4_cb.py +python compute_context_length/llama4_multi_image.py +python compute_context_length/llama4.py +python compute_context_length/mistral3.py +python compute_context_length/molmo.py +python compute_context_length/qwen2_5_vl.py +python compute_context_length/qwen2_5_vl_cb.py +python compute_context_length/qwen3moe.py +``` + +## Performance Tips + +1. **Speculative Decoding**: Best for long-form generation where draft model is much faster than target +2. **Prompt Lookup**: No draft model needed, works well for repetitive patterns +3. **Multi-Projection**: Optimal for models with built-in speculation support +4. **On-Device Sampling**: Reduces host-device communication overhead +5. **C++ Execution**: Maximum performance for production deployments + +## Documentation + +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Performance Features](https://quic.github.io/efficient-transformers/source/features_enablement.html) +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) diff --git a/examples/performance/compute_context_length/README.md b/examples/performance/compute_context_length/README.md new file mode 100644 index 000000000..2115251e2 --- /dev/null +++ b/examples/performance/compute_context_length/README.md @@ -0,0 +1,348 @@ +# Compute Context Length (CCL) Examples + +Examples demonstrating Compute Context Length (CCL) optimization for efficient inference on Qualcomm Cloud AI 100. + +## What is CCL? + +Compute Context Length (CCL) is a performance optimization feature that allows models to use different context lengths during different phases of inference: + +- **Prefill Phase**: Processing the initial prompt with optimized context lengths +- **Decode Phase**: Generating new tokens with dynamically adjusted context lengths + +This optimization provides: +- **Memory Efficiency**: Uses smaller context lengths when possible +- **Performance Optimization**: Reduces computation for shorter sequences +- **Flexible Scaling**: Adapts context length based on actual sequence position +- **Hardware Optimization**: Optimized for Qualcomm Cloud AI 100 accelerators + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Quick Start + +### Text-Only Models + +Run basic CCL inference with default settings: +```bash +python basic_inference.py +``` + +Customize with command-line arguments: +```bash +python basic_inference.py \ + --model-name meta-llama/Llama-3.2-1B \ + --prompt "Hello, how are you?" \ + --ctx-len 1024 \ + --ccl-enabled \ + --comp-ctx-lengths-prefill "256,500" \ + --comp-ctx-lengths-decode "512,1024" \ + --generation-len 100 +``` + +# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag +```bash +python basic_inference.py \ + --model-name meta-llama/Llama-3.2-1B \ + --prompt "Hello, how are you?" \ + --ctx-len 1024 \ + --ccl-enabled \ + --generation-len 100 +``` + +### Vision-Language Models + +Run VLM inference with CCL: +```bash +python vlm_inference.py +``` + +Customize with command-line arguments: +```bash +python vlm_inference.py \ + --model-name meta-llama/Llama-3.2-11B-Vision-Instruct \ + --query "Describe this image" \ + --image-url "https://..." \ + --ccl-enabled \ + --comp-ctx-lengths-prefill "4096" \ + --comp-ctx-lengths-decode "6144,8192" \ + --ctx-len 8192 +``` + +# For automatic CCL lists generation, simply not pass CCL lists and only pass ccl-enabled flag +```bash +python vlm_inference.py \ + --model-name meta-llama/Llama-3.2-11B-Vision-Instruct \ + --query "Describe this image" \ + --image-url "https://..." \ + --ccl-enabled \ + --ctx-len 8192 +``` + +## Available Examples + +### Text-Only Models + +#### basic_inference.py +Basic CCL usage with text-only language models. + +**Supported Models:** +- Llama (3.2, 3.3, swiftkv) +- Gemma/Gemma-2 +- Mistral +- Phi/Phi-3 +- Qwen +- Granite +- GPT-2, GPT-J +- CodeGen +- OLMo-2 +- Mistral/Mixtral +- Qwen2 +- Falcon + +**Command-Line Arguments:** +- `--model-name`: HuggingFace model ID (default: meta-llama/Llama-3.2-1B) +- `--prompt`: Input prompt (default: "My name is ") +- `--ctx-len`: Maximum context length (default: 1024) +- `--comp-ctx-lengths-prefill`: Comma-separated prefill context lengths (default: 256,500) +- `--comp-ctx-lengths-decode`: Comma-separated decode context lengths (default: 512,1024) +- `--generation-len`: Number of tokens to generate (default: 128) +- `--continuous-batching`: Enable continuous batching mode +- `--num-cores`: Number of cores (default: 16) +- `--num-devices`: Number of devices (default: 1) + +**Usage Examples:** +```bash +# Basic usage with defaults +python basic_inference.py + +# Custom model and prompt +python basic_inference.py \ + --model-name Qwen/Qwen2.5-7B-Instruct \ + --prompt "Explain quantum computing" + +# With continuous batching +python basic_inference.py \ + --continuous-batching \ + --full-batch-size 4 + +# Larger context with progressive CCL +python basic_inference.py \ + --ctx-len 4096 \ + --comp-ctx-lengths-prefill "1024,2048" \ + --comp-ctx-lengths-decode "2048,3072,4096" +``` + +**Python API:** +```python +from transformers import AutoTokenizer +from QEfficient import QEFFAutoModelForCausalLM + +model = QEFFAutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + qaic_config={ + "comp_ctx_lengths_prefill": [256, 500], + "comp_ctx_lengths_decode": [512, 1024], + "ctx_len": 1024, # Required for CCL validation + }, +) +``` + +#### gpt_oss.py +CCL for GPT-OSS MoE models with prefill_seq_len=1 optimization. + +**Usage:** +```bash +python gpt_oss.py +``` + +**Note:** For MoE models, both prefill and decode CCL lists can be similar when using prefill_seq_len=1. + +### Vision-Language Models + +#### vlm_inference.py +General VLM inference with CCL optimization. + +**Usage:** +```bash +python vlm_inference.py +``` + +#### gemma3.py +CCL for Gemma-3 multimodal models (4B/27B). + +**Usage:** +```bash +python gemma3.py +``` + +#### granite_vision.py +CCL for IBM Granite Vision models. + +**Usage:** +```bash +python granite_vision.py +``` + +#### internvl.py +CCL for InternVL2.5 models with custom processor. + +**Usage:** +```bash +python internvl.py +``` + +#### llama4.py +CCL for Llama-4 Scout vision-language models. + +**Usage:** +```bash +python llama4.py +``` + +#### llama4_cb.py +CCL for Llama-4 with continuous batching. + +**Usage:** +```bash +python llama4_cb.py +``` + +#### llama4_multi_image.py +CCL for Llama-4 with multiple images. + +**Usage:** +```bash +python llama4_multi_image.py +``` + +#### mistral3.py +CCL for Mistral-Small-3.1 vision models. + +**Usage:** +```bash +python mistral3.py +``` + +#### molmo.py +CCL for Molmo-7B multimodal models. + +**Usage:** +```bash +python molmo.py +``` + +#### qwen2_5_vl.py +CCL for Qwen2.5-VL models (32B). + +**Usage:** +```bash +python qwen2_5_vl.py +``` + +#### qwen2_5_vl_cb.py +CCL for Qwen2.5-VL with continuous batching. + +**Usage:** +```bash +python qwen2_5_vl_cb.py +``` + +## Configuration Guidelines + +### Choosing CCL Values + +1. **Prefill Context Lengths** (`comp_ctx_lengths_prefill`): + - Start with smaller values (e.g., [256, 512, 1024]) + - Should be less than or equal to your prefill_seq_len + - Gradually increase based on prompt chunk position + +2. **Decode Context Lengths** (`comp_ctx_lengths_decode`): + - Start from a value based on expected prompt length + - Include intermediate steps (e.g., [512, 1024, 2048, ctx_len]) + - Final value should match ctx_len + +3. **Context Length** (`ctx_len`): + - Maximum context length for the model + - Required parameter for CCL validation + - Should match your model's maximum supported length + +### Example Configurations + +**Small Context (1K-2K):** +```python +ctx_len = 2048 +comp_ctx_lengths_prefill = [256, 512] +comp_ctx_lengths_decode = [1024, ctx_len] +``` + +**Medium Context (4K-8K):** +```python +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072, 4096] +comp_ctx_lengths_decode = [4096, 6144, ctx_len] +``` + +**Large Context (16K+):** +```python +ctx_len = 16384 +comp_ctx_lengths_prefill = [4096, 8192] +comp_ctx_lengths_decode = [8192, 12288, ctx_len] +``` + +## Performance Tips + +1. **Memory Optimization**: Use smaller CCL values for prefill to reduce memory footprint +2. **Progressive Scaling**: Include intermediate CCL values in decode list for smooth transitions +3. **Vision Models**: Larger prefill contexts needed for image embeddings +4. **Continuous Batching**: CCL works seamlessly with CB for dynamic workloads +5. **MoE Models**: Consider prefill_seq_len=1 for optimal performance + +## Common Patterns + +### Text-Only Model +```python +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + qaic_config={ + "comp_ctx_lengths_prefill": [256, 500], + "comp_ctx_lengths_decode": [512, 1024], + "ctx_len": 1024, + }, +) +``` + +### Vision-Language Model +```python +model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + kv_offload=True, + qaic_config={ + "comp_ctx_lengths_prefill": [3072], + "comp_ctx_lengths_decode": [4096, 8192], + "ctx_len": 8192, + }, +) +``` + +### Continuous Batching +```python +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill": [256, 500], + "comp_ctx_lengths_decode": [512, 1024], + "ctx_len": 1024, + }, +) +``` + +## Documentation + +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Performance Features](https://quic.github.io/efficient-transformers/source/features_enablement.html) +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) diff --git a/examples/performance/compute_context_length/basic_inference.py b/examples/performance/compute_context_length/basic_inference.py new file mode 100644 index 000000000..6e8c045fb --- /dev/null +++ b/examples/performance/compute_context_length/basic_inference.py @@ -0,0 +1,156 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Basic Compute Context Length (CCL) inference example. + +This example demonstrates how to use CCL optimization for text generation models. +CCL allows using different context lengths during prefill and decode phases, +reducing memory footprint and computation for shorter sequences. +""" + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Text generation with Compute Context Length (CCL) optimization") + parser.add_argument( + "--model-name", + type=str, + default="meta-llama/Llama-3.2-1B", + help="HuggingFace model ID", + ) + parser.add_argument( + "--prompt", + type=str, + default="My name is ", + help="Input prompt for text generation", + ) + parser.add_argument( + "--prefill-seq-len", + type=int, + default=128, + help="Prefill sequence length", + ) + parser.add_argument( + "--ctx-len", + type=int, + default=1024, + help="Maximum context length", + ) + parser.add_argument( + "--ccl-enabled", + action="store_true", + help="Enable compute-context-length (CCL) feature", + ) + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda x: [int(i) for i in x.split(",")], + default=None, + help="Comma-separated list of context lengths for prefill phase (e.g., '256,500')", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda x: [int(i) for i in x.split(",")], + default=None, + help="Comma-separated list of context lengths for decode phase (e.g., '512,1024')", + ) + parser.add_argument( + "--generation-len", + type=int, + default=128, + help="Number of tokens to generate", + ) + parser.add_argument( + "--num-cores", + type=int, + default=16, + help="Number of cores for compilation", + ) + parser.add_argument( + "--num-devices", + type=int, + default=1, + help="Number of devices to use", + ) + parser.add_argument( + "--continuous-batching", + action="store_true", + help="Enable continuous batching mode", + ) + parser.add_argument( + "--full-batch-size", + type=int, + default=1, + help="Full batch size for continuous batching", + ) + parser.add_argument( + "--mxint8-kv-cache", + action="store_true", + default=True, + help="Enable MX INT8 KV cache", + ) + parser.add_argument( + "--mxfp6-matmul", + action="store_true", + default=True, + help="Enable MX FP6 matrix multiplication", + ) + args = parser.parse_args() + + print(f"Loading model: {args.model_name}") + print(f"Continuous batching: {args.continuous_batching}") + + # Load model with CCL configuration + model = QEFFAutoModelForCausalLM.from_pretrained( + args.model_name, + continuous_batching=args.continuous_batching, + qaic_config={ + "ccl_enabled": args.ccl_enabled, + }, + ) + + # Compile the model + print("\nCompiling model...") + compile_kwargs = { + "prefill_seq_len": args.prefill_seq_len, + "ctx_len": args.ctx_len, + "num_cores": args.num_cores, + "num_devices": args.num_devices, + "mxint8_kv_cache": args.mxint8_kv_cache, + "mxfp6_matmul": args.mxfp6_matmul, + } + + if args.continuous_batching: + compile_kwargs["full_batch_size"] = args.full_batch_size + if args.ccl_enabled: + compile_kwargs["comp_ctx_lengths_prefill"] = args.comp_ctx_lengths_prefill + compile_kwargs["comp_ctx_lengths_decode"] = args.comp_ctx_lengths_decode + + qpc_path = model.compile(**compile_kwargs) + print(f"Model compiled successfully to: {qpc_path}") + + # Load tokenizer and generate + print("\nGenerating text...") + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + exec_info = model.generate( + prompts=[args.prompt], + tokenizer=tokenizer, + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Generated: {exec_info.generated_texts[0]}") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/compute_context_length/fp32_nodes_gemma3_27b.yaml b/examples/performance/compute_context_length/fp32_nodes_gemma3_27b.yaml new file mode 100755 index 000000000..d2a4bf164 --- /dev/null +++ b/examples/performance/compute_context_length/fp32_nodes_gemma3_27b.yaml @@ -0,0 +1,685 @@ +FP32NodeInstanceNames: + + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.0/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.5/Add_2_output_0 + - /language_model/layers.5/Add_3_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.11/Add_2_output_0 + - /language_model/layers.11/Add_3_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.17/Add_2_output_0 + - /language_model/layers.17/Add_3_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.23/Add_2_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.29/Add_2_output_0 + - /language_model/layers.29/Add_3_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.34/Add_1_output_0 + - /language_model/layers.34/Add_2_output_0 + - /language_model/layers.34/Add_3_output_0 + - /language_model/layers.34/Add_output_0 + - /language_model/layers.35/Add_1_output_0 + - /language_model/layers.35/Add_2_output_0 + - /language_model/layers.35/Add_3_output_0 + - /language_model/layers.35/Add_output_0 + - /language_model/layers.36/Add_1_output_0 + - /language_model/layers.36/Add_2_output_0 + - /language_model/layers.36/Add_3_output_0 + - /language_model/layers.36/Add_output_0 + - /language_model/layers.37/Add_1_output_0 + - /language_model/layers.37/Add_2_output_0 + - /language_model/layers.37/Add_3_output_0 + - /language_model/layers.37/Add_output_0 + - /language_model/layers.38/Add_1_output_0 + - /language_model/layers.38/Add_2_output_0 + - /language_model/layers.38/Add_3_output_0 + - /language_model/layers.38/Add_output_0 + - /language_model/layers.39/Add_1_output_0 + - /language_model/layers.39/Add_2_output_0 + - /language_model/layers.39/Add_3_output_0 + - /language_model/layers.39/Add_output_0 + - /language_model/layers.40/Add_1_output_0 + - /language_model/layers.40/Add_2_output_0 + - /language_model/layers.40/Add_3_output_0 + - /language_model/layers.40/Add_output_0 + - /language_model/layers.41/Add_1_output_0 + - /language_model/layers.41/Add_2_output_0 + - /language_model/layers.41/Add_3_output_0 + - /language_model/layers.41/Add_output_0 + - /language_model/layers.42/Add_1_output_0 + - /language_model/layers.42/Add_2_output_0 + - /language_model/layers.42/Add_3_output_0 + - /language_model/layers.42/Add_output_0 + - /language_model/layers.43/Add_1_output_0 + - /language_model/layers.43/Add_2_output_0 + - /language_model/layers.43/Add_3_output_0 + - /language_model/layers.43/Add_output_0 + - /language_model/layers.44/Add_1_output_0 + - /language_model/layers.44/Add_2_output_0 + - /language_model/layers.44/Add_3_output_0 + - /language_model/layers.44/Add_output_0 + - /language_model/layers.45/Add_1_output_0 + - /language_model/layers.45/Add_2_output_0 + - /language_model/layers.45/Add_3_output_0 + - /language_model/layers.45/Add_output_0 + - /language_model/layers.46/Add_1_output_0 + - /language_model/layers.46/Add_2_output_0 + - /language_model/layers.46/Add_3_output_0 + - /language_model/layers.46/Add_output_0 + - /language_model/layers.47/Add_1_output_0 + - /language_model/layers.47/Add_2_output_0 + - /language_model/layers.47/Add_3_output_0 + - /language_model/layers.47/Add_output_0 + - /language_model/layers.48/Add_1_output_0 + - /language_model/layers.48/Add_2_output_0 + - /language_model/layers.48/Add_3_output_0 + - /language_model/layers.48/Add_output_0 + - /language_model/layers.49/Add_1_output_0 + - /language_model/layers.49/Add_2_output_0 + - /language_model/layers.49/Add_3_output_0 + - /language_model/layers.49/Add_output_0 + - /language_model/layers.50/Add_1_output_0 + - /language_model/layers.50/Add_2_output_0 + - /language_model/layers.50/Add_3_output_0 + - /language_model/layers.50/Add_output_0 + - /language_model/layers.51/Add_1_output_0 + - /language_model/layers.51/Add_2_output_0 + - /language_model/layers.51/Add_3_output_0 + - /language_model/layers.51/Add_output_0 + - /language_model/layers.52/Add_1_output_0 + - /language_model/layers.52/Add_2_output_0 + - /language_model/layers.52/Add_3_output_0 + - /language_model/layers.52/Add_output_0 + - /language_model/layers.53/Add_1_output_0 + - /language_model/layers.53/Add_2_output_0 + - /language_model/layers.53/Add_3_output_0 + - /language_model/layers.53/Add_output_0 + - /language_model/layers.54/Add_1_output_0 + - /language_model/layers.54/Add_2_output_0 + - /language_model/layers.54/Add_3_output_0 + - /language_model/layers.54/Add_output_0 + - /language_model/layers.55/Add_1_output_0 + - /language_model/layers.55/Add_2_output_0 + - /language_model/layers.55/Add_3_output_0 + - /language_model/layers.55/Add_output_0 + - /language_model/layers.56/Add_1_output_0 + - /language_model/layers.56/Add_2_output_0 + - /language_model/layers.56/Add_3_output_0 + - /language_model/layers.56/Add_output_0 + - /language_model/layers.57/Add_1_output_0 + - /language_model/layers.57/Add_2_output_0 + - /language_model/layers.57/Add_3_output_0 + - /language_model/layers.57/Add_output_0 + - /language_model/layers.58/Add_1_output_0 + - /language_model/layers.58/Add_2_output_0 + - /language_model/layers.58/Add_3_output_0 + - /language_model/layers.58/Add_output_0 + - /language_model/layers.59/Add_1_output_0 + - /language_model/layers.59/Add_2_output_0 + - /language_model/layers.59/Add_3_output_0 + - /language_model/layers.59/Add_output_0 + - /language_model/layers.60/Add_1_output_0 + - /language_model/layers.60/Add_2_output_0 + - /language_model/layers.60/Add_3_output_0 + - /language_model/layers.60/Add_output_0 + - /language_model/layers.61/Add_1_output_0 + - /language_model/layers.61/Add_2_output_0 + - /language_model/layers.61/Add_3_output_0 + - /language_model/layers.61/Add_output_0 + - /language_model/norm/Add_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.34/self_attn/Mul_output_0 + - /language_model/layers.35/self_attn/Mul_output_0 + - /language_model/layers.36/self_attn/Mul_output_0 + - /language_model/layers.37/self_attn/Mul_output_0 + - /language_model/layers.38/self_attn/Mul_output_0 + - /language_model/layers.39/self_attn/Mul_output_0 + - /language_model/layers.40/self_attn/Mul_output_0 + - /language_model/layers.41/self_attn/Mul_output_0 + - /language_model/layers.42/self_attn/Mul_output_0 + - /language_model/layers.43/self_attn/Mul_output_0 + - /language_model/layers.44/self_attn/Mul_output_0 + - /language_model/layers.45/self_attn/Mul_output_0 + - /language_model/layers.46/self_attn/Mul_output_0 + - /language_model/layers.47/self_attn/Mul_output_0 + - /language_model/layers.48/self_attn/Mul_output_0 + - /language_model/layers.49/self_attn/Mul_output_0 + - /language_model/layers.50/self_attn/Mul_output_0 + - /language_model/layers.51/self_attn/Mul_output_0 + - /language_model/layers.52/self_attn/Mul_output_0 + - /language_model/layers.53/self_attn/Mul_output_0 + - /language_model/layers.54/self_attn/Mul_output_0 + - /language_model/layers.55/self_attn/Mul_output_0 + - /language_model/layers.56/self_attn/Mul_output_0 + - /language_model/layers.57/self_attn/Mul_output_0 + - /language_model/layers.58/self_attn/Mul_output_0 + - /language_model/layers.59/self_attn/Mul_output_0 + - /language_model/layers.60/self_attn/Mul_output_0 + - /language_model/layers.61/self_attn/Mul_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/norm/CustomRMSNorm_output_0 + diff --git a/examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml b/examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml new file mode 100755 index 000000000..1c8aa1c41 --- /dev/null +++ b/examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml @@ -0,0 +1,698 @@ +FP32NodeInstanceNames: + + - /language_model/layers.0/Add_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.0/self_attn/Mul_1_output_0 + - /language_model/layers.0/self_attn/Mul_2_output_0 + - /language_model/layers.0/self_attn/Mul_3_output_0 + - /language_model/layers.0/self_attn/Mul_4_output_0 + - /language_model/layers.0/self_attn/Mul_5_output_0 + - /language_model/layers.0/self_attn/Mul_6_output_0 + - /language_model/layers.0/self_attn/Mul_7_output_0 + - /language_model/layers.0/self_attn/Mul_8_output_0 + - /language_model/layers.1/self_attn/Mul_9_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_1_output_0 + - /language_model/layers.2/self_attn/Mul_2_output_0 + - /language_model/layers.2/self_attn/Mul_3_output_0 + - /language_model/layers.2/self_attn/Mul_4_output_0 + - /language_model/layers.2/self_attn/Mul_5_output_0 + - /language_model/layers.2/self_attn/Mul_6_output_0 + - /language_model/layers.2/self_attn/Mul_7_output_0 + - /language_model/layers.2/self_attn/Mul_8_output_0 + - /language_model/layers.2/self_attn/Mul_9_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_1_output_0 + - /language_model/layers.3/self_attn/Mul_2_output_0 + - /language_model/layers.3/self_attn/Mul_3_output_0 + - /language_model/layers.3/self_attn/Mul_4_output_0 + - /language_model/layers.3/self_attn/Mul_5_output_0 + - /language_model/layers.3/self_attn/Mul_6_output_0 + - /language_model/layers.3/self_attn/Mul_7_output_0 + - /language_model/layers.3/self_attn/Mul_8_output_0 + - /language_model/layers.3/self_attn/Mul_9_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_1_output_0 + - /language_model/layers.4/self_attn/Mul_2_output_0 + - /language_model/layers.4/self_attn/Mul_3_output_0 + - /language_model/layers.4/self_attn/Mul_4_output_0 + - /language_model/layers.4/self_attn/Mul_5_output_0 + - /language_model/layers.4/self_attn/Mul_6_output_0 + - /language_model/layers.4/self_attn/Mul_7_output_0 + - /language_model/layers.4/self_attn/Mul_8_output_0 + - /language_model/layers.4/self_attn/Mul_9_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_1_output_0 + - /language_model/layers.5/self_attn/Mul_2_output_0 + - /language_model/layers.5/self_attn/Mul_3_output_0 + - /language_model/layers.5/self_attn/Mul_4_output_0 + - /language_model/layers.5/self_attn/Mul_5_output_0 + - /language_model/layers.5/self_attn/Mul_6_output_0 + - /language_model/layers.5/self_attn/Mul_7_output_0 + - /language_model/layers.5/self_attn/Mul_8_output_0 + - /language_model/layers.5/self_attn/Mul_9_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_1_output_0 + - /language_model/layers.6/self_attn/Mul_2_output_0 + - /language_model/layers.6/self_attn/Mul_3_output_0 + - /language_model/layers.6/self_attn/Mul_4_output_0 + - /language_model/layers.6/self_attn/Mul_5_output_0 + - /language_model/layers.6/self_attn/Mul_6_output_0 + - /language_model/layers.6/self_attn/Mul_7_output_0 + - /language_model/layers.6/self_attn/Mul_8_output_0 + - /language_model/layers.6/self_attn/Mul_9_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_1_output_0 + - /language_model/layers.7/self_attn/Mul_2_output_0 + - /language_model/layers.7/self_attn/Mul_3_output_0 + - /language_model/layers.7/self_attn/Mul_4_output_0 + - /language_model/layers.7/self_attn/Mul_5_output_0 + - /language_model/layers.7/self_attn/Mul_6_output_0 + - /language_model/layers.7/self_attn/Mul_7_output_0 + - /language_model/layers.7/self_attn/Mul_8_output_0 + - /language_model/layers.7/self_attn/Mul_9_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_1_output_0 + - /language_model/layers.8/self_attn/Mul_2_output_0 + - /language_model/layers.8/self_attn/Mul_3_output_0 + - /language_model/layers.8/self_attn/Mul_4_output_0 + - /language_model/layers.8/self_attn/Mul_5_output_0 + - /language_model/layers.8/self_attn/Mul_6_output_0 + - /language_model/layers.8/self_attn/Mul_7_output_0 + - /language_model/layers.8/self_attn/Mul_8_output_0 + - /language_model/layers.8/self_attn/Mul_9_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_1_output_0 + - /language_model/layers.9/self_attn/Mul_2_output_0 + - /language_model/layers.9/self_attn/Mul_3_output_0 + - /language_model/layers.9/self_attn/Mul_4_output_0 + - /language_model/layers.9/self_attn/Mul_5_output_0 + - /language_model/layers.9/self_attn/Mul_6_output_0 + - /language_model/layers.9/self_attn/Mul_7_output_0 + - /language_model/layers.9/self_attn/Mul_8_output_0 + - /language_model/layers.9/self_attn/Mul_9_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_1_output_0 + - /language_model/layers.10/self_attn/Mul_2_output_0 + - /language_model/layers.10/self_attn/Mul_3_output_0 + - /language_model/layers.10/self_attn/Mul_4_output_0 + - /language_model/layers.10/self_attn/Mul_5_output_0 + - /language_model/layers.10/self_attn/Mul_6_output_0 + - /language_model/layers.10/self_attn/Mul_7_output_0 + - /language_model/layers.10/self_attn/Mul_8_output_0 + - /language_model/layers.10/self_attn/Mul_9_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_1_output_0 + - /language_model/layers.11/self_attn/Mul_2_output_0 + - /language_model/layers.11/self_attn/Mul_3_output_0 + - /language_model/layers.11/self_attn/Mul_4_output_0 + - /language_model/layers.11/self_attn/Mul_5_output_0 + - /language_model/layers.11/self_attn/Mul_6_output_0 + - /language_model/layers.11/self_attn/Mul_7_output_0 + - /language_model/layers.11/self_attn/Mul_8_output_0 + - /language_model/layers.11/self_attn/Mul_9_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_1_output_0 + - /language_model/layers.12/self_attn/Mul_2_output_0 + - /language_model/layers.12/self_attn/Mul_3_output_0 + - /language_model/layers.12/self_attn/Mul_4_output_0 + - /language_model/layers.12/self_attn/Mul_5_output_0 + - /language_model/layers.12/self_attn/Mul_6_output_0 + - /language_model/layers.12/self_attn/Mul_7_output_0 + - /language_model/layers.12/self_attn/Mul_8_output_0 + - /language_model/layers.12/self_attn/Mul_9_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_1_output_0 + - /language_model/layers.13/self_attn/Mul_2_output_0 + - /language_model/layers.13/self_attn/Mul_3_output_0 + - /language_model/layers.13/self_attn/Mul_4_output_0 + - /language_model/layers.13/self_attn/Mul_5_output_0 + - /language_model/layers.13/self_attn/Mul_6_output_0 + - /language_model/layers.13/self_attn/Mul_7_output_0 + - /language_model/layers.13/self_attn/Mul_8_output_0 + - /language_model/layers.13/self_attn/Mul_9_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_1_output_0 + - /language_model/layers.14/self_attn/Mul_2_output_0 + - /language_model/layers.14/self_attn/Mul_3_output_0 + - /language_model/layers.14/self_attn/Mul_4_output_0 + - /language_model/layers.14/self_attn/Mul_5_output_0 + - /language_model/layers.14/self_attn/Mul_6_output_0 + - /language_model/layers.14/self_attn/Mul_7_output_0 + - /language_model/layers.14/self_attn/Mul_8_output_0 + - /language_model/layers.14/self_attn/Mul_9_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_1_output_0 + - /language_model/layers.15/self_attn/Mul_2_output_0 + - /language_model/layers.15/self_attn/Mul_3_output_0 + - /language_model/layers.15/self_attn/Mul_4_output_0 + - /language_model/layers.15/self_attn/Mul_5_output_0 + - /language_model/layers.15/self_attn/Mul_6_output_0 + - /language_model/layers.15/self_attn/Mul_7_output_0 + - /language_model/layers.15/self_attn/Mul_8_output_0 + - /language_model/layers.15/self_attn/Mul_9_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_1_output_0 + - /language_model/layers.16/self_attn/Mul_2_output_0 + - /language_model/layers.16/self_attn/Mul_3_output_0 + - /language_model/layers.16/self_attn/Mul_4_output_0 + - /language_model/layers.16/self_attn/Mul_5_output_0 + - /language_model/layers.16/self_attn/Mul_6_output_0 + - /language_model/layers.16/self_attn/Mul_7_output_0 + - /language_model/layers.16/self_attn/Mul_8_output_0 + - /language_model/layers.16/self_attn/Mul_9_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_1_output_0 + - /language_model/layers.17/self_attn/Mul_2_output_0 + - /language_model/layers.17/self_attn/Mul_3_output_0 + - /language_model/layers.17/self_attn/Mul_4_output_0 + - /language_model/layers.17/self_attn/Mul_5_output_0 + - /language_model/layers.17/self_attn/Mul_6_output_0 + - /language_model/layers.17/self_attn/Mul_7_output_0 + - /language_model/layers.17/self_attn/Mul_8_output_0 + - /language_model/layers.17/self_attn/Mul_9_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_1_output_0 + - /language_model/layers.18/self_attn/Mul_2_output_0 + - /language_model/layers.18/self_attn/Mul_3_output_0 + - /language_model/layers.18/self_attn/Mul_4_output_0 + - /language_model/layers.18/self_attn/Mul_5_output_0 + - /language_model/layers.18/self_attn/Mul_6_output_0 + - /language_model/layers.18/self_attn/Mul_7_output_0 + - /language_model/layers.18/self_attn/Mul_8_output_0 + - /language_model/layers.18/self_attn/Mul_9_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_1_output_0 + - /language_model/layers.19/self_attn/Mul_2_output_0 + - /language_model/layers.19/self_attn/Mul_3_output_0 + - /language_model/layers.19/self_attn/Mul_4_output_0 + - /language_model/layers.19/self_attn/Mul_5_output_0 + - /language_model/layers.19/self_attn/Mul_6_output_0 + - /language_model/layers.19/self_attn/Mul_7_output_0 + - /language_model/layers.19/self_attn/Mul_8_output_0 + - /language_model/layers.19/self_attn/Mul_9_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_1_output_0 + - /language_model/layers.20/self_attn/Mul_2_output_0 + - /language_model/layers.20/self_attn/Mul_3_output_0 + - /language_model/layers.20/self_attn/Mul_4_output_0 + - /language_model/layers.20/self_attn/Mul_5_output_0 + - /language_model/layers.20/self_attn/Mul_6_output_0 + - /language_model/layers.20/self_attn/Mul_7_output_0 + - /language_model/layers.20/self_attn/Mul_8_output_0 + - /language_model/layers.20/self_attn/Mul_9_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_1_output_0 + - /language_model/layers.21/self_attn/Mul_2_output_0 + - /language_model/layers.21/self_attn/Mul_3_output_0 + - /language_model/layers.21/self_attn/Mul_4_output_0 + - /language_model/layers.21/self_attn/Mul_5_output_0 + - /language_model/layers.21/self_attn/Mul_6_output_0 + - /language_model/layers.21/self_attn/Mul_7_output_0 + - /language_model/layers.21/self_attn/Mul_8_output_0 + - /language_model/layers.21/self_attn/Mul_9_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_1_output_0 + - /language_model/layers.22/self_attn/Mul_2_output_0 + - /language_model/layers.22/self_attn/Mul_3_output_0 + - /language_model/layers.22/self_attn/Mul_4_output_0 + - /language_model/layers.22/self_attn/Mul_5_output_0 + - /language_model/layers.22/self_attn/Mul_6_output_0 + - /language_model/layers.22/self_attn/Mul_7_output_0 + - /language_model/layers.22/self_attn/Mul_8_output_0 + - /language_model/layers.22/self_attn/Mul_9_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_1_output_0 + - /language_model/layers.23/self_attn/Mul_2_output_0 + - /language_model/layers.23/self_attn/Mul_3_output_0 + - /language_model/layers.23/self_attn/Mul_4_output_0 + - /language_model/layers.23/self_attn/Mul_5_output_0 + - /language_model/layers.23/self_attn/Mul_6_output_0 + - /language_model/layers.23/self_attn/Mul_7_output_0 + - /language_model/layers.23/self_attn/Mul_8_output_0 + - /language_model/layers.23/self_attn/Mul_9_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_1_output_0 + - /language_model/layers.24/self_attn/Mul_2_output_0 + - /language_model/layers.24/self_attn/Mul_3_output_0 + - /language_model/layers.24/self_attn/Mul_4_output_0 + - /language_model/layers.24/self_attn/Mul_5_output_0 + - /language_model/layers.24/self_attn/Mul_6_output_0 + - /language_model/layers.24/self_attn/Mul_7_output_0 + - /language_model/layers.24/self_attn/Mul_8_output_0 + - /language_model/layers.24/self_attn/Mul_9_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_1_output_0 + - /language_model/layers.25/self_attn/Mul_2_output_0 + - /language_model/layers.25/self_attn/Mul_3_output_0 + - /language_model/layers.25/self_attn/Mul_4_output_0 + - /language_model/layers.25/self_attn/Mul_5_output_0 + - /language_model/layers.25/self_attn/Mul_6_output_0 + - /language_model/layers.25/self_attn/Mul_7_output_0 + - /language_model/layers.25/self_attn/Mul_8_output_0 + - /language_model/layers.25/self_attn/Mul_9_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_1_output_0 + - /language_model/layers.26/self_attn/Mul_2_output_0 + - /language_model/layers.26/self_attn/Mul_3_output_0 + - /language_model/layers.26/self_attn/Mul_4_output_0 + - /language_model/layers.26/self_attn/Mul_5_output_0 + - /language_model/layers.26/self_attn/Mul_6_output_0 + - /language_model/layers.26/self_attn/Mul_7_output_0 + - /language_model/layers.26/self_attn/Mul_8_output_0 + - /language_model/layers.26/self_attn/Mul_9_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_1_output_0 + - /language_model/layers.27/self_attn/Mul_2_output_0 + - /language_model/layers.27/self_attn/Mul_3_output_0 + - /language_model/layers.27/self_attn/Mul_4_output_0 + - /language_model/layers.27/self_attn/Mul_5_output_0 + - /language_model/layers.27/self_attn/Mul_6_output_0 + - /language_model/layers.27/self_attn/Mul_7_output_0 + - /language_model/layers.27/self_attn/Mul_8_output_0 + - /language_model/layers.27/self_attn/Mul_9_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_1_output_0 + - /language_model/layers.28/self_attn/Mul_2_output_0 + - /language_model/layers.28/self_attn/Mul_3_output_0 + - /language_model/layers.28/self_attn/Mul_4_output_0 + - /language_model/layers.28/self_attn/Mul_5_output_0 + - /language_model/layers.28/self_attn/Mul_6_output_0 + - /language_model/layers.28/self_attn/Mul_7_output_0 + - /language_model/layers.28/self_attn/Mul_8_output_0 + - /language_model/layers.28/self_attn/Mul_9_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_1_output_0 + - /language_model/layers.29/self_attn/Mul_2_output_0 + - /language_model/layers.29/self_attn/Mul_3_output_0 + - /language_model/layers.29/self_attn/Mul_4_output_0 + - /language_model/layers.29/self_attn/Mul_5_output_0 + - /language_model/layers.29/self_attn/Mul_6_output_0 + - /language_model/layers.29/self_attn/Mul_7_output_0 + - /language_model/layers.29/self_attn/Mul_8_output_0 + - /language_model/layers.29/self_attn/Mul_9_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_1_output_0 + - /language_model/layers.30/self_attn/Mul_2_output_0 + - /language_model/layers.30/self_attn/Mul_3_output_0 + - /language_model/layers.30/self_attn/Mul_4_output_0 + - /language_model/layers.30/self_attn/Mul_5_output_0 + - /language_model/layers.30/self_attn/Mul_6_output_0 + - /language_model/layers.30/self_attn/Mul_7_output_0 + - /language_model/layers.30/self_attn/Mul_8_output_0 + - /language_model/layers.30/self_attn/Mul_9_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_1_output_0 + - /language_model/layers.31/self_attn/Mul_2_output_0 + - /language_model/layers.31/self_attn/Mul_3_output_0 + - /language_model/layers.31/self_attn/Mul_4_output_0 + - /language_model/layers.31/self_attn/Mul_5_output_0 + - /language_model/layers.31/self_attn/Mul_6_output_0 + - /language_model/layers.31/self_attn/Mul_7_output_0 + - /language_model/layers.31/self_attn/Mul_8_output_0 + - /language_model/layers.31/self_attn/Mul_9_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_1_output_0 + - /language_model/layers.32/self_attn/Mul_2_output_0 + - /language_model/layers.32/self_attn/Mul_3_output_0 + - /language_model/layers.32/self_attn/Mul_4_output_0 + - /language_model/layers.32/self_attn/Mul_5_output_0 + - /language_model/layers.32/self_attn/Mul_6_output_0 + - /language_model/layers.32/self_attn/Mul_7_output_0 + - /language_model/layers.32/self_attn/Mul_8_output_0 + - /language_model/layers.32/self_attn/Mul_9_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_1_output_0 + - /language_model/layers.33/self_attn/Mul_2_output_0 + - /language_model/layers.33/self_attn/Mul_3_output_0 + - /language_model/layers.33/self_attn/Mul_4_output_0 + - /language_model/layers.33/self_attn/Mul_5_output_0 + - /language_model/layers.33/self_attn/Mul_6_output_0 + - /language_model/layers.33/self_attn/Mul_7_output_0 + - /language_model/layers.33/self_attn/Mul_8_output_0 + - /language_model/layers.33/self_attn/Mul_9_output_0 + - /language_model/layers.0/self_attn/Softmax_output_0 + - /language_model/layers.1/self_attn/Softmax_output_0 + - /language_model/layers.2/self_attn/Softmax_output_0 + - /language_model/layers.3/self_attn/Softmax_output_0 + - /language_model/layers.4/self_attn/Softmax_output_0 + - /language_model/layers.5/self_attn/Softmax_output_0 + - /language_model/layers.6/self_attn/Softmax_output_0 + - /language_model/layers.7/self_attn/Softmax_output_0 + - /language_model/layers.8/self_attn/Softmax_output_0 + - /language_model/layers.9/self_attn/Softmax_output_0 + - /language_model/layers.10/self_attn/Softmax_output_0 + - /language_model/layers.11/self_attn/Softmax_output_0 + - /language_model/layers.12/self_attn/Softmax_output_0 + - /language_model/layers.13/self_attn/Softmax_output_0 + - /language_model/layers.14/self_attn/Softmax_output_0 + - /language_model/layers.15/self_attn/Softmax_output_0 + - /language_model/layers.16/self_attn/Softmax_output_0 + - /language_model/layers.17/self_attn/Softmax_output_0 + - /language_model/layers.18/self_attn/Softmax_output_0 + - /language_model/layers.19/self_attn/Softmax_output_0 + - /language_model/layers.20/self_attn/Softmax_output_0 + - /language_model/layers.21/self_attn/Softmax_output_0 + - /language_model/layers.22/self_attn/Softmax_output_0 + - /language_model/layers.23/self_attn/Softmax_output_0 + - /language_model/layers.24/self_attn/Softmax_output_0 + - /language_model/layers.25/self_attn/Softmax_output_0 + - /language_model/layers.26/self_attn/Softmax_output_0 + - /language_model/layers.27/self_attn/Softmax_output_0 + - /language_model/layers.28/self_attn/Softmax_output_0 + - /language_model/layers.29/self_attn/Softmax_output_0 + - /language_model/layers.30/self_attn/Softmax_output_0 + - /language_model/layers.31/self_attn/Softmax_output_0 + - /language_model/layers.32/self_attn/Softmax_output_0 + - /language_model/layers.33/self_attn/Softmax_output_0 + diff --git a/examples/performance/compute_context_length/gemma3.py b/examples/performance/compute_context_length/gemma3.py new file mode 100644 index 000000000..1dcec5c81 --- /dev/null +++ b/examples/performance/compute_context_length/gemma3.py @@ -0,0 +1,133 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +# Change model_id to "google/gemma-3-27b-it" for 27B model +model_id = "google/gemma-3-4b-it" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +# config.text_config.num_hidden_layers = 1 +# config.vision_config.num_hidden_layers = 2 +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +# pass HF_TOKEN if gated model +# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + attn_implementation="eager", + kv_offload=True, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) + +### use skip_vision=True, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info="examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml", + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the transformers architecture in LLMs."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + mos=1, + node_precision_info="examples/performance/compute_context_length/fp32_nodes_gemma3_4b.yaml", + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_27b.yaml b/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_27b.yaml new file mode 100755 index 000000000..d2a4bf164 --- /dev/null +++ b/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_27b.yaml @@ -0,0 +1,685 @@ +FP32NodeInstanceNames: + + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.0/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.5/Add_2_output_0 + - /language_model/layers.5/Add_3_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.11/Add_2_output_0 + - /language_model/layers.11/Add_3_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.17/Add_2_output_0 + - /language_model/layers.17/Add_3_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.23/Add_2_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.29/Add_2_output_0 + - /language_model/layers.29/Add_3_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.34/Add_1_output_0 + - /language_model/layers.34/Add_2_output_0 + - /language_model/layers.34/Add_3_output_0 + - /language_model/layers.34/Add_output_0 + - /language_model/layers.35/Add_1_output_0 + - /language_model/layers.35/Add_2_output_0 + - /language_model/layers.35/Add_3_output_0 + - /language_model/layers.35/Add_output_0 + - /language_model/layers.36/Add_1_output_0 + - /language_model/layers.36/Add_2_output_0 + - /language_model/layers.36/Add_3_output_0 + - /language_model/layers.36/Add_output_0 + - /language_model/layers.37/Add_1_output_0 + - /language_model/layers.37/Add_2_output_0 + - /language_model/layers.37/Add_3_output_0 + - /language_model/layers.37/Add_output_0 + - /language_model/layers.38/Add_1_output_0 + - /language_model/layers.38/Add_2_output_0 + - /language_model/layers.38/Add_3_output_0 + - /language_model/layers.38/Add_output_0 + - /language_model/layers.39/Add_1_output_0 + - /language_model/layers.39/Add_2_output_0 + - /language_model/layers.39/Add_3_output_0 + - /language_model/layers.39/Add_output_0 + - /language_model/layers.40/Add_1_output_0 + - /language_model/layers.40/Add_2_output_0 + - /language_model/layers.40/Add_3_output_0 + - /language_model/layers.40/Add_output_0 + - /language_model/layers.41/Add_1_output_0 + - /language_model/layers.41/Add_2_output_0 + - /language_model/layers.41/Add_3_output_0 + - /language_model/layers.41/Add_output_0 + - /language_model/layers.42/Add_1_output_0 + - /language_model/layers.42/Add_2_output_0 + - /language_model/layers.42/Add_3_output_0 + - /language_model/layers.42/Add_output_0 + - /language_model/layers.43/Add_1_output_0 + - /language_model/layers.43/Add_2_output_0 + - /language_model/layers.43/Add_3_output_0 + - /language_model/layers.43/Add_output_0 + - /language_model/layers.44/Add_1_output_0 + - /language_model/layers.44/Add_2_output_0 + - /language_model/layers.44/Add_3_output_0 + - /language_model/layers.44/Add_output_0 + - /language_model/layers.45/Add_1_output_0 + - /language_model/layers.45/Add_2_output_0 + - /language_model/layers.45/Add_3_output_0 + - /language_model/layers.45/Add_output_0 + - /language_model/layers.46/Add_1_output_0 + - /language_model/layers.46/Add_2_output_0 + - /language_model/layers.46/Add_3_output_0 + - /language_model/layers.46/Add_output_0 + - /language_model/layers.47/Add_1_output_0 + - /language_model/layers.47/Add_2_output_0 + - /language_model/layers.47/Add_3_output_0 + - /language_model/layers.47/Add_output_0 + - /language_model/layers.48/Add_1_output_0 + - /language_model/layers.48/Add_2_output_0 + - /language_model/layers.48/Add_3_output_0 + - /language_model/layers.48/Add_output_0 + - /language_model/layers.49/Add_1_output_0 + - /language_model/layers.49/Add_2_output_0 + - /language_model/layers.49/Add_3_output_0 + - /language_model/layers.49/Add_output_0 + - /language_model/layers.50/Add_1_output_0 + - /language_model/layers.50/Add_2_output_0 + - /language_model/layers.50/Add_3_output_0 + - /language_model/layers.50/Add_output_0 + - /language_model/layers.51/Add_1_output_0 + - /language_model/layers.51/Add_2_output_0 + - /language_model/layers.51/Add_3_output_0 + - /language_model/layers.51/Add_output_0 + - /language_model/layers.52/Add_1_output_0 + - /language_model/layers.52/Add_2_output_0 + - /language_model/layers.52/Add_3_output_0 + - /language_model/layers.52/Add_output_0 + - /language_model/layers.53/Add_1_output_0 + - /language_model/layers.53/Add_2_output_0 + - /language_model/layers.53/Add_3_output_0 + - /language_model/layers.53/Add_output_0 + - /language_model/layers.54/Add_1_output_0 + - /language_model/layers.54/Add_2_output_0 + - /language_model/layers.54/Add_3_output_0 + - /language_model/layers.54/Add_output_0 + - /language_model/layers.55/Add_1_output_0 + - /language_model/layers.55/Add_2_output_0 + - /language_model/layers.55/Add_3_output_0 + - /language_model/layers.55/Add_output_0 + - /language_model/layers.56/Add_1_output_0 + - /language_model/layers.56/Add_2_output_0 + - /language_model/layers.56/Add_3_output_0 + - /language_model/layers.56/Add_output_0 + - /language_model/layers.57/Add_1_output_0 + - /language_model/layers.57/Add_2_output_0 + - /language_model/layers.57/Add_3_output_0 + - /language_model/layers.57/Add_output_0 + - /language_model/layers.58/Add_1_output_0 + - /language_model/layers.58/Add_2_output_0 + - /language_model/layers.58/Add_3_output_0 + - /language_model/layers.58/Add_output_0 + - /language_model/layers.59/Add_1_output_0 + - /language_model/layers.59/Add_2_output_0 + - /language_model/layers.59/Add_3_output_0 + - /language_model/layers.59/Add_output_0 + - /language_model/layers.60/Add_1_output_0 + - /language_model/layers.60/Add_2_output_0 + - /language_model/layers.60/Add_3_output_0 + - /language_model/layers.60/Add_output_0 + - /language_model/layers.61/Add_1_output_0 + - /language_model/layers.61/Add_2_output_0 + - /language_model/layers.61/Add_3_output_0 + - /language_model/layers.61/Add_output_0 + - /language_model/norm/Add_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.34/self_attn/Mul_output_0 + - /language_model/layers.35/self_attn/Mul_output_0 + - /language_model/layers.36/self_attn/Mul_output_0 + - /language_model/layers.37/self_attn/Mul_output_0 + - /language_model/layers.38/self_attn/Mul_output_0 + - /language_model/layers.39/self_attn/Mul_output_0 + - /language_model/layers.40/self_attn/Mul_output_0 + - /language_model/layers.41/self_attn/Mul_output_0 + - /language_model/layers.42/self_attn/Mul_output_0 + - /language_model/layers.43/self_attn/Mul_output_0 + - /language_model/layers.44/self_attn/Mul_output_0 + - /language_model/layers.45/self_attn/Mul_output_0 + - /language_model/layers.46/self_attn/Mul_output_0 + - /language_model/layers.47/self_attn/Mul_output_0 + - /language_model/layers.48/self_attn/Mul_output_0 + - /language_model/layers.49/self_attn/Mul_output_0 + - /language_model/layers.50/self_attn/Mul_output_0 + - /language_model/layers.51/self_attn/Mul_output_0 + - /language_model/layers.52/self_attn/Mul_output_0 + - /language_model/layers.53/self_attn/Mul_output_0 + - /language_model/layers.54/self_attn/Mul_output_0 + - /language_model/layers.55/self_attn/Mul_output_0 + - /language_model/layers.56/self_attn/Mul_output_0 + - /language_model/layers.57/self_attn/Mul_output_0 + - /language_model/layers.58/self_attn/Mul_output_0 + - /language_model/layers.59/self_attn/Mul_output_0 + - /language_model/layers.60/self_attn/Mul_output_0 + - /language_model/layers.61/self_attn/Mul_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/norm/CustomRMSNorm_output_0 + diff --git a/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_4b.yaml b/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_4b.yaml new file mode 100755 index 000000000..1c8aa1c41 --- /dev/null +++ b/examples/performance/compute_context_length/gemma3/fp32_nodes_gemma3_4b.yaml @@ -0,0 +1,698 @@ +FP32NodeInstanceNames: + + - /language_model/layers.0/Add_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.0/self_attn/Mul_1_output_0 + - /language_model/layers.0/self_attn/Mul_2_output_0 + - /language_model/layers.0/self_attn/Mul_3_output_0 + - /language_model/layers.0/self_attn/Mul_4_output_0 + - /language_model/layers.0/self_attn/Mul_5_output_0 + - /language_model/layers.0/self_attn/Mul_6_output_0 + - /language_model/layers.0/self_attn/Mul_7_output_0 + - /language_model/layers.0/self_attn/Mul_8_output_0 + - /language_model/layers.1/self_attn/Mul_9_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_1_output_0 + - /language_model/layers.2/self_attn/Mul_2_output_0 + - /language_model/layers.2/self_attn/Mul_3_output_0 + - /language_model/layers.2/self_attn/Mul_4_output_0 + - /language_model/layers.2/self_attn/Mul_5_output_0 + - /language_model/layers.2/self_attn/Mul_6_output_0 + - /language_model/layers.2/self_attn/Mul_7_output_0 + - /language_model/layers.2/self_attn/Mul_8_output_0 + - /language_model/layers.2/self_attn/Mul_9_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_1_output_0 + - /language_model/layers.3/self_attn/Mul_2_output_0 + - /language_model/layers.3/self_attn/Mul_3_output_0 + - /language_model/layers.3/self_attn/Mul_4_output_0 + - /language_model/layers.3/self_attn/Mul_5_output_0 + - /language_model/layers.3/self_attn/Mul_6_output_0 + - /language_model/layers.3/self_attn/Mul_7_output_0 + - /language_model/layers.3/self_attn/Mul_8_output_0 + - /language_model/layers.3/self_attn/Mul_9_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_1_output_0 + - /language_model/layers.4/self_attn/Mul_2_output_0 + - /language_model/layers.4/self_attn/Mul_3_output_0 + - /language_model/layers.4/self_attn/Mul_4_output_0 + - /language_model/layers.4/self_attn/Mul_5_output_0 + - /language_model/layers.4/self_attn/Mul_6_output_0 + - /language_model/layers.4/self_attn/Mul_7_output_0 + - /language_model/layers.4/self_attn/Mul_8_output_0 + - /language_model/layers.4/self_attn/Mul_9_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_1_output_0 + - /language_model/layers.5/self_attn/Mul_2_output_0 + - /language_model/layers.5/self_attn/Mul_3_output_0 + - /language_model/layers.5/self_attn/Mul_4_output_0 + - /language_model/layers.5/self_attn/Mul_5_output_0 + - /language_model/layers.5/self_attn/Mul_6_output_0 + - /language_model/layers.5/self_attn/Mul_7_output_0 + - /language_model/layers.5/self_attn/Mul_8_output_0 + - /language_model/layers.5/self_attn/Mul_9_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_1_output_0 + - /language_model/layers.6/self_attn/Mul_2_output_0 + - /language_model/layers.6/self_attn/Mul_3_output_0 + - /language_model/layers.6/self_attn/Mul_4_output_0 + - /language_model/layers.6/self_attn/Mul_5_output_0 + - /language_model/layers.6/self_attn/Mul_6_output_0 + - /language_model/layers.6/self_attn/Mul_7_output_0 + - /language_model/layers.6/self_attn/Mul_8_output_0 + - /language_model/layers.6/self_attn/Mul_9_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_1_output_0 + - /language_model/layers.7/self_attn/Mul_2_output_0 + - /language_model/layers.7/self_attn/Mul_3_output_0 + - /language_model/layers.7/self_attn/Mul_4_output_0 + - /language_model/layers.7/self_attn/Mul_5_output_0 + - /language_model/layers.7/self_attn/Mul_6_output_0 + - /language_model/layers.7/self_attn/Mul_7_output_0 + - /language_model/layers.7/self_attn/Mul_8_output_0 + - /language_model/layers.7/self_attn/Mul_9_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_1_output_0 + - /language_model/layers.8/self_attn/Mul_2_output_0 + - /language_model/layers.8/self_attn/Mul_3_output_0 + - /language_model/layers.8/self_attn/Mul_4_output_0 + - /language_model/layers.8/self_attn/Mul_5_output_0 + - /language_model/layers.8/self_attn/Mul_6_output_0 + - /language_model/layers.8/self_attn/Mul_7_output_0 + - /language_model/layers.8/self_attn/Mul_8_output_0 + - /language_model/layers.8/self_attn/Mul_9_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_1_output_0 + - /language_model/layers.9/self_attn/Mul_2_output_0 + - /language_model/layers.9/self_attn/Mul_3_output_0 + - /language_model/layers.9/self_attn/Mul_4_output_0 + - /language_model/layers.9/self_attn/Mul_5_output_0 + - /language_model/layers.9/self_attn/Mul_6_output_0 + - /language_model/layers.9/self_attn/Mul_7_output_0 + - /language_model/layers.9/self_attn/Mul_8_output_0 + - /language_model/layers.9/self_attn/Mul_9_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_1_output_0 + - /language_model/layers.10/self_attn/Mul_2_output_0 + - /language_model/layers.10/self_attn/Mul_3_output_0 + - /language_model/layers.10/self_attn/Mul_4_output_0 + - /language_model/layers.10/self_attn/Mul_5_output_0 + - /language_model/layers.10/self_attn/Mul_6_output_0 + - /language_model/layers.10/self_attn/Mul_7_output_0 + - /language_model/layers.10/self_attn/Mul_8_output_0 + - /language_model/layers.10/self_attn/Mul_9_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_1_output_0 + - /language_model/layers.11/self_attn/Mul_2_output_0 + - /language_model/layers.11/self_attn/Mul_3_output_0 + - /language_model/layers.11/self_attn/Mul_4_output_0 + - /language_model/layers.11/self_attn/Mul_5_output_0 + - /language_model/layers.11/self_attn/Mul_6_output_0 + - /language_model/layers.11/self_attn/Mul_7_output_0 + - /language_model/layers.11/self_attn/Mul_8_output_0 + - /language_model/layers.11/self_attn/Mul_9_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_1_output_0 + - /language_model/layers.12/self_attn/Mul_2_output_0 + - /language_model/layers.12/self_attn/Mul_3_output_0 + - /language_model/layers.12/self_attn/Mul_4_output_0 + - /language_model/layers.12/self_attn/Mul_5_output_0 + - /language_model/layers.12/self_attn/Mul_6_output_0 + - /language_model/layers.12/self_attn/Mul_7_output_0 + - /language_model/layers.12/self_attn/Mul_8_output_0 + - /language_model/layers.12/self_attn/Mul_9_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_1_output_0 + - /language_model/layers.13/self_attn/Mul_2_output_0 + - /language_model/layers.13/self_attn/Mul_3_output_0 + - /language_model/layers.13/self_attn/Mul_4_output_0 + - /language_model/layers.13/self_attn/Mul_5_output_0 + - /language_model/layers.13/self_attn/Mul_6_output_0 + - /language_model/layers.13/self_attn/Mul_7_output_0 + - /language_model/layers.13/self_attn/Mul_8_output_0 + - /language_model/layers.13/self_attn/Mul_9_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_1_output_0 + - /language_model/layers.14/self_attn/Mul_2_output_0 + - /language_model/layers.14/self_attn/Mul_3_output_0 + - /language_model/layers.14/self_attn/Mul_4_output_0 + - /language_model/layers.14/self_attn/Mul_5_output_0 + - /language_model/layers.14/self_attn/Mul_6_output_0 + - /language_model/layers.14/self_attn/Mul_7_output_0 + - /language_model/layers.14/self_attn/Mul_8_output_0 + - /language_model/layers.14/self_attn/Mul_9_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_1_output_0 + - /language_model/layers.15/self_attn/Mul_2_output_0 + - /language_model/layers.15/self_attn/Mul_3_output_0 + - /language_model/layers.15/self_attn/Mul_4_output_0 + - /language_model/layers.15/self_attn/Mul_5_output_0 + - /language_model/layers.15/self_attn/Mul_6_output_0 + - /language_model/layers.15/self_attn/Mul_7_output_0 + - /language_model/layers.15/self_attn/Mul_8_output_0 + - /language_model/layers.15/self_attn/Mul_9_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_1_output_0 + - /language_model/layers.16/self_attn/Mul_2_output_0 + - /language_model/layers.16/self_attn/Mul_3_output_0 + - /language_model/layers.16/self_attn/Mul_4_output_0 + - /language_model/layers.16/self_attn/Mul_5_output_0 + - /language_model/layers.16/self_attn/Mul_6_output_0 + - /language_model/layers.16/self_attn/Mul_7_output_0 + - /language_model/layers.16/self_attn/Mul_8_output_0 + - /language_model/layers.16/self_attn/Mul_9_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_1_output_0 + - /language_model/layers.17/self_attn/Mul_2_output_0 + - /language_model/layers.17/self_attn/Mul_3_output_0 + - /language_model/layers.17/self_attn/Mul_4_output_0 + - /language_model/layers.17/self_attn/Mul_5_output_0 + - /language_model/layers.17/self_attn/Mul_6_output_0 + - /language_model/layers.17/self_attn/Mul_7_output_0 + - /language_model/layers.17/self_attn/Mul_8_output_0 + - /language_model/layers.17/self_attn/Mul_9_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_1_output_0 + - /language_model/layers.18/self_attn/Mul_2_output_0 + - /language_model/layers.18/self_attn/Mul_3_output_0 + - /language_model/layers.18/self_attn/Mul_4_output_0 + - /language_model/layers.18/self_attn/Mul_5_output_0 + - /language_model/layers.18/self_attn/Mul_6_output_0 + - /language_model/layers.18/self_attn/Mul_7_output_0 + - /language_model/layers.18/self_attn/Mul_8_output_0 + - /language_model/layers.18/self_attn/Mul_9_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_1_output_0 + - /language_model/layers.19/self_attn/Mul_2_output_0 + - /language_model/layers.19/self_attn/Mul_3_output_0 + - /language_model/layers.19/self_attn/Mul_4_output_0 + - /language_model/layers.19/self_attn/Mul_5_output_0 + - /language_model/layers.19/self_attn/Mul_6_output_0 + - /language_model/layers.19/self_attn/Mul_7_output_0 + - /language_model/layers.19/self_attn/Mul_8_output_0 + - /language_model/layers.19/self_attn/Mul_9_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_1_output_0 + - /language_model/layers.20/self_attn/Mul_2_output_0 + - /language_model/layers.20/self_attn/Mul_3_output_0 + - /language_model/layers.20/self_attn/Mul_4_output_0 + - /language_model/layers.20/self_attn/Mul_5_output_0 + - /language_model/layers.20/self_attn/Mul_6_output_0 + - /language_model/layers.20/self_attn/Mul_7_output_0 + - /language_model/layers.20/self_attn/Mul_8_output_0 + - /language_model/layers.20/self_attn/Mul_9_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_1_output_0 + - /language_model/layers.21/self_attn/Mul_2_output_0 + - /language_model/layers.21/self_attn/Mul_3_output_0 + - /language_model/layers.21/self_attn/Mul_4_output_0 + - /language_model/layers.21/self_attn/Mul_5_output_0 + - /language_model/layers.21/self_attn/Mul_6_output_0 + - /language_model/layers.21/self_attn/Mul_7_output_0 + - /language_model/layers.21/self_attn/Mul_8_output_0 + - /language_model/layers.21/self_attn/Mul_9_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_1_output_0 + - /language_model/layers.22/self_attn/Mul_2_output_0 + - /language_model/layers.22/self_attn/Mul_3_output_0 + - /language_model/layers.22/self_attn/Mul_4_output_0 + - /language_model/layers.22/self_attn/Mul_5_output_0 + - /language_model/layers.22/self_attn/Mul_6_output_0 + - /language_model/layers.22/self_attn/Mul_7_output_0 + - /language_model/layers.22/self_attn/Mul_8_output_0 + - /language_model/layers.22/self_attn/Mul_9_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_1_output_0 + - /language_model/layers.23/self_attn/Mul_2_output_0 + - /language_model/layers.23/self_attn/Mul_3_output_0 + - /language_model/layers.23/self_attn/Mul_4_output_0 + - /language_model/layers.23/self_attn/Mul_5_output_0 + - /language_model/layers.23/self_attn/Mul_6_output_0 + - /language_model/layers.23/self_attn/Mul_7_output_0 + - /language_model/layers.23/self_attn/Mul_8_output_0 + - /language_model/layers.23/self_attn/Mul_9_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_1_output_0 + - /language_model/layers.24/self_attn/Mul_2_output_0 + - /language_model/layers.24/self_attn/Mul_3_output_0 + - /language_model/layers.24/self_attn/Mul_4_output_0 + - /language_model/layers.24/self_attn/Mul_5_output_0 + - /language_model/layers.24/self_attn/Mul_6_output_0 + - /language_model/layers.24/self_attn/Mul_7_output_0 + - /language_model/layers.24/self_attn/Mul_8_output_0 + - /language_model/layers.24/self_attn/Mul_9_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_1_output_0 + - /language_model/layers.25/self_attn/Mul_2_output_0 + - /language_model/layers.25/self_attn/Mul_3_output_0 + - /language_model/layers.25/self_attn/Mul_4_output_0 + - /language_model/layers.25/self_attn/Mul_5_output_0 + - /language_model/layers.25/self_attn/Mul_6_output_0 + - /language_model/layers.25/self_attn/Mul_7_output_0 + - /language_model/layers.25/self_attn/Mul_8_output_0 + - /language_model/layers.25/self_attn/Mul_9_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_1_output_0 + - /language_model/layers.26/self_attn/Mul_2_output_0 + - /language_model/layers.26/self_attn/Mul_3_output_0 + - /language_model/layers.26/self_attn/Mul_4_output_0 + - /language_model/layers.26/self_attn/Mul_5_output_0 + - /language_model/layers.26/self_attn/Mul_6_output_0 + - /language_model/layers.26/self_attn/Mul_7_output_0 + - /language_model/layers.26/self_attn/Mul_8_output_0 + - /language_model/layers.26/self_attn/Mul_9_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_1_output_0 + - /language_model/layers.27/self_attn/Mul_2_output_0 + - /language_model/layers.27/self_attn/Mul_3_output_0 + - /language_model/layers.27/self_attn/Mul_4_output_0 + - /language_model/layers.27/self_attn/Mul_5_output_0 + - /language_model/layers.27/self_attn/Mul_6_output_0 + - /language_model/layers.27/self_attn/Mul_7_output_0 + - /language_model/layers.27/self_attn/Mul_8_output_0 + - /language_model/layers.27/self_attn/Mul_9_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_1_output_0 + - /language_model/layers.28/self_attn/Mul_2_output_0 + - /language_model/layers.28/self_attn/Mul_3_output_0 + - /language_model/layers.28/self_attn/Mul_4_output_0 + - /language_model/layers.28/self_attn/Mul_5_output_0 + - /language_model/layers.28/self_attn/Mul_6_output_0 + - /language_model/layers.28/self_attn/Mul_7_output_0 + - /language_model/layers.28/self_attn/Mul_8_output_0 + - /language_model/layers.28/self_attn/Mul_9_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_1_output_0 + - /language_model/layers.29/self_attn/Mul_2_output_0 + - /language_model/layers.29/self_attn/Mul_3_output_0 + - /language_model/layers.29/self_attn/Mul_4_output_0 + - /language_model/layers.29/self_attn/Mul_5_output_0 + - /language_model/layers.29/self_attn/Mul_6_output_0 + - /language_model/layers.29/self_attn/Mul_7_output_0 + - /language_model/layers.29/self_attn/Mul_8_output_0 + - /language_model/layers.29/self_attn/Mul_9_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_1_output_0 + - /language_model/layers.30/self_attn/Mul_2_output_0 + - /language_model/layers.30/self_attn/Mul_3_output_0 + - /language_model/layers.30/self_attn/Mul_4_output_0 + - /language_model/layers.30/self_attn/Mul_5_output_0 + - /language_model/layers.30/self_attn/Mul_6_output_0 + - /language_model/layers.30/self_attn/Mul_7_output_0 + - /language_model/layers.30/self_attn/Mul_8_output_0 + - /language_model/layers.30/self_attn/Mul_9_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_1_output_0 + - /language_model/layers.31/self_attn/Mul_2_output_0 + - /language_model/layers.31/self_attn/Mul_3_output_0 + - /language_model/layers.31/self_attn/Mul_4_output_0 + - /language_model/layers.31/self_attn/Mul_5_output_0 + - /language_model/layers.31/self_attn/Mul_6_output_0 + - /language_model/layers.31/self_attn/Mul_7_output_0 + - /language_model/layers.31/self_attn/Mul_8_output_0 + - /language_model/layers.31/self_attn/Mul_9_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_1_output_0 + - /language_model/layers.32/self_attn/Mul_2_output_0 + - /language_model/layers.32/self_attn/Mul_3_output_0 + - /language_model/layers.32/self_attn/Mul_4_output_0 + - /language_model/layers.32/self_attn/Mul_5_output_0 + - /language_model/layers.32/self_attn/Mul_6_output_0 + - /language_model/layers.32/self_attn/Mul_7_output_0 + - /language_model/layers.32/self_attn/Mul_8_output_0 + - /language_model/layers.32/self_attn/Mul_9_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_1_output_0 + - /language_model/layers.33/self_attn/Mul_2_output_0 + - /language_model/layers.33/self_attn/Mul_3_output_0 + - /language_model/layers.33/self_attn/Mul_4_output_0 + - /language_model/layers.33/self_attn/Mul_5_output_0 + - /language_model/layers.33/self_attn/Mul_6_output_0 + - /language_model/layers.33/self_attn/Mul_7_output_0 + - /language_model/layers.33/self_attn/Mul_8_output_0 + - /language_model/layers.33/self_attn/Mul_9_output_0 + - /language_model/layers.0/self_attn/Softmax_output_0 + - /language_model/layers.1/self_attn/Softmax_output_0 + - /language_model/layers.2/self_attn/Softmax_output_0 + - /language_model/layers.3/self_attn/Softmax_output_0 + - /language_model/layers.4/self_attn/Softmax_output_0 + - /language_model/layers.5/self_attn/Softmax_output_0 + - /language_model/layers.6/self_attn/Softmax_output_0 + - /language_model/layers.7/self_attn/Softmax_output_0 + - /language_model/layers.8/self_attn/Softmax_output_0 + - /language_model/layers.9/self_attn/Softmax_output_0 + - /language_model/layers.10/self_attn/Softmax_output_0 + - /language_model/layers.11/self_attn/Softmax_output_0 + - /language_model/layers.12/self_attn/Softmax_output_0 + - /language_model/layers.13/self_attn/Softmax_output_0 + - /language_model/layers.14/self_attn/Softmax_output_0 + - /language_model/layers.15/self_attn/Softmax_output_0 + - /language_model/layers.16/self_attn/Softmax_output_0 + - /language_model/layers.17/self_attn/Softmax_output_0 + - /language_model/layers.18/self_attn/Softmax_output_0 + - /language_model/layers.19/self_attn/Softmax_output_0 + - /language_model/layers.20/self_attn/Softmax_output_0 + - /language_model/layers.21/self_attn/Softmax_output_0 + - /language_model/layers.22/self_attn/Softmax_output_0 + - /language_model/layers.23/self_attn/Softmax_output_0 + - /language_model/layers.24/self_attn/Softmax_output_0 + - /language_model/layers.25/self_attn/Softmax_output_0 + - /language_model/layers.26/self_attn/Softmax_output_0 + - /language_model/layers.27/self_attn/Softmax_output_0 + - /language_model/layers.28/self_attn/Softmax_output_0 + - /language_model/layers.29/self_attn/Softmax_output_0 + - /language_model/layers.30/self_attn/Softmax_output_0 + - /language_model/layers.31/self_attn/Softmax_output_0 + - /language_model/layers.32/self_attn/Softmax_output_0 + - /language_model/layers.33/self_attn/Softmax_output_0 + diff --git a/examples/performance/compute_context_length/gpt_oss.py b/examples/performance/compute_context_length/gpt_oss.py new file mode 100644 index 000000000..92bef9148 --- /dev/null +++ b/examples/performance/compute_context_length/gpt_oss.py @@ -0,0 +1,55 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 4096 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +# In moe models like gpt-oss, since prefill_seq_len=1 both comp_ctx_lengths_prefill and comp_ctx_lengths_decode can share similar lists. +comp_ctx_lengths_prefill = comp_ctx_lengths_decode = [1024, ctx_len] + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + qaic_config={ + "ccl_enabled": True, + }, +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=ctx_len, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + prompts="Who is your creator? and What all you are allowed to do?", + generation_len=256, +) diff --git a/examples/performance/compute_context_length/granite_vision.py b/examples/performance/compute_context_length/granite_vision.py new file mode 100644 index 000000000..ef5dc3a51 --- /dev/null +++ b/examples/performance/compute_context_length/granite_vision.py @@ -0,0 +1,129 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=5500, + ctx_len=6000, + ccl_enabled=False, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=384, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + kv_offload=kv_offload, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1610 x 1109) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1109 and width 1610 + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1610, 1109)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "ibm-granite/granite-vision-3.2-2b" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 5500 + ctx_len = 8192 + generation_len = 128 + img_size = 384 + num_cores = 16 + num_devices = 4 + ctx_len = 8192 + ccl_enabled = True + # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length. + comp_ctx_lengths_prefill = [5500] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + ccl_enabled=ccl_enabled, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +The image depicts two cats lying on a pink blanket that is spread out on a red couch. The cats are positioned in a relaxed manner, with their bodies stretched out and their heads resting on the blanket. +The cat on the left is a smaller, tabby cat with a mix of black, gray, and white fur. It has a long, slender body and a distinctive tail that is curled up near its tail end. The cat on the right is a larger, +tabby cat with a mix of gray, black, and brown fur. It has + +""" diff --git a/examples/performance/compute_context_length/internvl.py b/examples/performance/compute_context_length/internvl.py new file mode 100644 index 000000000..02e965e0d --- /dev/null +++ b/examples/performance/compute_context_length/internvl.py @@ -0,0 +1,296 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import requests +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +# Process the input messages to generate prompt for the model. +def get_prompt(messages) -> str: + """Get the prompt for generation.""" + ## Chat template used for InternVL + system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + sep = "<|im_end|>\n" + + ret = system_prompt + sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + sep + else: + ret += role + return ret + + +# Processor class for InternVL models +class InternProcessor: + """ + InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. + The methods used here are borrowed from the original InternVL modelling files. + "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" + """ + + def __init__(self, model: nn.Module, tokenizer): + self.model = model + image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size + patch_size = self.model.config.vision_config.patch_size + self.template = model.config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) + self.tokenizer = tokenizer + + def build_transform(self, input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def load_image(self, image, input_size=448, max_num=12): + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + def __call__( + self, + pixel_values, + question, + messages, + roles, + history=None, + num_patches_list=None, + IMG_START_TOKEN="", + IMG_END_TOKEN="", + IMG_CONTEXT_TOKEN="", + verbose=False, + ) -> str: + if history is None and pixel_values is not None and "" not in question: + question = "\n" + question + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + + messages.append([roles[0], question]) + messages.append([roles[1], None]) + query = get_prompt(messages) + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + logger.info(f"dynamic ViT batch size: {image_bs}") + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace("", image_tokens, 1) + return query + + +def run_intern_on_aic( + model_name, + prompt, + image_url, + messages, + roles, + kv_offload=False, + prefill_seq_len=3840, + num_devices=1, + num_cores=16, + ctx_len=512, + ccl_enabled=False, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, +): + ## STEP 1 -- LOAD THE MODEL + + # The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. + # To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. + + model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + kv_offload=kv_offload, + trust_remote_code=True, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + ## STEP 2 -- EXPORT & COMPILE THE MODEL + + model.compile( + num_cores=num_cores, + num_devices=num_devices, + ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, + mxfp6_matmul=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ## STEP 3 -- SETUP THE PROCESSOR + + # InternVL doesn't have an AutoProcessor yet, so we will use our own processor class "InternProcessor" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + internProcessor = InternProcessor(model.model, tokenizer) + + ## STEP 4 -- PREPROCESS THE INPUTS + + img = requests.get(image_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + # Images are resized to (1000, 747) for inference + image = image.resize((1000, 747)) + + # preprocess the resized image + pixel_values = internProcessor.load_image(image, max_num=12) + question = "\n" + prompt + query = internProcessor(pixel_values, question, messages, roles) + inputs = tokenizer( + query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" + ) + + inputs["pixel_values"] = pixel_values + + ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION + streamer = TextStreamer(tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=128) + + +if __name__ == "__main__": + model_name = "OpenGVLab/InternVL2_5-1B" + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + + # Inputs for the model + prompt = "Please describe the image in detail." + image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + + ## Compilation parameters + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + kv_offload = True + + # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with + # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to + # incorporate the memory for the merged embeddings. + + prefill_seq_len = 3840 + num_devices = 4 + num_cores = 16 + + ctx_len = 8192 + ccl_enabled = True + # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length. + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + image_url=image_url, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + ctx_len=ctx_len, + ccl_enabled=ccl_enabled, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + +""" +Expected Response: + +The image is a promotional graphic for Microsoft Azure. It features a blue background with a hexagonal pattern on the left side. The hexagons are white and are arranged in a way that suggests a network or connectivity theme. + +On the right side of the image, the Microsoft Azure logo is prominently displayed. The logo consists of the Azure name in white, with the Microsoft logo above it, which includes four colored squares (blue, green, yellow, and red). Below the logo, the word "Azure" is written in large white letters. + +Below the logo, there is text that reads: +- "By Dinesh Kumar Wick +""" diff --git a/examples/performance/compute_context_length/llama4.py b/examples/performance/compute_context_length/llama4.py new file mode 100644 index 000000000..a867e1bd3 --- /dev/null +++ b/examples/performance/compute_context_length/llama4.py @@ -0,0 +1,140 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [3072] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=True, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=8, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you describe the image in detail.", + }, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=8, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/performance/compute_context_length/llama4_cb.py b/examples/performance/compute_context_length/llama4_cb.py new file mode 100644 index 000000000..f97160693 --- /dev/null +++ b/examples/performance/compute_context_length/llama4_cb.py @@ -0,0 +1,118 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 4096 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [3072] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [ctx_len] + +continious_batching = True +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) + +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/performance/compute_context_length/llama4_multi_image.py b/examples/performance/compute_context_length/llama4_multi_image.py new file mode 100644 index 000000000..314aa49b3 --- /dev/null +++ b/examples/performance/compute_context_length/llama4_multi_image.py @@ -0,0 +1,99 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [5376] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### For multi-image, the value of max_num_tiles should be the sum of the num_tiles values across all the images ### +qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=34, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, +) + +### Multi_image Prompt ### +image_url_1 = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) + + +image_url_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url_1}, + {"type": "image", "url": image_url_2}, + { + "type": "text", + "text": "Analyze the key elements, colors, and objects in the two images. Discuss their similarities, differences, and how they complement or contrast each other. Reflect on the emotions or ideas they convey, considering the context, light, shadow, and composition.", + }, + ], + }, +] + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +) + +inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) +streamer = TextStreamer(tokenizer) +output = qeff_model.generate(inputs=inputs, generation_len=100) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/performance/compute_context_length/mistral3.py b/examples/performance/compute_context_length/mistral3.py new file mode 100644 index 000000000..a773ddfd9 --- /dev/null +++ b/examples/performance/compute_context_length/mistral3.py @@ -0,0 +1,130 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=False, + prefill_seq_len=128, + ctx_len=4096, + ccl_enabled=False, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=1540, + num_cores=16, + num_devices=4, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + config = AutoConfig.from_pretrained(model_name) + config.vision_config._attn_implementation = "eager" + # For Testing Purpose Only + config.text_config.num_hidden_layers = 4 + config.vision_config.num_hidden_layers = 2 + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + kv_offload=kv_offload, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1540 x 1540) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1540 and width 1540 as defined in the config + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1540, 1540)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 128 + ctx_len = 8192 + generation_len = 128 + num_cores = 16 + num_devices = 4 + ccl_enabled = True + # Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. If both are None, the lists will be generated automatically based on the context length. + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + ccl_enabled=ccl_enabled, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: +The image depicts a street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese archway, known as a paifang, which is intricately designed with red columns and ornate details. The archway features Chinese characters at the top, which translate to "Chinatown Gate." +In the foreground, there is a red stop sign mounted on a pole. The street is relatively quiet, with a single dark-colored SUV driving through the archway. On either side of the archway, there are stone lion statues, which are common decorative elements in Chinese architecture and symbolize protection. + + +""" diff --git a/examples/performance/compute_context_length/molmo.py b/examples/performance/compute_context_length/molmo.py new file mode 100644 index 000000000..8d773f5fe --- /dev/null +++ b/examples/performance/compute_context_length/molmo.py @@ -0,0 +1,112 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import transformers +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "allenai/Molmo-7B-D-0924" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) +# For Testing Purpose Only +# config.num_hidden_layers = 2 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +# load the model +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +comp_ctx_lengths_prefill = [3072] # None # +comp_ctx_lengths_decode = [4096, 8192] # None # + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + kv_offload=True, + trust_remote_code=True, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +### use skip_vision=True, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + inputs = processor.process(text="Tell me about yourself") + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["input_ids"] = inputs["input_ids"].to(torch.int64) + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((536, 354)) + + inputs = processor.process(images=[image], text="Can you describe the image in detail.") + + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["pixel_values"] = inputs.pop("images") + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/performance/compute_context_length/qwen2_5_vl.py b/examples/performance/compute_context_length/qwen2_5_vl.py new file mode 100644 index 000000000..5a6818930 --- /dev/null +++ b/examples/performance/compute_context_length/qwen2_5_vl.py @@ -0,0 +1,165 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 2 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +comp_ctx_lengths_prefill = [4096] # None # +comp_ctx_lengths_decode = [6144, ctx_len] # None # + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=True, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/performance/compute_context_length/qwen2_5_vl_cb.py b/examples/performance/compute_context_length/qwen2_5_vl_cb.py new file mode 100644 index 000000000..c247a1e58 --- /dev/null +++ b/examples/performance/compute_context_length/qwen2_5_vl_cb.py @@ -0,0 +1,93 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 8192 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=8192, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, + device_ids=[0, 1, 2, 3], +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/performance/compute_context_length/qwen3moe.py b/examples/performance/compute_context_length/qwen3moe.py new file mode 100644 index 000000000..93849fa5a --- /dev/null +++ b/examples/performance/compute_context_length/qwen3moe.py @@ -0,0 +1,55 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mxint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths_prefill and comp_ctx_lengths_decode to provide two lists of context lengths for the prefilling and decoding processes. If both are None, the lists will be generated automatically based on the context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the generated token's cache index exceeds the current CCL value. + +ctx_len = 1024 +prefill_seq_len = 1 +ccl_enabled = True +# Two optional lists, comp_ctx_lengths_prefill and comp_ctx_lengths_decode, define CCL values for prefilling and decoding. +# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same ccl specializations. +comp_ctx_lengths_prefill = comp_ctx_lengths_decode = [256, 512, ctx_len] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=False, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, +) + +model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/examples/performance/compute_context_length/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/performance/compute_context_length/qwen3moe_example/ccl_qwen3moe_inference.py new file mode 100644 index 000000000..9fb4c4d43 --- /dev/null +++ b/examples/performance/compute_context_length/qwen3moe_example/ccl_qwen3moe_inference.py @@ -0,0 +1,54 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained(). +## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length. +## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process. +## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk. +## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process. +## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index. +## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold. + +ctx_len = 1024 +prefill_seq_len = 1 +# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same ccl specializations. +comp_ctx_lengths_prefill = [256, 512, ctx_len] # None # +comp_ctx_lengths_decode = [256, 512, ctx_len] # None # + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=False, + ccl_enabled=True, + num_hidden_layers=4, +) + +model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, +) +# mos=1, +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/examples/performance/compute_context_length/vlm_inference.py b/examples/performance/compute_context_length/vlm_inference.py new file mode 100644 index 000000000..294632fe3 --- /dev/null +++ b/examples/performance/compute_context_length/vlm_inference.py @@ -0,0 +1,239 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Vision-Language Model (VLM) inference with Compute Context Length (CCL) optimization. + +This example demonstrates how to use CCL optimization for vision-language models. +CCL allows using different context lengths during prefill and decode phases, +reducing memory footprint and computation while maintaining support for longer contexts. +""" + +import argparse + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + hf_token=None, + kv_offload=True, + prefill_seq_len=32, + ctx_len=8192, + ccl_enabled=False, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=560, + num_cores=16, + num_devices=4, +): + """ + Run VLM inference with CCL optimization. + + Args: + model_name: HuggingFace model ID + query: Text query about the image + image_url: URL of the image to process + hf_token: HuggingFace token for gated models + kv_offload: Enable Dual QPC mode (vision encoder and LM in separate QPCs) + prefill_seq_len: Prefill sequence length + ctx_len: Maximum context length + comp_ctx_lengths_prefill: List of context lengths for prefill phase + comp_ctx_lengths_decode: List of context lengths for decode phase + generation_len: Number of tokens to generate + img_size: Image size for processing + num_cores: Number of cores for compilation + num_devices: Number of devices to use + """ + print(f"Loading model: {model_name}") + print(f"KV offload (Dual QPC mode): {kv_offload}") + + ## STEP 1: Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=hf_token) + + # `kv_offload` determines Single QPC vs Dual QPC mode: + # - Single QPC (kv_offload=False): Entire model runs in one QPC + # - Dual QPC (kv_offload=True): Vision encoder and language model run in separate QPCs + # with outputs transferred via host for flexibility + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=hf_token, + attn_implementation="eager", + kv_offload=kv_offload, + qaic_config={ + "ccl_enabled": ccl_enabled, + }, + ) + + ## STEP 2: Export & Compile the Model + + print("\nCompiling model...") + qpc_path = model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + print(f"Model compiled successfully to: {qpc_path}") + + ## STEP 3: Load and Process the Inputs for Inference + + print(f"\nLoading image from: {image_url}") + image = Image.open(requests.get(image_url, stream=True).raw) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP 4: Run Inference on the Compiled Model + + print(f"\nQuery: {query}") + print("Generated response:") + streamer = TextStreamer(processor.tokenizer) + output_statistics = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + + print(f"Tokens generated: {len(output_statistics.generated_ids[0])}") + + +def main(): + parser = argparse.ArgumentParser( + description="Vision-Language Model (VLM) inference with Compute Context Length (CCL) optimization" + ) + parser.add_argument( + "--model-name", + type=str, + default="meta-llama/Llama-3.2-11B-Vision-Instruct", + help="HuggingFace VLM model ID", + ) + parser.add_argument( + "--query", + type=str, + default="Describe this image.", + help="Text query/question about the image", + ) + parser.add_argument( + "--image-url", + type=str, + default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + help="URL of the image to process", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="HuggingFace token for accessing gated models", + ) + parser.add_argument( + "--kv-offload", + action="store_true", + default=True, + help="Enable Dual QPC mode (vision encoder and LM in separate QPCs)", + ) + parser.add_argument( + "--prefill-seq-len", + type=int, + default=32, + help="Prefill sequence length", + ) + parser.add_argument( + "--ctx-len", + type=int, + default=8192, + help="Maximum context length", + ) + parser.add_argument( + "--ccl-enabled", + action="store_true", + help="Enable compute-context-length (CCL) feature", + ) + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda x: [int(i) for i in x.split(",")], + default=None, + help="Comma-separated list of context lengths for prefill phase (e.g., '4096')", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda x: [int(i) for i in x.split(",")], + default=None, + help="Comma-separated list of context lengths for decode phase (e.g., '6144,8192')", + ) + parser.add_argument( + "--generation-len", + type=int, + default=128, + help="Number of tokens to generate", + ) + parser.add_argument( + "--img-size", + type=int, + default=560, + help="Image size for processing", + ) + parser.add_argument( + "--num-cores", + type=int, + default=16, + help="Number of cores for compilation", + ) + parser.add_argument( + "--num-devices", + type=int, + default=4, + help="Number of devices to use", + ) + args = parser.parse_args() + + run_model( + model_name=args.model_name, + query=args.query, + image_url=args.image_url, + hf_token=args.hf_token, + kv_offload=args.kv_offload, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + ccl_enabled=args.ccl_enabled, + comp_ctx_lengths_prefill=args.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=args.comp_ctx_lengths_decode, + generation_len=args.generation_len, + img_size=args.img_size, + num_cores=args.num_cores, + num_devices=args.num_devices, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/cpp_execution/CMakeLists.txt b/examples/performance/cpp_execution/CMakeLists.txt similarity index 100% rename from examples/cpp_execution/CMakeLists.txt rename to examples/performance/cpp_execution/CMakeLists.txt diff --git a/examples/cpp_execution/InferenceSetIOBuffer.cpp b/examples/performance/cpp_execution/InferenceSetIOBuffer.cpp similarity index 100% rename from examples/cpp_execution/InferenceSetIOBuffer.cpp rename to examples/performance/cpp_execution/InferenceSetIOBuffer.cpp diff --git a/examples/cpp_execution/README.md b/examples/performance/cpp_execution/README.md similarity index 81% rename from examples/cpp_execution/README.md rename to examples/performance/cpp_execution/README.md index 386921657..2d1c604e5 100644 --- a/examples/cpp_execution/README.md +++ b/examples/performance/cpp_execution/README.md @@ -24,7 +24,7 @@ make -j 8 cd ../../../ # Need to be in base folder - efficient-transformers to run below cmd # Run the python script to get the generated text -python examples/cpp_execution/text_inference_using_cpp.py --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 14 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first +python examples/performance/cpp_execution/text_inference_cpp.py --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 14 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first ``` diff --git a/examples/cpp_execution/text_inference_using_cpp.py b/examples/performance/cpp_execution/text_inference_cpp.py similarity index 99% rename from examples/cpp_execution/text_inference_using_cpp.py rename to examples/performance/cpp_execution/text_inference_cpp.py index 072f2c57c..8355c1e44 100644 --- a/examples/cpp_execution/text_inference_using_cpp.py +++ b/examples/performance/cpp_execution/text_inference_cpp.py @@ -229,7 +229,7 @@ def tokenize_decode_output(tokenizer, generated_ids, prompt): "--prompts_txt_file_path", "--prompts-txt-file-path", type=str, - help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples/sample_prompts folder", ) parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") parser.add_argument( diff --git a/examples/on_device_sampling.py b/examples/performance/on_device_sampling.py similarity index 83% rename from examples/on_device_sampling.py rename to examples/performance/on_device_sampling.py index 00d8c2430..da9c5b43b 100644 --- a/examples/on_device_sampling.py +++ b/examples/performance/on_device_sampling.py @@ -21,6 +21,7 @@ def main(args, **kwargs): include_sampler = None return_pdfs = None max_top_k_ids = None + include_guided_decoding = None sampling_params = None bs = args.full_batch_size if args.full_batch_size is not None else args.batch_size if args.override_qaic_config is not None: @@ -28,6 +29,8 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) + include_guided_decoding = args.override_qaic_config.get("aic_include_guided_decoding", None) == "true" sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +39,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -44,13 +49,12 @@ def main(args, **kwargs): "include_sampler": include_sampler, "return_pdfs": return_pdfs, "max_top_k_ids": max_top_k_ids, + "include_guided_decoding": include_guided_decoding, }.items() if v is not None } print("qaic_config:") pprint(qaic_config) - print("sampling_params:") - pprint(sampling_params) # Load model with On Device Sampler enabled qeff_model = AutoModelForCausalLM.from_pretrained( @@ -60,6 +64,19 @@ def main(args, **kwargs): ) print(f"{args.model_name} optimized for AI 100 \n", qeff_model) + if include_guided_decoding: + # Ideally this should come from a logits processor like xgrammar, but for the sake of the + # example, we generate a random bitmask + sampling_params.update( + { + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(qeff_model.model.config.vocab_size,)), (bs, 1) + ) + } + ) + print("sampling_params:") + pprint(sampling_params) + # Compile the model for inference generated_qpc_path = qeff_model.compile( prefill_seq_len=args.prompt_len, @@ -88,6 +105,7 @@ def main(args, **kwargs): generation_len=args.generation_len, include_sampler=include_sampler, return_pdfs=return_pdfs, + include_guided_decoding=include_guided_decoding, sampling_params=sampling_params, ) @@ -106,14 +124,14 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -126,14 +144,34 @@ def main(args, **kwargs): --num-cores 16 \ --mxint8-kv-cache \ --mxfp6-matmul \ - --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512" \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:false" \ + --repetition-penalty 1.9 \ + --presence-penalty 0.8 \ + --temperature 0.67 \ + --top-k 54 \ + --top-p 0.89 \ + --min-p 0.6 \ + --random-number 26 + + 3. With guided decoding: + python3.10 examples/on_device_sampling.py \ + --model-name 'meta-llama/Llama-3.1-8B' \ + --prompt-len 128 \ + --ctx-len 256 \ + --generation-len 20 \ + --full-batch-size 2 \ + --device-group [0,1,2,3] \ + --num-cores 16 \ + --mxint8-kv-cache \ + --mxfp6-matmul \ + --override-qaic-config "aic_include_sampler:true aic_return_pdfs:false max_top_k_ids:512 aic_include_guided_decoding:true" \ --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") @@ -177,7 +215,7 @@ def main(args, **kwargs): "--prompts_txt_file_path", "--prompts-txt-file-path", type=str, - help="File path for taking input prompts from txt file, sample prompts.txt file present in examples folder", + help="File path for taking input prompts from txt file, sample prompts.txt file present in examples/sample_prompts folder", ) parser.add_argument("--generation_len", "--generation-len", type=int, help="Number of tokens to generate") diff --git a/examples/performance/speculative_decoding/README.md b/examples/performance/speculative_decoding/README.md new file mode 100644 index 000000000..e03eb45be --- /dev/null +++ b/examples/performance/speculative_decoding/README.md @@ -0,0 +1,181 @@ +# Speculative Decoding Examples + +Accelerate text generation using speculative decoding techniques on Qualcomm Cloud AI 100. + +Speculative decoding improves inference speed by generating multiple candidate tokens in parallel and validating them with the target model, reducing sequential forward passes required for text generation. + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Quick Start + +```bash +# Draft-based: Use small draft model + large target model +python draft_based.py \ + --draft-model-name "meta-llama/Llama-3.2-1B" \ + --target-model-name "meta-llama/Llama-3.2-1B" \ + --num-speculative-tokens 4 + +# Prompt Lookup: N-gram matching without draft model +python prompt_lookup.py \ + --target-model-name "meta-llama/Llama-3.2-1B" \ + --num-speculative-tokens 3 \ + --max-ngram-size 3 + +# Multi-Projection: Built-in speculation for Turbo models (requires speculator_config.json) +# Note: TinyLlama does not support multi-projection - use actual Turbo models +python multi_projection.py \ + --pretrained-model-name-or-path "meta-llama/Llama-3.1-8B-Turbo" +``` + +## Available Scripts + +### draft_based.py - Two-Model Speculative Decoding + +**How It Works:** +1. **Draft Phase**: Small, fast model generates `N` candidate tokens sequentially +2. **Validation Phase**: Large target model scores all candidates in a single forward pass +3. **Acceptance**: Greedily accept tokens until first mismatch, then sample from target distribution +4. **Iteration**: Repeat with accepted tokens + one additional target token + +This approach achieves speedup when draft model is 3-8x faster than target model. + +**Basic Usage:** +```bash +python draft_based.py \ + --draft-model-name "meta-llama/Llama-3.2-1B" \ + --target-model-name "meta-llama/Llama-3.2-8B" \ + --num-speculative-tokens 4 \ + --prefill-seq-len 32 \ + --ctx-len 128 +``` + +**Multi-Device Deployment:** +```bash +python draft_based.py \ + --draft-model-name "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \ + --target-model-name "meta-llama/Llama-3.1-70B" \ + --target-device-group 0,1,2,3 \ + --draft-device-group 4,5 \ + --num-speculative-tokens 6 +``` + +**Key Features:** +- Uses `qaic_config={"speculative_model_type": "target"}` for target model compilation +- Draft model uses fewer cores (5) vs target model (11) by default +- Supports both regular batching and continuous batching modes +- Implements "bonus token" handling for multi-batch scenarios + +**Recommended Model Pairs:** +- `TinyLlama-1.1B` → `Llama-3.1-8B` (8x size ratio) +- `Llama-3.2-1B` → `Llama-3.2-8B` (8x size ratio) +- `Llama-3.1-8B` → `Llama-3.1-70B` (9x size ratio) + +### prompt_lookup.py - N-gram Pattern Matching + +**How It Works:** +1. **Pattern Search**: Sliding window searches input context for n-gram matches +2. **Candidate Generation**: When match found, extract following tokens as candidates +3. **Fallback**: If no match, pad with dummy tokens (no speculation benefit) +4. **Validation**: Target model scores candidates like draft-based approach + +Most effective for repetitive text patterns, code with common structures, or templated content. + +**Basic Usage:** +```bash +python prompt_lookup.py \ + --target-model-name "meta-llama/Llama-3.2-8B" \ + --num-speculative-tokens 3 \ + --max-ngram-size 3 \ + --prefill-seq-len 256 \ + --ctx-len 1024 +``` + +**Optimized for Repetitive Content:** +```bash +python prompt_lookup.py \ + --target-model-name "meta-llama/Llama-3.1-8B" \ + --prompts "Write code with repeated patterns: for i in range(10): print(i)" \ + --num-speculative-tokens 5 \ + --max-ngram-size 4 \ + --ctx-len 2048 +``` + +**Key Features:** +- Implements `find_candidate_pred_tokens()` for n-gram matching +- Maintains `all_ids` array to track full context for pattern matching +- Default prompts designed for repetitive patterns (e.g., "hello, good morning to you") +- Uses `fill_tok=-1` for padding when no matches found +- No separate draft model required - uses n-gram pattern matching instead + +**Key Parameters:** +- `--max-ngram-size`: Larger values (3-5) better for structured text +- `--num-speculative-tokens`: Reduce if acceptance rate is low +- Longer context lengths improve pattern matching opportunities + +### multi_projection.py - Turbo Model Speculation + +**How It Works:** +1. **Multi-Head Projection**: Model has multiple projection heads generating token candidates +2. **Single Forward Pass**: All candidates generated simultaneously in one inference +3. **Built-in Validation**: Model internally scores and ranks candidates +4. **Optimized Architecture**: Specifically designed for speculative decoding + +Requires models with `speculative_config` and multi-projection architecture. + +**Basic Usage:** +```bash +python multi_projection.py \ + --pretrained-model-name-or-path "meta-llama/Llama-3.1-8B-Turbo" \ + --prefill-seq-len 32 \ + --ctx-len 128 +``` + +**Continuous Batching:** +```bash +python multi_projection.py \ + --pretrained-model-name-or-path "meta-llama/Llama-3.1-8B-Turbo" \ + --full-batch-size 4 \ + --device-group 0,1,2,3 \ + --ignore-eos-token +``` + +**Key Features:** +- Uses `qaic_config={"speculative_model_type": "turbo"}` for compilation +- Automatically extracts `num_speculative_tokens` from model's `speculative_config` +- Generates 4D logits tensor: `[batch, num_logits, num_logits, vocab_size]` +- No separate draft model required - speculation built into architecture + + +## Common Parameters + +| Parameter | Description | Default | Recommended | +|-----------|-------------|---------|-------------| +| `--prefill-seq-len` | Prefill chunk size | 32 | 128-256 | +| `--ctx-len` | Max context length | 128 | 512-2048 | +| `--num-speculative-tokens` | Candidates per iteration | 3-4 | 3-6 | +| `--device-group` | Device allocation | `[0]` | Multi-device for large models | +| `--full-batch-size` | Continuous batching | None | 2-8 for throughput | + +## Performance Metrics Explained + +All scripts output detailed metrics: + +``` +Avg TLM+DLM TTFT = 0.15 # Time to first token (seconds) +Decode Throughput = 125.67 # Tokens/second during generation +E2E Throughput = 98.23 # Overall tokens/second including prefill +Avg number of accepted tokens = 2.8 # Speculation effectiveness +``` + + + +## Documentation + +- [Speculative Decoding Guide](https://quic.github.io/efficient-transformers/source/features_enablement.html#speculative-decoding) +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) +- [Performance Optimization](https://quic.github.io/efficient-transformers/source/features_enablement.html) diff --git a/examples/draft_spd_inference.py b/examples/performance/speculative_decoding/draft_based.py similarity index 98% rename from examples/draft_spd_inference.py rename to examples/performance/speculative_decoding/draft_based.py index 9dccc2a1d..9e617663c 100644 --- a/examples/draft_spd_inference.py +++ b/examples/performance/speculative_decoding/draft_based.py @@ -200,7 +200,7 @@ def draft_spec_decode_inference( continuous_batching = full_batch_size is not None if target_model_session is None: target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config={"speculative_model_type": "target"} ) target_num_devices = len(target_device_group) target_model_qpc_path: str = target_model.compile( @@ -248,6 +248,7 @@ def draft_spec_decode_inference( p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) p_tok["position_ids"] = position_ids + p_tok["num_logits_to_keep"] = np.array([[1]], dtype=np.int64) prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths generated_ids = [[] for i in range(decode_batch_size)] @@ -264,6 +265,7 @@ def draft_spec_decode_inference( input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.arange(num_speculative_tokens + 1, dtype=np.int64).reshape(-1, 1), ) max_gen_len = [ctx_len] * decode_batch_size num_logits_to_keep = num_speculative_tokens + 1 diff --git a/examples/multiprojs_spd_inference.py b/examples/performance/speculative_decoding/multi_projection.py similarity index 100% rename from examples/multiprojs_spd_inference.py rename to examples/performance/speculative_decoding/multi_projection.py diff --git a/examples/pld_spd_inference.py b/examples/performance/speculative_decoding/prompt_lookup.py similarity index 98% rename from examples/pld_spd_inference.py rename to examples/performance/speculative_decoding/prompt_lookup.py index 2b5baba18..53b1f4e85 100644 --- a/examples/pld_spd_inference.py +++ b/examples/performance/speculative_decoding/prompt_lookup.py @@ -103,7 +103,7 @@ def run_prefill_on_draft_and_target( prefill_seq_len: int, slot_idx: int, ): - input_len = inputs.input_ids.shape[1] + input_len = inputs["input_ids"].shape[1] num_chunks = input_len // prefill_seq_len cache_index = np.array([[0]], np.int64) batch_index = np.array([[slot_idx]], np.int64) @@ -234,7 +234,7 @@ def pld_spec_decode_inference( # export_and_compile tlm and dlm continuous_batching = full_batch_size is not None target_model = AutoModelForCausalLM.from_pretrained( - target_model_name, continuous_batching=continuous_batching, is_tlm=True + target_model_name, continuous_batching=continuous_batching, qaic_config={"speculative_model_type": "target"} ) num_devices = len(device_group) @@ -270,6 +270,7 @@ def pld_spec_decode_inference( p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1) p_tok["position_ids"] = position_ids + p_tok["num_logits_to_keep"] = np.array([[1]], dtype=np.int64) prompts_tokenized.append(p_tok) # create caches to hold generated ids and input prompt lengths generated_ids = [[] for i in range(decode_batch_size)] @@ -280,6 +281,7 @@ def pld_spec_decode_inference( input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.arange(num_speculative_tokens + 1, dtype=np.int64).reshape(-1, 1), ) num_logits_to_keep = num_speculative_tokens + 1 max_gen_len = [ctx_len] * decode_batch_size diff --git a/examples/qwen3moe_example/qwen3moe_inference.py b/examples/qwen3moe_example/qwen3moe_inference.py deleted file mode 100644 index 3bef3a1dc..000000000 --- a/examples/qwen3moe_example/qwen3moe_inference.py +++ /dev/null @@ -1,21 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils.constants import Constants - -model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" -""" -# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function -# We will use prompt_len=1 for compilation for both cb and non-cb inference -""" -model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=False) -model.compile(prefill_seq_len=1, ctx_len=256, num_cores=16, num_devices=4, mxfp6_matmul=False, mxint8_kv_cache=False) -tokenizer = AutoTokenizer.from_pretrained(model_name) -exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/examples/prompts.txt b/examples/sample_prompts/prompts.txt similarity index 100% rename from examples/prompts.txt rename to examples/sample_prompts/prompts.txt diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md deleted file mode 100644 index 4b091347b..000000000 --- a/examples/speech_to_text/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Speech Seq2Seq -This directory contains an example script of how to use the AutoModelForSpeechSeq2Seq class. (for now, Whisper models on audio <30 seconds only has been validated) - -## Required packages: -- `librosa==0.10.2` -- `soundfile==0.13.1` - -You can install them using pip: -```sh -pip install librosa==0.10.2 soundfile==0.13.1 -``` - -To run example script after package installations: -```sh -python speech_seq2seq_models.py -``` - -Expected output for given data sample: -```sh -<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|> -``` \ No newline at end of file diff --git a/examples/speech_to_text/run_whisper_speech_to_text.py b/examples/speech_to_text/run_whisper_speech_to_text.py deleted file mode 100644 index d24389e9e..000000000 --- a/examples/speech_to_text/run_whisper_speech_to_text.py +++ /dev/null @@ -1,36 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from datasets import load_dataset -from transformers import AutoProcessor - -from QEfficient import QEFFAutoModelForSpeechSeq2Seq - -base_model_name = "openai/whisper-tiny" -ctx_len = 25 - -## STEP 1 -- load audio sample, using a standard english dataset, can load specific files if longer audio needs to be tested; also load initial processor -ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") -data = ds[0]["audio"]["array"] -# reshape to so shape corresponds to data with batch size 1 -data = data.reshape(-1) -sample_rate = ds[0]["audio"]["sampling_rate"] -processor = AutoProcessor.from_pretrained(base_model_name) - -## STEP 2 -- init base model -qeff_model = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(base_model_name) - -## STEP 3 -- export and compile model -qeff_model.compile() - -## STEP 4 -- generate output for loaded input and processor -exec_info = qeff_model.generate( - inputs=processor(data, sampling_rate=sample_rate, return_tensors="pt"), generation_len=ctx_len -) - -## STEP 5 (optional) -- use processor to decode output -print(processor.batch_decode(exec_info.generated_ids)[0]) diff --git a/examples/text_generation/README.md b/examples/text_generation/README.md new file mode 100644 index 000000000..6b80442c2 --- /dev/null +++ b/examples/text_generation/README.md @@ -0,0 +1,314 @@ +# Text Generation Examples + +Examples for running inference on text-only language models on Qualcomm Cloud AI 100. + + +## Authentication + +For private/gated models, export your HuggingFace token: +```bash +export HF_TOKEN= +``` + +## Supported Models + +**QEff Auto Class:** `QEFFAutoModelForCausalLM` + +For the complete list of supported text generation models, see the [Validated Models - Text Generation Section](../../docs/source/validate.md#text-only-language-models). + +Popular model families include: +- Llama (2, 3, 3.1, 3.2, 3.3) +- Mistral, Mixtral, Codestral +- Qwen, Qwen2, Qwen3-MoE +- Gemma, CodeGemma +- GPT-2, GPT-J +- Falcon, MPT, Phi-3 +- Granite, StarCoder + +--- + +## Python Examples + +### basic_inference.py +Simple text generation with any supported language model. + +**Usage:** +```bash +python basic_inference.py \ + --model-name Qwen/Qwen2-1.5B-Instruct \ + --prompt "Hello, how are you?" \ + --prefill-seq-len 32 \ + --ctx-len 128 \ + --num-cores 16 +``` + +This example: +- Demonstrates basic text generation workflow +- Loads any HuggingFace text model +- Compiles and runs inference on Cloud AI 100 + +### continuous_batching.py +Dynamic batching for processing multiple prompts efficiently. + +**Usage:** +```bash +python continuous_batching.py \ + --model-name meta-llama/Llama-3.1-8B \ + --prompts "Hello|Hi there|Good morning|How are you" \ + --full-batch-size 4 \ + --prefill-seq-len 128 \ + --ctx-len 512 \ + --num-cores 16 +``` + +This example: +- Demonstrates continuous batching mode +- Processes multiple prompts in parallel +- Improves throughput for multi-request scenarios +- Uses pipe-separated prompts + +### gguf_models.py +GGUF format model support (quantized models). To run GGUF format models, you need to install the `gguf` package: + +```bash +pip install gguf +``` + +**Usage:** +```bash +# With default parameters +python gguf_models.py + +# With custom parameters +python gguf_models.py \ + --model-name MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF \ + --gguf-file Mistral-7B-Instruct-v0.3.fp16.gguf \ + --prompt "How are you?" \ + --prefill-seq-len 32 \ + --ctx-len 128 \ + --num-cores 16 +``` + +This example: +- Loads models in GGUF format (quantized models) +- Demonstrates GGUF file loading from HuggingFace +- Compiles and runs inference on Cloud AI 100 +- Supports custom GGUF files and prompts + +--- + + +### moe_inference.py +Mixture of Experts (MoE) model inference. + +**Usage:** +```bash +python moe_inference.py \ + --model-name Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --prompt "Explain quantum computing" \ + --ctx-len 256 \ + --num-cores 16 +``` + +This example: +- Demonstrates MoE model inference +- Uses sparse expert activation for efficiency +- Works with Qwen, Mixtral, and other MoE models + + +## CLI Workflow + +The QEfficient CLI provides a streamlined workflow for running text generation models on Cloud AI 100. You can use individual commands for each step or the all-in-one `infer` command. + +### Quick Start: All-in-One Inference (Recommended) + +The `infer` command handles export, compile, and execute in a single step: + +```bash +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompt "Write a short story about AI" \ + --mxfp6 \ + --mxint8_kv_cache \ + --mos 1 \ + --aic_enable_depth_first +``` + +**What it does:** +1. Downloads and exports the model to ONNX +2. Compiles to QPC +3. Executes inference with your prompt + +**CLI API Reference:** [`QEfficient.cloud.infer`](https://quic.github.io/efficient-transformers/source/cli_api.html#qefficient-cloud-infer) + +### Step-by-Step Workflow + +For more control, you can execute each step individually: + +#### Step 1: Export Model to ONNX + +Export the HuggingFace model to ONNX format optimized for Cloud AI 100: + +```bash +python -m QEfficient.cloud.export \ + --model_name meta-llama/Llama-3.1-8B \ + --cache_dir ~/.cache/qeff_cache +``` + +This downloads the model and converts it to ONNX format. The ONNX model is saved in the QEfficient cache directory. + +**CLI API Reference:** [`QEfficient.cloud.export`](https://quic.github.io/efficient-transformers/source/cli_api.html#qefficient-cloud-export) + +#### Step 2: Compile Model to QPC + +Compile the ONNX model to Qualcomm Program Container (QPC) format: + +```bash +python -m QEfficient.cloud.compile \ + --onnx_path ~/.cache/qeff_cache/meta-llama/Llama-3.1-8B/onnx/model.onnx \ + --qpc_path ./qpc_output \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --mxfp6 \ + --mos 1 \ + --aic_enable_depth_first +``` + +**Note:** The `compile` API is deprecated for direct use. Use the unified `infer` API instead for most use cases. + +**CLI API Reference:** [`QEfficient.cloud.compile`](https://quic.github.io/efficient-transformers/source/cli_api.html#qefficient-cloud-compile) + +#### Step 3: Execute Inference + +Run inference using the pre-compiled QPC: + +```bash +python -m QEfficient.cloud.execute \ + --model_name meta-llama/Llama-3.1-8B \ + --qpc_path ./qpc_output/qpcs \ + --prompt "Write a short story about AI" \ + --device_group [0] +``` + +This uses the pre-compiled QPC for fast inference. You can run this multiple times with different prompts without recompiling. + +**CLI API Reference:** [`QEfficient.cloud.execute`](https://quic.github.io/efficient-transformers/source/cli_api.html#qefficient-cloud-execute) + +### Common CLI Parameters + +| Parameter | Description | Default | Example | +|-----------|-------------|---------|---------| +| `--model_name` | HuggingFace model ID | Required | `meta-llama/Llama-3.1-8B` | +| `--prompt` | Input text prompt | Required | `"Hello, how are you?"` | +| `--prompt_len` | Maximum input sequence length | 32 | `128` | +| `--ctx_len` | Maximum context length (input + output) | 128 | `512` | +| `--batch_size` | Batch size for inference | 1 | `1` | +| `--num_cores` | AI 100 cores to use | 16 | `16` or `14` | +| `--device_group` | Device IDs to use | `[0]` | `[0]` or `[0,1,2,3]` | +| `--mxfp6` | Enable MXFP6 quantization | False | Add flag to enable | +| `--mxint8_kv_cache` | Enable MXINT8 KV cache | False | Add flag to enable | +| `--mos` | Memory optimization strategy | 1 | `1` or `2` | +| `--aic_enable_depth_first` | Enable depth-first execution | False | Add flag to enable | + + +### Advanced Features + +#### Multi-Device Inference (Multi-Qranium) + +Run models across multiple devices for better performance: + +```bash +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0,1,2,3] \ + --prompt "Explain quantum computing" \ + --mxfp6 \ + --mxint8_kv_cache \ + --aic_enable_depth_first +``` + +**Documentation:** [Multi-Qranium Inference](https://quic.github.io/efficient-transformers/source/features_enablement.html#multi-qranium-inference) + +#### Continuous Batching + +Process multiple prompts efficiently with continuous batching: + +```bash +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --full_batch_size 4 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompt "Hello|Hi there|Good morning|How are you" \ + --mxfp6 \ + --mxint8_kv_cache +``` + +**Note:** Use pipe (`|`) to separate multiple prompts. When using continuous batching, do not specify `--batch_size`. + +**Documentation:** [Continuous Batching](https://quic.github.io/efficient-transformers/source/features_enablement.html#continuous-batching) + +#### Batch Processing from File + +Process multiple prompts from a text file: + +```bash +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --full_batch_size 8 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompts_txt_file_path examples/sample_prompts/prompts.txt \ + --mxfp6 \ + --mxint8_kv_cache +``` + +### CLI Examples Script + +For a comprehensive collection of copy-paste ready CLI commands, run: + +```bash +bash cli_examples.sh +``` + +This script demonstrates: +- Complete 4-step workflow (Export → Compile → Execute → Infer) +- Multi-device inference +- Continuous batching +- Batch processing from file +- Parameter explanations and best practices + +--- + + +## Additional Resources + +### Documentation +- [CLI API Reference](https://quic.github.io/efficient-transformers/source/cli_api.html) - Complete CLI command documentation +- [Quick Start Guide](https://quic.github.io/efficient-transformers/source/quick_start.html) - Getting started with QEfficient +- [Features Enablement](https://quic.github.io/efficient-transformers/source/features_enablement.html) - Advanced features guide +- [QEff Auto Classes](https://quic.github.io/efficient-transformers/source/qeff_autoclasses.html) - Python API reference +- [Validated Models](https://quic.github.io/efficient-transformers/source/validate.html) - Supported models list + + +### Model Storage +By default, exported models and QPC files are stored in `~/.cache/qeff_cache`. Customize this with: +- `QEFF_HOME`: Primary cache directory +- `XDG_CACHE_HOME`: Alternative cache location + diff --git a/examples/text_generation/basic_inference.py b/examples/text_generation/basic_inference.py new file mode 100644 index 000000000..6340ec725 --- /dev/null +++ b/examples/text_generation/basic_inference.py @@ -0,0 +1,57 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Basic text generation inference") + parser.add_argument("--model-name", type=str, default="Qwen/Qwen2-1.5B-Instruct", help="HuggingFace model ID") + parser.add_argument("--prompt", type=str, default="Hello, how are you?", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=128, help="Context length") + parser.add_argument("--generation-len", type=int, default=100, help="Number of tokens to generate") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument( + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + default=None, + help="Device IDs (comma-separated) e.g. [0,1]", + ) + args = parser.parse_args() + + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name) + + # Compile the model + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=(1 if args.device_group is None else len(args.device_group)), + ) + print(f"Model compiled to: {qpc_path}") + + # Generate text + exec_info = model.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + device_id=args.device_group, + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Generated: {exec_info.generated_texts[0]}") + + +if __name__ == "__main__": + main() diff --git a/examples/text_generation/cli_examples.sh b/examples/text_generation/cli_examples.sh new file mode 100755 index 000000000..12a426ebe --- /dev/null +++ b/examples/text_generation/cli_examples.sh @@ -0,0 +1,209 @@ +#!/bin/bash + +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# QEfficient CLI Examples for Text Generation +# This script provides a simplified workflow for running text generation on Cloud AI 100 + +echo "QEfficient CLI Workflow for Text Generation" +echo "===========================================" +echo "" +echo "This example demonstrates the complete workflow using Llama-3.1-8B" +echo "" + +# ============================================================================ +# STEP 1: EXPORT MODEL TO ONNX +# ============================================================================ + +echo "Step 1: Export Model to ONNX" +echo "-----------------------------" +echo "Export the HuggingFace model to ONNX format optimized for Cloud AI 100" +echo "" +cat << 'EOF' +python -m QEfficient.cloud.export \ + --model_name meta-llama/Llama-3.1-8B \ + --cache_dir ~/.cache/qeff_cache +EOF +echo "" +echo "This will download the model and convert it to ONNX format." +echo "The ONNX model will be saved in the QEfficient cache directory." +echo "" + +# ============================================================================ +# STEP 2: COMPILE MODEL TO QPC +# ============================================================================ + +echo "Step 2: Compile Model to QPC" +echo "-----------------------------" +echo "Compile the ONNX model to Qualcomm Program Container (QPC) format" +echo "" +cat << 'EOF' +python -m QEfficient.cloud.compile \ + --onnx_path ~/.cache/qeff_cache/meta-llama/Llama-3.1-8B/onnx/model.onnx \ + --qpc_path ./qpc_output \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --mxfp6 \ + --mos 1 \ + --aic_enable_depth_first +EOF +echo "" +echo "Compilation parameters:" +echo " --batch_size: Number of prompts to process simultaneously" +echo " --prompt_len: Maximum input prompt length" +echo " --ctx_len: Maximum context length (prompt + generation)" +echo " --num_cores: Number of AI 100 cores to use (typically 14 or 16)" +echo " --device_group: Device IDs to use (e.g., [0] for single device, [0,1,2,3] for multi-device)" +echo " --mxfp6: Enable MXFP6 quantization for better performance" +echo " --mos: Memory optimization strategy" +echo " --aic_enable_depth_first: Enable depth-first execution" +echo "" + +# ============================================================================ +# STEP 3: EXECUTE WITH COMPILED QPC +# ============================================================================ + +echo "Step 3: Execute Inference with Compiled QPC" +echo "--------------------------------------------" +echo "Run inference using the pre-compiled QPC" +echo "" +cat << 'EOF' +python -m QEfficient.cloud.execute \ + --model_name meta-llama/Llama-3.1-8B \ + --qpc_path ./qpc_output/qpcs \ + --prompt "Write a short story about AI" \ + --device_group [0] +EOF +echo "" +echo "This uses the pre-compiled QPC for fast inference." +echo "You can run this multiple times with different prompts without recompiling." +echo "" + +# ============================================================================ +# STEP 4: END-TO-END INFERENCE (ALL-IN-ONE) +# ============================================================================ + +echo "Step 4: End-to-End Inference (Recommended)" +echo "-------------------------------------------" +echo "The 'infer' command handles export, compile, and execute in one step" +echo "" +cat << 'EOF' +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompt "Write a short story about AI" \ + --mxfp6 \ + --mxint8_kv_cache \ + --mos 1 \ + --aic_enable_depth_first +EOF +echo "" +echo "This is the recommended approach for most use cases." +echo "It automatically:" +echo " 1. Downloads and exports the model to ONNX (if not cached)" +echo " 2. Compiles to QPC (if not already compiled with these settings)" +echo " 3. Executes inference with your prompt" +echo "" + +# ============================================================================ +# ADDITIONAL EXAMPLES +# ============================================================================ + +echo "" +echo "Additional Examples" +echo "===================" +echo "" + +echo "Multi-Device Inference (Multi-Qranium)" +echo "---------------------------------------" +cat << 'EOF' +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --batch_size 1 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0,1,2,3] \ + --prompt "Explain quantum computing" \ + --mxfp6 \ + --mxint8_kv_cache \ + --aic_enable_depth_first +EOF +echo "" + +echo "Continuous Batching (Multiple Prompts)" +echo "---------------------------------------" +cat << 'EOF' +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --full_batch_size 4 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompt "Hello|Hi there|Good morning|How are you" \ + --mxfp6 \ + --mxint8_kv_cache +EOF +echo "" +echo "Note: Use pipe (|) to separate multiple prompts for continuous batching" +echo "" + +echo "Batch Processing from File" +echo "---------------------------" +cat << 'EOF' +python -m QEfficient.cloud.infer \ + --model_name meta-llama/Llama-3.1-8B \ + --full_batch_size 8 \ + --prompt_len 128 \ + --ctx_len 512 \ + --num_cores 16 \ + --device_group [0] \ + --prompts_txt_file_path examples/sample_prompts/prompts.txt \ + --mxfp6 \ + --mxint8_kv_cache +EOF +echo "" + +# ============================================================================ +# NOTES AND DOCUMENTATION +# ============================================================================ + +echo "" +echo "Important Notes" +echo "===============" +echo "" +echo "Terminal Compatibility:" +echo " - Use bash terminal for best compatibility" +echo " - If using ZSH, wrap device_group in single quotes: '--device_group [0]'" +echo "" +echo "Common Parameters:" +echo " --model_name: HuggingFace model ID (e.g., meta-llama/Llama-3.1-8B)" +echo " --prompt: Input text prompt" +echo " --prompt_len: Maximum input sequence length" +echo " --ctx_len: Maximum context length (input + output)" +echo " --num_cores: AI 100 cores (typically 14 or 16)" +echo " --device_group: Device IDs [0] for single, [0,1,2,3] for multi-device" +echo " --mxfp6: Enable MXFP6 quantization (recommended)" +echo " --mxint8_kv_cache: Enable MXINT8 KV cache (recommended)" +echo " --aic_enable_depth_first: Enable depth-first execution" +echo "" +echo "For More Information:" +echo " - Full CLI API Reference: https://quic.github.io/efficient-transformers/cli_api.html" +echo " - Quick Start Guide: https://quic.github.io/efficient-transformers/quick_start.html" +echo " - Features Guide: https://quic.github.io/efficient-transformers/features_enablement.html" +echo " - Supported Models: https://quic.github.io/efficient-transformers/validate.html" +echo " - Examples README: examples/text_generation/README.md" +echo "" diff --git a/examples/text_generation/continuous_batching.py b/examples/text_generation/continuous_batching.py new file mode 100644 index 000000000..ec3a36ea9 --- /dev/null +++ b/examples/text_generation/continuous_batching.py @@ -0,0 +1,72 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Continuous batching inference") + parser.add_argument("--model-name", type=str, default="Qwen/Qwen2-1.5B-Instruct", help="HuggingFace model ID") + parser.add_argument( + "--prompts", + type=str, + default="Hello! How can I help?|Hi there! What’s up?|Hey! Need assistance?|Welcome! How can I support you today?", + help="Pipe-separated prompts for batch processing", + ) + parser.add_argument("--prefill-seq-len", type=int, default=128, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=512, help="Context length") + parser.add_argument("--full-batch-size", type=int, default=4, help="Full batch size for continuous batching") + parser.add_argument("--generation-len", type=int, default=100, help="Number of tokens to generate") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument( + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + default=None, + help="Device IDs (comma-separated) e.g. [0,1]", + ) + args = parser.parse_args() + + # Parse prompts + prompt_list = args.prompts.split("|") + print(f"Processing {len(prompt_list)} prompts with continuous batching") + + # Load tokenizer and model with continuous batching enabled + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name, continuous_batching=True) + + # Compile the model with full_batch_size for continuous batching + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + full_batch_size=args.full_batch_size, + num_cores=args.num_cores, + num_devices=(1 if args.device_group is None else len(args.device_group)), + ) + print(f"Model compiled to: {qpc_path}") + + # Generate text for all prompts + exec_info = model.generate( + tokenizer=tokenizer, + prompts=prompt_list, + device_id=args.device_group, + generation_len=args.generation_len, + ) + + # Display results + print("\n" + "=" * 80) + for i, (prompt, generated) in enumerate(zip(prompt_list, exec_info.generated_texts)): + print(f"\nPrompt {i + 1}: {prompt}") + print(f"Generated: {generated}") + print("-" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/text_generation/gguf_models.py b/examples/text_generation/gguf_models.py new file mode 100644 index 000000000..2f81ef031 --- /dev/null +++ b/examples/text_generation/gguf_models.py @@ -0,0 +1,59 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="GGUF model inference") + parser.add_argument( + "--model-name", + type=str, + default="Qwen/Qwen2-1.5B-Instruct-GGUF", + help="HuggingFace model ID for GGUF model", + ) + parser.add_argument( + "--gguf-file", + type=str, + default="qwen2-1_5b-instruct-q8_0.gguf", + help="GGUF file name within the model repository", + ) + parser.add_argument("--prompt", type=str, default="Hello! How are you?", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=128, help="Context length") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument("--num-devices", type=int, default=1, help="Number of devices") + args = parser.parse_args() + + # Load the model and tokenizer + print(f"Loading GGUF model: {args.model_name}") + print(f"GGUF file: {args.gguf_file}") + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, gguf_file=args.gguf_file) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name, gguf_file=args.gguf_file) + + # Compile the model + generated_qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=args.num_devices, + ) + print(f"Model compiled to: {generated_qpc_path}") + + # Generate text + exec_info = model.generate(prompts=[args.prompt], tokenizer=tokenizer) + print(f"\nPrompt: {args.prompt}") + print(f"Generated: {exec_info.generated_texts[0]}") + + +if __name__ == "__main__": + main() diff --git a/examples/text_generation/moe_inference.py b/examples/text_generation/moe_inference.py new file mode 100644 index 000000000..276c766dd --- /dev/null +++ b/examples/text_generation/moe_inference.py @@ -0,0 +1,66 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import argparse + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="MoE model inference") + parser.add_argument( + "--model-name", + type=str, + default="Qwen/Qwen3-30B-A3B-Instruct-2507", + help="HuggingFace MoE model ID", + ) + parser.add_argument("--prompt", type=str, default="Explain quantum computing", help="Input prompt") + parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=256, help="Context length") + parser.add_argument("--generation-len", type=int, default=None, help="Number of tokens to generate") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument( + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + default=None, + help="Device IDs (comma-separated) e.g. [0,1]", + ) + args = parser.parse_args() + + print(f"Loading MoE model: {args.model_name}") + print("Note: MoE models use sparse expert activation for efficient inference") + + # Load tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = QEFFAutoModelForCausalLM.from_pretrained(args.model_name) + + # Compile the model + qpc_path = model.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=(1 if args.device_group is None else len(args.device_group)), + ) + print(f"Model compiled to: {qpc_path}") + + # Generate text + exec_info = model.generate( + tokenizer=tokenizer, + prompts=[args.prompt], + device_id=args.device_group, + generation_len=args.generation_len, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Generated: {exec_info.generated_texts[0]}") + + +if __name__ == "__main__": + main() diff --git a/examples/wav2vec2_example/README.md b/examples/wav2vec2_example/README.md deleted file mode 100644 index fba8d9ad2..000000000 --- a/examples/wav2vec2_example/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Speech Recognition with Wav2Vec2 -This directory contains an example script of how to use the AutoModelForCTC class. (for now, Wav2Vec2 models on audio <30 seconds only has been validated) - -## Required packages: -- `librosa==0.10.2` -- `soundfile==0.13.1` - -You can install them using pip: -```sh -pip install librosa==0.10.2 soundfile==0.13.1 -``` - -To run example script after package installations: -```sh -python run_wav2vec2_inference.py -``` - -Expected output for given data sample: -```sh -MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL -``` \ No newline at end of file diff --git a/examples/wav2vec2_example/run_wav2vec2_inference.py b/examples/wav2vec2_example/run_wav2vec2_inference.py deleted file mode 100644 index 961aabeb8..000000000 --- a/examples/wav2vec2_example/run_wav2vec2_inference.py +++ /dev/null @@ -1,24 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from datasets import load_dataset -from transformers import AutoProcessor - -from QEfficient import QEFFAutoModelForCTC - -base_model_name = "facebook/wav2vec2-base-960h" - -ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") -data = ds[0]["audio"]["array"] -# reshape to so shape corresponds to data with batch size 1 -data = data.reshape(-1) -sample_rate = ds[0]["audio"]["sampling_rate"] -processor = AutoProcessor.from_pretrained(base_model_name) - -model = QEFFAutoModelForCTC.from_pretrained(base_model_name) -model.compile(num_cores=16) -print(model.generate(processor, inputs=data)) diff --git a/notebooks/QEfficientGPT2.ipynb b/notebooks/QEfficientGPT2.ipynb index 74e8097bb..350f8bc31 100644 --- a/notebooks/QEfficientGPT2.ipynb +++ b/notebooks/QEfficientGPT2.ipynb @@ -33,6 +33,9 @@ "outputs": [], "source": [ "# Initiate the Original Transformer model\n", + "# Initiate the tokenizer for transformers library\n", + "from transformers import AutoTokenizer\n", + "\n", "from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM\n", "\n", "# Please uncomment and use appropriate Cache Directory for transformers, in case you don't want to use default ~/.cache dir.\n", @@ -92,11 +95,7 @@ "# Compile the model for provided compilation arguments\n", "# Please use platform SDK to Check num_cores for your card.\n", "\n", - "qeff_model.compile(\n", - " num_cores=14,\n", - " mxfp6=True,\n", - " device_group=[0],\n", - ")" + "qeff_model.compile(num_cores=14, mxfp6_matmul=True)" ] }, { @@ -116,8 +115,8 @@ "source": [ "# post compilation, we can print the latency stats for the kv models, We provide API to print token and Latency stats on Cloud AI 100\n", "# We need the compiled prefill and decode qpc to compute the token generated, This is based on Greedy Sampling Approach\n", - "\n", - "qeff_model.generate(prompts=[\"My name is\"])" + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "qeff_model.generate(prompts=[\"My name is\"], tokenizer=tokenizer)" ] } ], diff --git a/notebooks/QEfficientMPT.ipynb b/notebooks/QEfficientMPT.ipynb index d1a1f3c5f..3bb99ecbc 100644 --- a/notebooks/QEfficientMPT.ipynb +++ b/notebooks/QEfficientMPT.ipynb @@ -32,6 +32,8 @@ "outputs": [], "source": [ "# Initiate the Original Transformer model\n", + "# Initiate the tokenizer for transformers library\n", + "from transformers import AutoTokenizer\n", "\n", "from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM\n", "\n", @@ -91,11 +93,7 @@ "# Compile the model for provided compilation arguments\n", "# Please use platform SDK to Check num_cores for your card.\n", "\n", - "qeff_model.compile(\n", - " num_cores=14,\n", - " mxfp6=True,\n", - " device_group=[0],\n", - ")" + "qeff_model.compile(num_cores=14, mxfp6_matmul=True)" ] }, { @@ -116,7 +114,8 @@ "# post compilation, we can print the latency stats for the kv models, We provide API to print token and Latency stats on Cloud AI 100\n", "# We need the compiled prefill and decode qpc to compute the token generated, This is based on Greedy Sampling Approach\n", "\n", - "qeff_model.generate(prompts=[\"My name is\"])" + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "qeff_model.generate(prompts=[\"My name is\"], tokenizer=tokenizer)" ] } ], diff --git a/pyproject.toml b/pyproject.toml index ea3c3405d..9da98f71d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,9 +20,10 @@ classifiers = [ requires-python = ">=3.8,<3.11" dependencies = [ "transformers==4.55.0", + "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", + "peft==0.17.0", "datasets==2.20.0", "fsspec==2023.6.0", "multidict==6.0.4", @@ -39,18 +40,22 @@ dependencies = [ "fire", "py7zr", "torchmetrics==1.7.0", + "ftfy==6.3.1", + "imageio==2.37.2", + "imageio-ffmpeg==0.6.0", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", "torch@https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", + "torchvision@https://download.pytorch.org/whl/cpu/torchvision-0.22.0%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", + "torchvision@https://download.pytorch.org/whl/cpu/torchvision-0.22.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", ] [project.optional-dependencies] test = ["pytest","pytest-mock"] docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"] quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"] - [build-system] requires = ["setuptools>=62.0.0"] build-backend = "setuptools.build_meta" @@ -72,3 +77,16 @@ target-version = "py310" addopts = "-W ignore -s -v" junit_logging = "all" doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +markers = [ + "on_qaic: marks tests as requiring QAIC hardware", + "diffusion_models: marks tests for diffusion models", + "wan: marks tests for WAN model", + "flux: marks tests for Flux model", + "regular: marks regular tests", + "nightly: marks nightly tests", + "multimodal: marks multimodal tests", + "qnn: marks QNN tests", + "cli: marks CLI tests", + "finetune: marks finetune tests", + "vllm: marks vLLM tests" +] diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index d9d391d47..3420c025b 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -25,7 +25,6 @@ pipeline { pip install junitparser pytest-xdist && pip install librosa==0.10.2 soundfile==0.13.1 && #packages needed to load example for whisper testing pip install --extra-index-url https://download.pytorch.org/whl/cpu timm==1.0.14 torchvision==0.22.0+cpu einops==0.8.1 && #packages to load VLMs - pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl && # For finetuning tests rm -rf QEfficient" ''' } @@ -34,7 +33,7 @@ pipeline { parallel { stage('Run Non-CLI Non-QAIC Tests') { steps { - timeout(time: 25, unit: 'MINUTES') { + timeout(time: 40, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -42,7 +41,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' @@ -59,7 +58,7 @@ pipeline { mkdir -p $PWD/Non_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic && - pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log2.xml && + pytest tests -m '(not cli) and (on_qaic) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2.xml && junitparser merge tests/tests_log2.xml tests/tests_log.xml && deactivate" ''' @@ -68,9 +67,9 @@ pipeline { } } } - stage('QAIC MultiModal Tests') { + stage('QAIC MultiModal Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -78,20 +77,38 @@ pipeline { mkdir -p $PWD/Non_cli_qaic_multimodal && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && - pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log6.xml && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log6.xml && junitparser merge tests/tests_log6.xml tests/tests_log.xml && deactivate" ''' } } } + stage('QAIC Diffusion Models Tests') { + steps { + timeout(time: 120, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qaic_diffusion && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qaic_diffusion && + export HF_HUB_CACHE=/huggingface_hub && + pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml && + junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } stage('Inference Tests') { steps { - timeout(time: 60, unit: 'MINUTES') { + timeout(time: 120, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " - source /qnn_sdk/bin/envsetup.sh && - source /qnn_sdk/bin/envcheck -c && + #source /qnn_sdk/bin/envsetup.sh && + #source /qnn_sdk/bin/envcheck -c && cd /efficient-transformers && . preflight_qeff/bin/activate && mkdir -p $PWD/cli && @@ -163,11 +180,13 @@ pipeline { // } stage('Finetune CLI Tests') { steps { - timeout(time: 5, unit: 'MINUTES') { + timeout(time: 20, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && . preflight_qeff/bin/activate && + pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl && + pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cpu && mkdir -p $PWD/cli_qaic_finetuning && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/cli_qaic_finetuning && @@ -181,15 +200,15 @@ pipeline { } post { - success { - // Trigger downstream job only if this pipeline succeeds - build job: 'qefficient_vllm_upstream', - parameters: [ - string(name: 'NAME', value: "${BUILD_TAG}"), - string(name: 'QEFF_WORKSPACE', value: "${env.WORKSPACE}") - ], - wait: false - } + // success { + // // Trigger downstream job only if this pipeline succeeds + // build job: 'qefficient_vllm_upstream', + // parameters: [ + // string(name: 'NAME', value: "${BUILD_TAG}"), + // string(name: 'QEFF_WORKSPACE', value: "${env.WORKSPACE}") + // ], + // wait: false + // } always { script { try { @@ -200,9 +219,13 @@ pipeline { echo "Failed to change ownership: ${error}" } } - junit testResults: 'tests/tests_log.xml' - } - unsuccessful { + script { + try { + junit testResults: 'tests/tests_log.xml', allowEmptyResults: true + } catch (error) { + echo "No test results file found or parsing failed: ${error}" + } + } script { try { sh ''' @@ -215,5 +238,18 @@ pipeline { echo 'Cleaning Workspace' deleteDir() } + // unsuccessful { + // script { + // try { + // sh ''' + // sudo docker rm -f ${BUILD_TAG} + // ''' + // } catch (error) { + // echo "Failed to delete container ${BUILD_TAG}: ${error}" + // } + // } + // echo 'Cleaning Workspace' + // deleteDir() + // } } } \ No newline at end of file diff --git a/scripts/memory_profiling/README.md b/scripts/memory_profiling/README.md new file mode 100644 index 000000000..efb995815 --- /dev/null +++ b/scripts/memory_profiling/README.md @@ -0,0 +1,199 @@ +# QEfficient Memory Profiling + +A memory profiling solution for QEfficient workflows with manual operation marking. + + + +## Quick Start + +```python +from profiler import QEffMemoryProfiler +from QEfficient import QEFFAutoModelForCausalLM +from transformers import AutoTokenizer + +# Initialize profiler with verbose output to see detailed memory tracking information +profiler = QEffMemoryProfiler(verbose=True) +# Start monitoring memory usage - this begins tracking memory consumption +profiler.start_monitoring() + +# Mark the start of model loading operation for memory profiling, this will help to create stage wise partitioning the output graph +profiler.mark_operation("Loading model") + +model = QEFFAutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + +# Mark the export operation +profiler.mark_operation("Export") +model.export() + +# Mark the compilation operation +profiler.mark_operation("Compile") +model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16) + +# Mark the text generation operation +profiler.mark_operation("Generation") +output = model.generate(prompts=["Hello world"], tokenizer=tokenizer, generation_len=100) + +# Stop memory monitoring and generate reports +profiler.stop_monitoring() + +# Print a detailed memory usage report to the console showing peak memory and operation-wise breakdown (optional) +print(profiler.get_memory_report()) + +# Generate a visual graph of memory usage over time and save it as an image file +profiler.generate_memory_graph("profile.png") +``` + +## Configuration + +### Basic Configuration + +```python +profiler = QEffMemoryProfiler( + sampling_interval=0.1, # Sample every 100ms + output_file="my_profile.png", # Custom output file + verbose=True, # Enable detailed logging + enable_cpu_monitoring=True, # Monitor CPU usage + enable_disk_monitoring=True, # Monitor disk I/O +) +``` + +### Manual Operation Marking + +```python +profiler = QEffMemoryProfiler() +profiler.start_monitoring() + +# Manual operation marking +profiler.mark_operation("Custom Operation 1") +# ... your code ... + +profiler.mark_operation("Custom Operation 2") +# ... more code ... + +profiler.stop_monitoring() +``` + +## API Reference + +### QEffMemoryProfiler + +#### Constructor Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `sampling_interval` | `float` | `0.05` | Time between samples (seconds) | +| `output_file` | `str` | `"qeff_memory_profile.png"` | Output file path | +| `verbose` | `bool` | `False` | Enable verbose logging | +| `enable_cpu_monitoring` | `bool` | `True` | Monitor CPU usage | +| `enable_disk_monitoring` | `bool` | `True` | Monitor disk I/O | + +#### Methods + +- **`start_monitoring()`**: Start background monitoring +- **`stop_monitoring()`**: Stop monitoring and mark completion +- **`mark_operation(name: str)`**: Manually mark operation start +- **`get_memory_report() -> str`**: Generate comprehensive text report +- **`generate_memory_graph(filename: str)`**: Create visualization +- **`stop_and_save(filename: str) -> str`**: Convenience method to stop and save + +#### Properties + +- **`peak_rss`**: Peak RSS memory usage (MB) +- **`peak_operation`**: Operation during peak memory +- **`samples`**: List of collected profiling samples +- **`operations`**: List of marked operations with timestamps + +## Operation Types + +The profiler supports marking these common QEfficient operations: + +- **Model Loading**: `from_pretrained`, `AutoModel`, `AutoTokenizer` +- **Export**: `model.export()`, ONNX transforms, PyTorch transforms +- **Compilation**: `model.compile()`, QNN compilation +- **Generation**: `model.generate()`, inference execution +- **Cleanup**: Memory cleanup, garbage collection + +## Output + +### Console Report +``` +QEFFICIENT PERFORMANCE MONITORING REPORT +============================================================ +Peak Memory Usage: + • RSS (Physical): 18.7 GB at 14:23:45 + • Peak during: Compilation + +Memory Statistics: + • Current RSS: 16.2 GB (Delta: +15.8 GB) + • Duration: 185.3 seconds + • Operations: 4 + +QEfficient Operations Timeline: + 1. 0.0s - Model Loading (25.2s) [+8.2 GB] + 2. 25.2s - Export (15.4s) [+2.1 GB] + 3. 40.6s - Compilation (120.8s) [+6.3 GB] <- Peak + 4. 161.4s - Generation (18.7s) [+1.2 GB] +``` + +### Visualization + +The profiler generates a comprehensive 4-panel visualization: + +1. **Memory Timeline**: RSS usage with colored operation phases +2. **CPU Usage**: CPU utilization with performance zones +3. **Disk I/O**: Read/write activity per operation phase +4. **Phase Duration**: Timing analysis with duration labels + +#### Sample Output + +![Sample Memory Profile](memory_profile_llama3.2.png) + +*Example memory profiling output showing QEfficient workflow phases including model loading, ONNX transforms, compilation, and generation phases with detailed memory, CPU, and disk I/O metrics.* + +## Advanced Usage + + +### Accessing Raw Data + +```python +# Get synchronized data arrays +data = profiler.get_synchronized_data() +timestamps = data['timestamps'] +memory_usage = data['rss_memory'] +cpu_usage = data['cpu_usage'] + +# Access individual samples +for sample in profiler.samples: + print(f"Time: {sample.timestamp}, RSS: {sample.rss_mb} MB") +``` + +## Integration Examples + +### With Existing QEfficient Scripts + +```python +# Add to existing QEfficient workflow +profiler = QEffMemoryProfiler(output_file="workflow_profile.png") +profiler.start_monitoring() + +# Existing QEfficient code unchanged +model = QEFFAutoModelForCausalLM.from_pretrained(model_name) +# ... rest of workflow ... + +# Add at end +report = profiler.stop_and_save() +print(report) +``` + + +## Limitations + +### Disk I/O Tracking + +**Subprocess I/O Limitation**: Disk I/O tracking captures parent process I/O only. Subprocess I/O (e.g., compilation reading ONNX files via `subprocess.run()`) is not captured due to Linux I/O accounting limitations. During compilation phases, expect lower I/O readings than actual file operations performed by subprocesses. + +## Compatibility + +- **Python**: 3.7+ +- **Dependencies**: `psutil`, `matplotlib`, `numpy` diff --git a/scripts/memory_profiling/__init__.py b/scripts/memory_profiling/__init__.py new file mode 100644 index 000000000..dc1377d0b --- /dev/null +++ b/scripts/memory_profiling/__init__.py @@ -0,0 +1,53 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Profiling + +A production-ready memory profiling solution specifically designed for QEfficient workflows. +Provides manual operation marking, comprehensive metrics collection, and professional visualization. + +Usage Example: + +```python +from scripts.memory_profiling import QEffMemoryProfiler + +profiler = QEffMemoryProfiler(verbose=True) +profiler.start_monitoring() +# ... your QEfficient code ... +profiler.stop_monitoring() +print(profiler.get_memory_report()) +profiler.generate_memory_graph() +``` +""" + +__version__ = "2.0.0" +__author__ = "Qualcomm Technologies, Inc." + +# Core profiler components +from .profiler import ( + MetricsCollector, + ProfilerConfig, + ProfileSample, + QEffMemoryProfiler, +) + +# Visualization component (imported on-demand) +try: + from .visualizer import QEffMemoryVisualizer +except ImportError: + # Handle case where matplotlib is not available + QEffMemoryVisualizer = None + +__all__ = [ + "QEffMemoryProfiler", + "ProfilerConfig", + "ProfileSample", + "MetricsCollector", + "QEffMemoryVisualizer", + "__version__", +] diff --git a/scripts/memory_profiling/memory_profile_llama3.2.png b/scripts/memory_profiling/memory_profile_llama3.2.png new file mode 100644 index 000000000..e91c1d04a Binary files /dev/null and b/scripts/memory_profiling/memory_profile_llama3.2.png differ diff --git a/scripts/memory_profiling/profiler.py b/scripts/memory_profiling/profiler.py new file mode 100644 index 000000000..cfd53e4d7 --- /dev/null +++ b/scripts/memory_profiling/profiler.py @@ -0,0 +1,729 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Profiler - Production-Ready Memory Monitoring + +This module provides comprehensive memory profiling capabilities specifically +designed for QEfficient workflows. +""" + +import os +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Tuple + +import psutil + +from QEfficient.utils.logging_utils import logger + + +@dataclass +class ProfilerConfig: + """Configuration for memory profiler.""" + + sampling_interval: float = 0.2 + output_file: Optional[str] = None + verbose: bool = False + enable_cpu_monitoring: bool = True + enable_disk_monitoring: bool = True + track_child_processes: bool = True + child_scan_interval: float = 1.0 + + +@dataclass +class ProfileSample: + """Single profiling sample containing all metrics.""" + + timestamp: datetime + rss_mb: float + vms_mb: float + cpu_percent: float = 0.0 + disk_read_mb: float = 0.0 + disk_write_mb: float = 0.0 + disk_read_rate: float = 0.0 + disk_write_rate: float = 0.0 + + +class MetricsCollector: + """Handles collection of system metrics with child process support.""" + + def __init__(self, config: ProfilerConfig): + self.config = config + self.process = psutil.Process(os.getpid()) + self._last_disk_counters = None + self._last_disk_time = None + self._cpu_initialized = False + self._last_cpu_ema = 0.0 + self._cpu_ema_alpha = 0.3 + + # Child process tracking + self._track_children = config.track_child_processes + self._child_processes: Dict[int, psutil.Process] = {} + self._last_child_scan = 0.0 + self._child_scan_interval = config.child_scan_interval + self._child_cpu_cache: Dict[int, float] = {} + + if self._track_children and self.config.verbose: + logger.info("Child process tracking enabled") + + def initialize_cpu_monitoring(self) -> None: + """Initialize CPU monitoring.""" + try: + self.process.cpu_percent() # First call to establish baseline + self._cpu_initialized = True + + # Initialize child process CPU monitoring + if self._track_children: + self._update_child_processes() + for child_proc in self._child_processes.values(): + try: + child_proc.cpu_percent() # Initialize baseline for children + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if self.config.verbose: + logger.info("CPU measurement initialized") + except Exception as e: + if self.config.verbose: + logger.warning(f"CPU initialization warning: {e}") + self._cpu_initialized = False + + def _update_child_processes(self) -> None: + """Discover and track child processes (compilation subprocesses).""" + current_time = time.time() + # Only scan for children if we don't have any, or every 5 seconds + scan_interval = 5.0 if self._child_processes else self._child_scan_interval + if current_time - self._last_child_scan < scan_interval: + return + + try: + # Get current children (recursive to catch subprocess chains) + children = self.process.children(recursive=True) + + # Add new children + new_children_count = 0 + for child in children: + if child.pid not in self._child_processes: + try: + # Verify child is still running and accessible + if child.is_running(): + self._child_processes[child.pid] = child + self._child_cpu_cache[child.pid] = 0.0 + + # Initialize CPU monitoring for new child + try: + child.cpu_percent() # First call to establish baseline + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass # Child may have terminated quickly + + new_children_count += 1 + + if self.config.verbose: + try: + cmd_name = child.name() + logger.info(f"Tracking new subprocess: PID {child.pid} ({cmd_name})") + except (psutil.NoSuchProcess, psutil.AccessDenied): + logger.info(f"Tracking new subprocess: PID {child.pid}") + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + # Remove terminated children + terminated_pids = [] + for pid, proc in self._child_processes.items(): + try: + if not proc.is_running(): + terminated_pids.append(pid) + except (psutil.NoSuchProcess, psutil.AccessDenied): + terminated_pids.append(pid) + + for pid in terminated_pids: + if pid in self._child_processes: + del self._child_processes[pid] + if pid in self._child_cpu_cache: + del self._child_cpu_cache[pid] + if self.config.verbose: + logger.info(f"Removed terminated subprocess: PID {pid}") + + if new_children_count > 0 and self.config.verbose: + logger.info(f"Now tracking {len(self._child_processes)} child processes") + + except Exception as e: + if self.config.verbose: + logger.warning(f"Child process scan error: {e}") + + self._last_child_scan = current_time + + def get_memory_usage(self) -> Tuple[float, float]: + """Get current memory usage in MB (parent + children).""" + try: + # Parent process memory + mem_info = self.process.memory_info() + total_rss = mem_info.rss / 1024 / 1024 + total_vms = mem_info.vms / 1024 / 1024 + + # Add child process memory (if tracking enabled) + if self._track_children: + child_rss = 0.0 + child_vms = 0.0 + active_children = 0 + stale_children = [] + + # Iterate through current child processes + for pid, child_proc in self._child_processes.items(): + try: + child_mem = child_proc.memory_info() + child_rss += child_mem.rss / 1024 / 1024 + child_vms += child_mem.vms / 1024 / 1024 + active_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Mark child as stale for cleanup + stale_children.append(pid) + continue + + # Clean up stale children (don't do this during iteration) + for pid in stale_children: + if pid in self._child_processes: + del self._child_processes[pid] + if pid in self._child_cpu_cache: + del self._child_cpu_cache[pid] + + total_rss += child_rss + total_vms += child_vms + + if self.config.verbose and active_children > 0: + logger.debug( + f"Memory: Parent {mem_info.rss / 1024 / 1024:.1f}MB + " + f"Children {child_rss:.1f}MB = Total {total_rss:.1f}MB RSS" + ) + + return total_rss, total_vms + except Exception as e: + if self.config.verbose: + logger.warning(f"Memory collection error: {e}") + return 0.0, 0.0 + + def get_cpu_usage(self) -> float: + """Get CPU usage with child processes included and smoothing.""" + if not self.config.enable_cpu_monitoring: + return 0.0 + + try: + import multiprocessing + + num_cores = multiprocessing.cpu_count() + + parent_cpu_raw = 0.0 + child_cpu_raw_total = 0.0 + + # Parent CPU (raw percentage, can be >100% on multi-core) + if self._cpu_initialized: + parent_cpu_raw = self.process.cpu_percent() + if parent_cpu_raw < 0: + parent_cpu_raw = 0.0 + + # Child CPU (if tracking enabled) + if self._track_children: + active_children = 0 + + for pid, child_proc in list(self._child_processes.items()): + try: + child_cpu_raw = child_proc.cpu_percent() + if child_cpu_raw >= 0: + # Cache raw CPU value + self._child_cpu_cache[pid] = child_cpu_raw + child_cpu_raw_total += child_cpu_raw + active_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Use cached value if available, otherwise skip + if pid in self._child_cpu_cache: + child_cpu_raw_total += self._child_cpu_cache[pid] + continue + + if self.config.verbose and active_children > 0: + # Convert to system-wide percentage for logging + parent_system_pct = parent_cpu_raw / num_cores + child_system_pct = child_cpu_raw_total / num_cores + logger.debug( + f"CPU: Parent {parent_system_pct:.1f}% + " + f"Children {child_system_pct:.1f}% (from {active_children} processes) " + f"= {parent_system_pct + child_system_pct:.1f}% system-wide" + ) + + # Calculate system-wide CPU percentage + # psutil.Process.cpu_percent() returns per-process CPU time percentage + # To get system-wide percentage: divide by number of cores + total_process_cpu = parent_cpu_raw + child_cpu_raw_total + system_wide_cpu = total_process_cpu / num_cores + + # Cap at 100% (shouldn't exceed this in normal cases) + system_wide_cpu = min(system_wide_cpu, 100.0) + + # Apply exponential moving average smoothing + if system_wide_cpu > 0 or self._last_cpu_ema > 0: + smoothed_cpu = self._cpu_ema_alpha * system_wide_cpu + (1 - self._cpu_ema_alpha) * self._last_cpu_ema + self._last_cpu_ema = smoothed_cpu + return smoothed_cpu + + return 0.0 + except Exception as e: + if self.config.verbose: + logger.warning(f"CPU collection error: {e}") + return self._last_cpu_ema + + def get_disk_io_stats(self) -> Tuple[float, float, float, float]: + """Get disk I/O statistics with rate calculation (parent + children).""" + if not self.config.enable_disk_monitoring: + return 0.0, 0.0, 0.0, 0.0 + + try: + current_time = time.time() + + # Parent process I/O + parent_io = self.process.io_counters() + + # Determine which counters to use + use_chars = hasattr(parent_io, "read_chars") and hasattr(parent_io, "write_chars") + + if use_chars: + total_read_bytes = parent_io.read_chars + total_write_bytes = parent_io.write_chars + else: + total_read_bytes = parent_io.read_bytes + total_write_bytes = parent_io.write_bytes + + # Add child process I/O (if tracking enabled) + if self._track_children: + child_read_total = 0 + child_write_total = 0 + active_io_children = 0 + + for pid, child_proc in list(self._child_processes.items()): + try: + child_io = child_proc.io_counters() + if use_chars and hasattr(child_io, "read_chars") and hasattr(child_io, "write_chars"): + child_read_total += child_io.read_chars + child_write_total += child_io.write_chars + else: + child_read_total += child_io.read_bytes + child_write_total += child_io.write_bytes + active_io_children += 1 + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # Child process terminated or inaccessible + continue + + total_read_bytes += child_read_total + total_write_bytes += child_write_total + + if self.config.verbose and active_io_children > 0: + parent_read_mb = ( + parent_io.read_chars / 1024 / 1024 if use_chars else parent_io.read_bytes / 1024 / 1024 + ) + parent_write_mb = ( + parent_io.write_chars / 1024 / 1024 if use_chars else parent_io.write_bytes / 1024 / 1024 + ) + child_read_mb = child_read_total / 1024 / 1024 + child_write_mb = child_write_total / 1024 / 1024 + logger.debug( + f"Disk I/O: Parent R:{parent_read_mb:.1f}MB W:{parent_write_mb:.1f}MB + " + f"Children R:{child_read_mb:.1f}MB W:{child_write_mb:.1f}MB " + f"(from {active_io_children} processes)" + ) + + # Convert to MB + read_mb = total_read_bytes / 1024 / 1024 + write_mb = total_write_bytes / 1024 / 1024 + + # Calculate rates + read_rate = 0.0 + write_rate = 0.0 + + if self._last_disk_counters is not None and self._last_disk_time is not None: + time_delta = current_time - self._last_disk_time + if time_delta > 0: + # Calculate delta from last measurement + if use_chars: + last_read = self._last_disk_counters.get("read_chars", 0) + last_write = self._last_disk_counters.get("write_chars", 0) + else: + last_read = self._last_disk_counters.get("read_bytes", 0) + last_write = self._last_disk_counters.get("write_bytes", 0) + + read_delta = (total_read_bytes - last_read) / 1024 / 1024 # MB + write_delta = (total_write_bytes - last_write) / 1024 / 1024 # MB + + read_rate = read_delta / time_delta # MB/s + write_rate = write_delta / time_delta # MB/s + + # Update counters (store as dict to handle both counter types) + if use_chars: + self._last_disk_counters = {"read_chars": total_read_bytes, "write_chars": total_write_bytes} + else: + self._last_disk_counters = {"read_bytes": total_read_bytes, "write_bytes": total_write_bytes} + self._last_disk_time = current_time + + return read_mb, write_mb, read_rate, write_rate + + except Exception as e: + if self.config.verbose: + logger.warning(f"Disk I/O collection error: {e}") + return 0.0, 0.0, 0.0, 0.0 + + def collect_sample(self) -> ProfileSample: + """Collect a complete profiling sample.""" + timestamp = datetime.now() + rss_mb, vms_mb = self.get_memory_usage() + cpu_percent = self.get_cpu_usage() + read_bytes, write_bytes, read_rate, write_rate = self.get_disk_io_stats() + + return ProfileSample( + timestamp=timestamp, + rss_mb=rss_mb, + vms_mb=vms_mb, + cpu_percent=cpu_percent, + disk_read_mb=read_bytes, + disk_write_mb=write_bytes, + disk_read_rate=read_rate, + disk_write_rate=write_rate, + ) + + +class QEffMemoryProfiler: + """ + Production-ready memory profiler for QEfficient workflows. + + Features: + - Manual operation marking for QEfficient workflows + - Production-quality visualization with detailed segment analysis + - Precise memory attribution and performance metrics + - Professional-grade reporting suitable for debugging and optimization + """ + + # Segment colors for visualization + SEGMENT_COLORS = { + "Initialization": "#E8E8E8", + "Model Loading": "#FF6B6B", + "Export": "#FFEAA7", + "Model Export": "#FFEAA7", + "Compilation": "#98D8C8", + "Model Compilation": "#98D8C8", + "Generation": "#F7DC6F", + "Text Generation": "#F7DC6F", + "Cleanup": "#AED6F1", + "Completion": "#D5DBDB", + } + + def __init__( + self, sampling_interval: float = 0.05, output_file: Optional[str] = None, verbose: bool = False, **kwargs + ): + """ + Initialize the QEfficient Memory Profiler. + + Args: + sampling_interval: Time between memory samples in seconds + output_file: Output file for memory profile graph + verbose: Enable verbose output for monitoring operations + """ + # Create configuration + self.config = ProfilerConfig( + sampling_interval=sampling_interval, + output_file=output_file or "qeff_memory_profile.png", + verbose=verbose, + **kwargs, + ) + + # Initialize components + self.metrics_collector = MetricsCollector(self.config) + + # Monitoring state + self.monitoring = False + self.monitor_thread = None + + # self.samples = deque(maxlen=5000) # Auto-evicts old samples + self.samples: List[ProfileSample] = [] # This could slow down for very long runs + self.operations: List[Tuple[datetime, str]] = [] + + # Peak tracking + self.peak_rss = 0.0 + self.peak_vms = 0.0 + self.peak_rss_time: Optional[datetime] = None + self.peak_vms_time: Optional[datetime] = None + self.peak_operation: Optional[str] = None + + # Operation tracking + self.current_operation = "Initialization" + self.operation_start_time = datetime.now() + self.operation_durations: Dict[str, float] = {} + self.operation_memory_deltas: Dict[str, float] = {} + + # Legacy property accessors for backward compatibility + @property + def timestamps(self) -> List[datetime]: + """Get timestamps from samples.""" + return [sample.timestamp for sample in self.samples] + + @property + def rss_memory(self) -> List[float]: + """Get RSS memory values from samples.""" + return [sample.rss_mb for sample in self.samples] + + @property + def vms_memory(self) -> List[float]: + """Get VMS memory values from samples.""" + return [sample.vms_mb for sample in self.samples] + + @property + def cpu_usage(self) -> List[float]: + """Get CPU usage values from samples.""" + return [sample.cpu_percent for sample in self.samples] + + @property + def disk_read_bytes(self) -> List[float]: + """Get disk read bytes from samples.""" + return [sample.disk_read_mb for sample in self.samples] + + @property + def disk_write_bytes(self) -> List[float]: + """Get disk write bytes from samples.""" + return [sample.disk_write_mb for sample in self.samples] + + @property + def disk_read_rate(self) -> List[float]: + """Get disk read rates from samples.""" + return [sample.disk_read_rate for sample in self.samples] + + @property + def disk_write_rate(self) -> List[float]: + """Get disk write rates from samples.""" + return [sample.disk_write_rate for sample in self.samples] + + @property + def sampling_interval(self) -> float: + """Get sampling interval.""" + return self.config.sampling_interval + + @property + def output_file(self) -> str: + """Get output file path.""" + return self.config.output_file + + @property + def verbose(self) -> bool: + """Get verbose flag.""" + return self.config.verbose + + def start_monitoring(self) -> None: + """Start continuous memory monitoring in background thread.""" + if self.monitoring: + return + + # Initialize CPU measurement + self.metrics_collector.initialize_cpu_monitoring() + + self.monitoring = True + self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.monitor_thread.start() + + if self.config.verbose: + logger.info(f"QEff Memory monitoring started (sampling every {self.config.sampling_interval}s)") + + def stop_monitoring(self) -> None: + """Stop memory monitoring and generate reports.""" + if not self.monitoring: + return + + self.monitoring = False + if self.monitor_thread: + self.monitor_thread.join(timeout=1.0) + + # Mark completion + self.mark_operation("Completion") + + if self.config.verbose: + logger.info("QEff Memory monitoring stopped") + + def _monitor_loop(self) -> None: + """Background monitoring loop.""" + while self.monitoring: + try: + # Update child processes periodically (throttled internally) + if self.metrics_collector._track_children: + self.metrics_collector._update_child_processes() + + # Collect sample + sample = self.metrics_collector.collect_sample() + self.samples.append(sample) + + # Update peaks + self._update_peaks(sample) + + time.sleep(self.config.sampling_interval) + + except Exception as e: + if self.config.verbose: + logger.warning(f"Monitoring error: {e}") + break + + def _update_peaks(self, sample: ProfileSample) -> None: + """Update peak memory tracking.""" + if sample.rss_mb > self.peak_rss: + self.peak_rss = sample.rss_mb + self.peak_rss_time = sample.timestamp + self.peak_operation = self.current_operation + + if sample.vms_mb > self.peak_vms: + self.peak_vms = sample.vms_mb + self.peak_vms_time = sample.timestamp + + def mark_operation(self, operation_name: str) -> None: + """Mark the start of a new operation.""" + current_time = datetime.now() + current_rss = self.samples[-1].rss_mb if self.samples else 0.0 + + # Record previous operation duration and memory delta + if self.current_operation != "Initialization" and self.samples: + duration = (current_time - self.operation_start_time).total_seconds() + self.operation_durations[self.current_operation] = duration + + # Calculate memory delta from start of operation + start_idx = max(0, len(self.samples) - max(1, int(duration / self.config.sampling_interval))) + start_rss = self.samples[start_idx].rss_mb if start_idx < len(self.samples) else current_rss + memory_delta = current_rss - start_rss + self.operation_memory_deltas[self.current_operation] = memory_delta + + # Start new operation + self.current_operation = operation_name + self.operation_start_time = current_time + self.operations.append((current_time, operation_name)) + + if self.config.verbose: + logger.info(f"{operation_name} | Memory: {current_rss:.1f} MB RSS") + + def get_synchronized_data(self) -> Dict[str, List[float]]: + """Get synchronized data arrays.""" + if not self.samples: + return {} + + start_time = self.samples[0].timestamp + return { + "timestamps": [(s.timestamp - start_time).total_seconds() for s in self.samples], + "rss_memory": [s.rss_mb for s in self.samples], + "vms_memory": [s.vms_mb for s in self.samples], + "cpu_usage": [s.cpu_percent for s in self.samples], + "disk_read_bytes": [s.disk_read_mb for s in self.samples], + "disk_write_bytes": [s.disk_write_mb for s in self.samples], + "disk_read_rate": [s.disk_read_rate for s in self.samples], + "disk_write_rate": [s.disk_write_rate for s in self.samples], + } + + def mark_segment(self, segment_name: str) -> None: + """Convenience method for manual segment marking (API mode).""" + self.mark_operation(segment_name) + + def stop_and_save(self, filename: Optional[str] = None) -> str: + """Stop monitoring and save results (API mode convenience).""" + self.stop_monitoring() + self.generate_memory_graph(filename) + return self.get_memory_report() + + def get_memory_report(self) -> str: + """Generate comprehensive memory usage report.""" + if not self.samples: + return "No memory data collected" + + current_sample = self.samples[-1] + initial_sample = self.samples[0] + + # Calculate statistics + rss_values = [s.rss_mb for s in self.samples] + avg_rss = sum(rss_values) / len(rss_values) + max_rss = max(rss_values) + min_rss = min(rss_values) + + # Auto-scale units + rss_scale, rss_unit = (1024, "GB") if max_rss > 2048 else (1, "MB") + + # Calculate disk I/O statistics + disk_io_stats = "" + if self.samples and len(self.samples) > 1: + total_read = current_sample.disk_read_mb - initial_sample.disk_read_mb + total_write = current_sample.disk_write_mb - initial_sample.disk_write_mb + max_read_rate = max(s.disk_read_rate for s in self.samples) + max_write_rate = max(s.disk_write_rate for s in self.samples) + avg_read_rate = sum(s.disk_read_rate for s in self.samples) / len(self.samples) + avg_write_rate = sum(s.disk_write_rate for s in self.samples) / len(self.samples) + + disk_io_stats = f""" +Disk I/O Statistics: + • Total Read: {total_read:.2f} MB + • Total Write: {total_write:.2f} MB + • Peak Read Rate: {max_read_rate:.2f} MB/s + • Peak Write Rate:{max_write_rate:.2f} MB/s + • Avg Read Rate: {avg_read_rate:.2f} MB/s + • Avg Write Rate: {avg_write_rate:.2f} MB/s""" + + report = f""" +QEFFICIENT PERFORMANCE MONITORING REPORT +{"=" * 60} +Peak Memory Usage: + • RSS (Physical): {self.peak_rss / rss_scale:.2f} {rss_unit} at {self.peak_rss_time.strftime("%H:%M:%S") if self.peak_rss_time else "N/A"} + • VMS (Virtual): {self.peak_vms / rss_scale:.2f} {rss_unit} at {self.peak_vms_time.strftime("%H:%M:%S") if self.peak_vms_time else "N/A"} + • Peak during: {self.peak_operation} + +Memory Statistics: + • Current RSS: {current_sample.rss_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.rss_mb - initial_sample.rss_mb) / rss_scale:+.2f} {rss_unit}) + • Current VMS: {current_sample.vms_mb / rss_scale:.2f} {rss_unit} (Delta: {(current_sample.vms_mb - initial_sample.vms_mb) / rss_scale:+.2f} {rss_unit}) + • Average RSS: {avg_rss / rss_scale:.2f} {rss_unit} + • Min/Max RSS: {min_rss / rss_scale:.2f} / {max_rss / rss_scale:.2f} {rss_unit} + • Memory Range: {(max_rss - min_rss) / rss_scale:.2f} {rss_unit}{disk_io_stats} + +Monitoring Info: + • Duration: {(current_sample.timestamp - initial_sample.timestamp).total_seconds():.1f} seconds + • Data Points: {len(self.samples)} + • Operations: {len(self.operations)} + • Sampling Rate: {self.config.sampling_interval}s + +QEfficient Operations Timeline:""" + + # Add operation timeline + if self.operations: + start_time = self.samples[0].timestamp + for i, (op_time, op_name) in enumerate(self.operations): + relative_time = (op_time - start_time).total_seconds() + duration = self.operation_durations.get(op_name, 0) + memory_delta = self.operation_memory_deltas.get(op_name, 0) + + duration_str = f"({duration:.1f}s)" if duration > 0 else "" + memory_str = f"[{memory_delta / rss_scale:+.1f} {rss_unit}]" if abs(memory_delta) > 10 else "" + + report += f"\n {i + 1:2d}. {relative_time:6.1f}s - {op_name} {duration_str} {memory_str}" + + return report + + def generate_memory_graph(self, filename: Optional[str] = None) -> None: + """Generate professional memory usage graph with QEfficient operation segments.""" + if not self.samples: + logger.warning("No data to plot") + return + + output_file = filename or self.config.output_file + + # Import visualization module + from visualizer import QEffMemoryVisualizer + + visualizer = QEffMemoryVisualizer(self) + visualizer.generate_professional_graph(output_file) + + if self.config.verbose: + logger.info(f"QEfficient memory profile saved as: {output_file}") + + # Legacy methods for backward compatibility + def get_memory_usage(self) -> Tuple[float, float]: + """Get current memory usage in MB (legacy method).""" + return self.metrics_collector.get_memory_usage() diff --git a/scripts/memory_profiling/visualizer.py b/scripts/memory_profiling/visualizer.py new file mode 100644 index 000000000..c16c0c0ef --- /dev/null +++ b/scripts/memory_profiling/visualizer.py @@ -0,0 +1,604 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Visualizer - Production Quality Enhanced Visualization + +This module provides production-quality visualization with detailed segment analysis, +clear operation boundaries, and comprehensive memory metrics. +""" + +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from .profiler import QEffMemoryProfiler + +from QEfficient.utils.logging_utils import logger + + +class QEffMemoryVisualizer: + """Production-quality memory visualization with enhanced segment analysis.""" + + def __init__(self, profiler: "QEffMemoryProfiler"): + """Initialize visualizer with profiler data.""" + self.profiler = profiler + self._setup_matplotlib_style() + + def _setup_matplotlib_style(self) -> None: + """Configure matplotlib for professional styling.""" + plt.style.use("default") + plt.rcParams.update( + { + "font.size": 10, + "font.family": ["DejaVu Sans", "sans-serif"], + "axes.linewidth": 1.2, + "figure.facecolor": "white", + "axes.facecolor": "white", + "grid.alpha": 0.3, + "lines.linewidth": 2.0, + "axes.spines.top": False, + "axes.spines.right": False, + "axes.edgecolor": "#333333", + "text.color": "#333333", + "axes.labelcolor": "#333333", + "xtick.color": "#333333", + "ytick.color": "#333333", + } + ) + + def generate_professional_graph(self, filename: str) -> None: + """Generate enhanced multi-panel memory profile with synchronized visualization.""" + if not self.profiler.samples: + logger.warning("No data to plot") + return + + # Get synchronized data + sync_data = self.profiler.get_synchronized_data() + + # Create figure with professional layout - Fixed spacing to prevent title overlap + fig = plt.figure(figsize=(20, 12), facecolor="white") + gs = fig.add_gridspec( + 3, + 2, + height_ratios=[2.5, 1.8, 1.2], + width_ratios=[1, 1], + hspace=0.35, + wspace=0.2, + left=0.05, + right=0.98, + top=0.90, + bottom=0.08, + ) + + # Create subplots + ax_memory = fig.add_subplot(gs[0, :]) # Memory usage (full width) + ax_cpu = fig.add_subplot(gs[1, :]) # CPU usage (full width) + ax_disk = fig.add_subplot(gs[2, 0]) # Disk I/O (left) + ax_timing = fig.add_subplot(gs[2, 1]) # Phase Duration (right) + + # Prepare data + relative_times = sync_data["timestamps"] + max_rss = max(sync_data["rss_memory"]) if sync_data["rss_memory"] else 0 + use_gb = max_rss > 2048 + scale = 1024 if use_gb else 1 + unit = "GB" if use_gb else "MB" + rss_scaled = [x / scale for x in sync_data["rss_memory"]] + + # Normalize CPU usage to prevent > 100% values (multi-core issue) + normalized_cpu = [min(cpu, 100.0) for cpu in sync_data["cpu_usage"]] + + # Setup plots + self._setup_memory_plot(ax_memory, relative_times, rss_scaled, scale, unit) + self._setup_cpu_plot(ax_cpu, relative_times, normalized_cpu) + self._setup_disk_io_plot(ax_disk, sync_data) + self._setup_timing_plot(ax_timing) + + # Add main title with proper spacing + fig.suptitle( + "QEfficient Enhanced Memory & Performance Analysis - Synchronized View", + fontsize=18, + fontweight="bold", + color="#2E86AB", + y=0.95, + ) + + # Save with high quality + plt.savefig( + filename, dpi=300, bbox_inches="tight", facecolor="white", edgecolor="none", format="png", pad_inches=0.2 + ) + plt.close() + + logger.info(f"Enhanced synchronized memory profile saved: {filename}") + + def _setup_memory_plot( + self, ax, relative_times: List[float], rss_scaled: List[float], scale: float, unit: str + ) -> None: + """Setup the main memory usage plot with enhanced visualization.""" + if not relative_times or not rss_scaled: + ax.text( + 0.5, + 0.5, + "No memory data available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + return + + start_time = self.profiler.samples[0].timestamp + + # Draw segment backgrounds + self._draw_segment_backgrounds(ax, relative_times, rss_scaled, start_time) + + # Main memory line + ax.plot( + relative_times, rss_scaled, color="#2E86AB", linewidth=3.5, label="Memory Usage (RSS)", alpha=0.9, zorder=5 + ) + ax.fill_between(relative_times, rss_scaled, alpha=0.15, color="#2E86AB", zorder=1) + + # Add segment boundaries and annotations + self._draw_segment_boundaries(ax, start_time, max(rss_scaled)) + self._mark_peak_memory(ax, start_time, scale, unit) + + # Format axes + ax.set_xlabel("Time (seconds)", fontsize=13, fontweight="bold") + ax.set_ylabel(f"Memory Usage ({unit})", fontsize=13, fontweight="bold") + ax.set_xlim(0, max(relative_times) * 1.02) + ax.set_ylim(0, max(rss_scaled) * 1.15) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC") + ax.set_axisbelow(True) + + # Enhanced title + total_duration = relative_times[-1] if relative_times else 0 + peak_memory = max(rss_scaled) if rss_scaled else 0 + ax.set_title( + f"Memory Usage Over Time | Peak: {peak_memory:.1f} {unit} | Duration: {total_duration:.1f}s", + fontsize=14, + fontweight="bold", + color="#2E86AB", + pad=15, + ) + + # Add legend + self._add_segment_legend(ax) + + def _setup_cpu_plot(self, ax, relative_times: List[float], cpu_usage: List[float]) -> None: + """Setup CPU plot with perfect synchronization to memory plot.""" + if not relative_times or not cpu_usage or len(cpu_usage) != len(relative_times): + ax.text( + 0.5, + 0.5, + "CPU data not available or not synchronized", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("CPU Usage Over Time", fontsize=14, fontweight="bold") + if relative_times: + ax.set_xlim(0, max(relative_times) * 1.02) + return + + start_time = self.profiler.samples[0].timestamp + + # Draw segment backgrounds for consistency + self._draw_segment_backgrounds(ax, relative_times, cpu_usage, start_time, max_val=100) + + # Main CPU line + ax.plot(relative_times, cpu_usage, color="#FF6B35", linewidth=3, label="CPU Usage", alpha=0.9, zorder=5) + ax.fill_between(relative_times, cpu_usage, alpha=0.2, color="#FF6B35", zorder=1) + + # Add segment boundaries + self._draw_segment_boundaries(ax, start_time, max(cpu_usage) if cpu_usage else 100) + + # Add average line + avg_cpu = sum(cpu_usage) / len(cpu_usage) + ax.axhline( + y=avg_cpu, + color="#E74C3C", + linestyle="-", + alpha=0.8, + linewidth=2.5, + label=f"Average: {avg_cpu:.1f}%", + zorder=4, + ) + + # Add performance zones + ax.axhspan(0, 25, alpha=0.08, color="#4CAF50", zorder=0) + ax.axhspan(25, 50, alpha=0.08, color="#FFC107", zorder=0) + ax.axhspan(50, 75, alpha=0.08, color="#FF9800", zorder=0) + ax.axhspan(75, 100, alpha=0.08, color="#F44336", zorder=0) + + # Format axes + ax.set_ylabel("CPU Usage (%)", fontsize=13, fontweight="bold") + ax.set_xlabel("Time (seconds)", fontsize=12, fontweight="bold") + ax.set_xlim(0, max(relative_times) * 1.02) + ax.set_ylim(0, max(cpu_usage) * 1.1 if cpu_usage else 100) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.8, color="#CCCCCC") + ax.set_axisbelow(True) + + # Enhanced title + max_cpu = max(cpu_usage) + ax.set_title( + f"CPU Usage Over Time | Peak: {max_cpu:.1f}% | Average: {avg_cpu:.1f}%", + fontsize=14, + fontweight="bold", + color="#FF6B35", + pad=15, + ) + + # Compact legend + ax.legend(loc="upper right", fontsize=10, framealpha=0.9) + + def _setup_disk_io_plot(self, ax, sync_data: Dict[str, List[float]]) -> None: + """Setup enhanced disk I/O plot showing phase-based analysis.""" + if not self.profiler.operations or len(self.profiler.operations) < 2: + ax.text( + 0.5, + 0.5, + "No operation phases available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold") + return + + # Calculate I/O per phase + operations, read_totals, write_totals = self._calculate_io_per_phase(sync_data) + + if not operations: + ax.text( + 0.5, + 0.5, + "No significant disk I/O detected", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Disk I/O per Phase", fontsize=14, fontweight="bold") + return + + # Create enhanced bar chart + x_pos = np.arange(len(operations)) + bar_width = 0.35 + + bars_read = ax.bar( + x_pos - bar_width / 2, + read_totals, + bar_width, + label="Read (MB)", + color="#2196F3", + alpha=0.8, + edgecolor="white", + linewidth=1.5, + ) + bars_write = ax.bar( + x_pos + bar_width / 2, + write_totals, + bar_width, + label="Write (MB)", + color="#FF5722", + alpha=0.8, + edgecolor="white", + linewidth=1.5, + ) + + # Add value labels + self._add_bar_labels(ax, bars_read, bars_write, read_totals, write_totals) + + # Format axes + ax.set_ylabel("Total I/O (MB)", fontsize=12, fontweight="bold") + ax.set_xlabel("Operation Phase", fontsize=11, fontweight="bold") + ax.set_xticks(x_pos) + ax.set_xticklabels(operations, rotation=45, ha="right", fontsize=10) + + max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0]) + ax.set_ylim(0, max_val * 1.25 if max_val > 0 else 1) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="y") + ax.set_title("Disk I/O per Operation Phase", fontsize=14, fontweight="bold", pad=15) + ax.legend(loc="upper right", fontsize=10, framealpha=0.9) + + # Summary statistics + total_read = sum(read_totals) + total_write = sum(write_totals) + ax.text( + 0.02, + 0.98, + f"Total I/O: {total_read:.1f} MB read, {total_write:.1f} MB write", + transform=ax.transAxes, + fontsize=10, + va="top", + ha="left", + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1), + ) + + def _setup_timing_plot(self, ax) -> None: + """Setup enhanced timing analysis plot.""" + operations, durations, colors = self._get_timing_data() + + if not operations: + ax.text( + 0.5, + 0.5, + "No timing data available", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + color="#666666", + ) + ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold") + return + + # Enhanced horizontal bar chart + y_pos = np.arange(len(operations)) + bars = ax.barh(y_pos, durations, color=colors, alpha=0.8, edgecolor="white", linewidth=1.5, height=0.6) + + # Add duration labels + self._add_duration_labels(ax, bars, durations) + + # Format axes + ax.set_yticks(y_pos) + ax.set_yticklabels(operations, fontsize=11) + ax.set_xlabel("Duration (seconds)", fontsize=12, fontweight="bold") + ax.set_title("Phase Duration Analysis", fontsize=14, fontweight="bold", pad=15) + ax.grid(True, alpha=0.3, linestyle="-", linewidth=0.5, color="#CCCCCC", axis="x") + ax.set_xlim(0, max(durations) * 1.2) + + # Add total duration summary + total_duration = sum(durations) + ax.text( + 0.98, + 0.02, + f"Total: {total_duration:.1f}s", + transform=ax.transAxes, + fontsize=10, + va="bottom", + ha="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.9, edgecolor="gray", linewidth=1), + ) + + def _draw_segment_backgrounds( + self, + ax, + relative_times: List[float], + data_values: List[float], + start_time: datetime, + max_val: Optional[float] = None, + ) -> None: + """Draw colored background segments for each operation.""" + if len(self.profiler.operations) < 2: + return + + max_value = max_val or (max(data_values) * 1.1 if data_values else 100) + + for i in range(len(self.profiler.operations) - 1): + op_start_time = (self.profiler.operations[i][0] - start_time).total_seconds() + op_end_time = (self.profiler.operations[i + 1][0] - start_time).total_seconds() + op_name = self.profiler.operations[i][1] + + color = self.profiler.SEGMENT_COLORS.get(op_name, "#F0F0F0") + + rect = patches.Rectangle( + (op_start_time, 0), + op_end_time - op_start_time, + max_value, + linewidth=0, + facecolor=color, + alpha=0.15, + zorder=0, + ) + ax.add_patch(rect) + + def _draw_segment_boundaries(self, ax, start_time: datetime, max_value: float) -> None: + """Draw vertical lines at segment boundaries.""" + for i, (op_time, op_name) in enumerate(self.profiler.operations): + if i == 0: + continue + + boundary_time = (op_time - start_time).total_seconds() + ax.axvline(x=boundary_time, color="#666666", linestyle="--", alpha=0.6, linewidth=2, zorder=3) + + def _mark_peak_memory(self, ax, start_time: datetime, scale: float, unit: str) -> None: + """Mark peak memory with enhanced annotation.""" + if not self.profiler.peak_rss_time: + return + + peak_time_rel = (self.profiler.peak_rss_time - start_time).total_seconds() + peak_rss_scaled = self.profiler.peak_rss / scale + + # Enhanced peak marker + ax.plot( + peak_time_rel, + peak_rss_scaled, + "o", + color="#E74C3C", + markersize=14, + markeredgecolor="white", + markeredgewidth=3, + zorder=10, + label="Peak Memory", + ) + + # Enhanced annotation + peak_text = f"Peak: {peak_rss_scaled:.1f} {unit}\nPhase: {self.profiler.peak_operation}" + ax.annotate( + peak_text, + xy=(peak_time_rel, peak_rss_scaled), + xytext=(25, 25), + textcoords="offset points", + bbox=dict(boxstyle="round,pad=0.6", facecolor="#E74C3C", alpha=0.95, edgecolor="white", linewidth=2), + arrowprops=dict(arrowstyle="->", color="#E74C3C", lw=2.5), + fontsize=11, + fontweight="bold", + color="white", + ha="left", + va="bottom", + zorder=15, + ) + + def _add_segment_legend(self, ax) -> None: + """Add enhanced segment legend with better styling.""" + if not self.profiler.operations: + return + + unique_ops = [] + seen_ops = set() + for _, op_name in self.profiler.operations: + if op_name not in seen_ops and op_name not in ["Initialization", "Completion"]: + unique_ops.append(op_name) + seen_ops.add(op_name) + + if not unique_ops: + return + + legend_elements = [] + for op_name in unique_ops: + color = self.profiler.SEGMENT_COLORS.get(op_name, "#666666") + duration = self.profiler.operation_durations.get(op_name, 0) + + label = f"{op_name} ({duration:.1f}s)" if duration > 0 else op_name + legend_elements.append(patches.Patch(color=color, alpha=0.8, label=label)) + + legend = ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(1.01, 1.0), + fontsize=11, + title="QEfficient Phases", + title_fontsize=12, + framealpha=0.95, + edgecolor="#2E86AB", + fancybox=True, + ) + legend.get_frame().set_facecolor("#F8F9FA") + + def _calculate_io_per_phase(self, sync_data: Dict[str, List[float]]) -> Tuple[List[str], List[float], List[float]]: + """Calculate I/O totals per operation phase.""" + operations = [] + read_totals = [] + write_totals = [] + + valid_operations = [ + (op_time, op_name) + for op_time, op_name in self.profiler.operations + if op_name not in ["Initialization", "Completion"] + ] + + if not valid_operations: + return operations, read_totals, write_totals + + relative_times = sync_data["timestamps"] + start_time = self.profiler.samples[0].timestamp + + for i, (op_time, op_name) in enumerate(valid_operations): + op_start_time = (op_time - start_time).total_seconds() + + if i + 1 < len(valid_operations): + op_end_time = (valid_operations[i + 1][0] - start_time).total_seconds() + else: + op_end_time = max(relative_times) if relative_times else op_start_time + 1 + + # Find data indices + start_idx = next((j for j, t in enumerate(relative_times) if t >= op_start_time), 0) + end_idx = next((j for j, t in enumerate(relative_times) if t >= op_end_time), len(relative_times) - 1) + + if start_idx < len(sync_data["disk_read_bytes"]) and end_idx < len(sync_data["disk_read_bytes"]): + read_total = sync_data["disk_read_bytes"][end_idx] - sync_data["disk_read_bytes"][start_idx] + write_total = sync_data["disk_write_bytes"][end_idx] - sync_data["disk_write_bytes"][start_idx] + + if read_total > 0.01 or write_total > 0.01: + operations.append(op_name) + read_totals.append(max(0, read_total)) + write_totals.append(max(0, write_total)) + + return operations, read_totals, write_totals + + def _get_timing_data(self) -> Tuple[List[str], List[float], List[str]]: + """Get timing data for operations.""" + operations = [] + durations = [] + colors = [] + + for op_time, op_name in self.profiler.operations: + if op_name in ["Initialization", "Completion"]: + continue + duration = self.profiler.operation_durations.get(op_name, 0) + if duration > 0: + operations.append(op_name) + durations.append(duration) + colors.append(self.profiler.SEGMENT_COLORS.get(op_name, "#666666")) + + return operations, durations, colors + + def _add_bar_labels(self, ax, bars_read, bars_write, read_totals: List[float], write_totals: List[float]) -> None: + """Add value labels on bars.""" + max_val = max(max(read_totals) if read_totals else [0], max(write_totals) if write_totals else [0]) + + for i, (read_bar, write_bar, read_val, write_val) in enumerate( + zip(bars_read, bars_write, read_totals, write_totals) + ): + if read_val > 0.01: + ax.text( + read_bar.get_x() + read_bar.get_width() / 2, + read_bar.get_height() + max_val * 0.02, + f"{read_val:.1f}", + ha="center", + va="bottom", + fontsize=9, + fontweight="bold", + color="#2196F3", + ) + + if write_val > 0.01: + ax.text( + write_bar.get_x() + write_bar.get_width() / 2, + write_bar.get_height() + max_val * 0.02, + f"{write_val:.1f}", + ha="center", + va="bottom", + fontsize=9, + fontweight="bold", + color="#FF5722", + ) + + def _add_duration_labels(self, ax, bars, durations: List[float]) -> None: + """Add duration labels on timing bars.""" + max_duration = max(durations) + + for i, (bar, duration) in enumerate(zip(bars, durations)): + width = bar.get_width() + minutes = int(duration // 60) + seconds = duration % 60 + + if minutes > 0: + duration_text = f"{minutes}m {seconds:.1f}s" + else: + duration_text = f"{seconds:.1f}s" + + ax.text( + width + max_duration * 0.02, + bar.get_y() + bar.get_height() / 2, + duration_text, + ha="left", + va="center", + fontsize=10, + fontweight="bold", + ) diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index e2e78105a..01cadaa5b 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -11,7 +11,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from QEfficient import QEFFAutoModelForCausalLM, export +from QEfficient import QEFFAutoModelForCausalLM from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ @@ -160,11 +160,8 @@ def replicate_kv_heads( # Export the modified model q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False)) - export( - model_name, - q_model, - tokenizer=tokenizer, - onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", + q_model.export( + export_dir=f"{model_base_name}-{new_kv_heads}kvheads", full_batch_size=(full_batch_size if full_batch_size else None), ) diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py index d1b7a4653..f63b18f1a 100644 --- a/tests/base/test_export_memory_offload.py +++ b/tests/base/test_export_memory_offload.py @@ -27,7 +27,7 @@ @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/base/test_onnx_transforms.py b/tests/base/test_onnx_transforms.py index 7e3ec066e..25a3b15d9 100644 --- a/tests/base/test_onnx_transforms.py +++ b/tests/base/test_onnx_transforms.py @@ -8,7 +8,11 @@ import numpy as np import onnx -from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.base.onnx_transforms import ( + FP16ClipTransform, + OnnxTransformPipeline, + SplitTensorsTransform, +) def test_fp16clip_transform(): @@ -32,7 +36,9 @@ def test_fp16clip_transform(): } """) onnx.checker.check_model(test_onnx, True, True, True) - transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx) + + onnx_transforms = OnnxTransformPipeline(transforms=[FP16ClipTransform]) + transformed_onnx, transformed = onnx_transforms.apply(test_onnx, model_name="") assert transformed assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == 65504.0 assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[1]) == 2147483647 @@ -63,7 +69,8 @@ def test_fp16clip_transform_external(tmp_path): np.array(-1e10, dtype="float32").tofile(tmp_path / external_tensors_file) onnx.checker.check_model(onnx_path, True, True, True) - transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx, onnx_base_dir=str(tmp_path)) + onnx_transforms = OnnxTransformPipeline(transforms=[FP16ClipTransform]) + transformed_onnx, transformed = onnx_transforms.apply(test_onnx, model_name="", onnx_base_dir=str(tmp_path)) assert transformed assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == -65504.0 @@ -92,7 +99,8 @@ def test_split_tensors_transform(tmp_path): tensors.tofile(tmp_path / external_tensors_file) onnx.checker.check_model(onnx_path, True, True, True) - trans_onnx, transformed = SplitTensorsTransform.apply( + onnx_transforms = OnnxTransformPipeline(transforms=[SplitTensorsTransform]) + trans_onnx, transformed = onnx_transforms.apply( test_onnx, model_name="test_split", onnx_base_dir=str(tmp_path), diff --git a/tests/cloud/test_export_compile_execute.py b/tests/cloud/test_export_compile_execute.py index f1c80a6b0..c2e77578a 100644 --- a/tests/cloud/test_export_compile_execute.py +++ b/tests/cloud/test_export_compile_execute.py @@ -76,7 +76,7 @@ def check_export_compile_execute(mocker, model_name, full_batch_size=None, enabl model_name=model_name, qpc_path=qpc_path, prompt="My name is", - prompts_txt_file_path="examples/prompts.txt", + prompts_txt_file_path="examples/sample_prompts/prompts.txt", generation_len=20, full_batch_size=full_batch_size, ) diff --git a/tests/cloud/test_infer.py b/tests/cloud/test_infer.py index 9addc0a7b..e11f69017 100644 --- a/tests/cloud/test_infer.py +++ b/tests/cloud/test_infer.py @@ -24,7 +24,7 @@ def check_infer( num_cores=16, prompt=prompt, local_model_dir=None, - prompts_txt_file_path="examples/prompts.txt", + prompts_txt_file_path="examples/sample_prompts/prompts.txt", aic_enable_depth_first=True, mos=1, hf_token=None, diff --git a/tests/diffusers/diffusers_utils.py b/tests/diffusers/diffusers_utils.py new file mode 100644 index 000000000..4e407c5aa --- /dev/null +++ b/tests/diffusers/diffusers_utils.py @@ -0,0 +1,174 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Common utilities for diffusion pipeline testing. +Provides essential functions for MAD validation, image validation +hash verification, and other testing utilities. +""" + +import os +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +from PIL import Image + + +class DiffusersTestUtils: + """Essential utilities for diffusion pipeline testing""" + + @staticmethod + def validate_image_generation( + image: Image.Image, expected_size: Tuple[int, int], min_variance: float = 1.0 + ) -> Dict[str, Any]: + """ + Validate generated image properties. + Args: + image: Generated PIL Image + expected_size: Expected (width, height) tuple + min_variance: Minimum pixel variance to ensure image is not blank + + Returns: + Dict containing validation results + Raises: + AssertionError: If image validation fails + """ + # Basic image validation + assert isinstance(image, Image.Image), f"Expected PIL Image, got {type(image)}" + assert image.size == expected_size, f"Expected size {expected_size}, got {image.size}" + assert image.mode in ["RGB", "RGBA"], f"Unexpected image mode: {image.mode}" + + # Variance check (ensure image is not blank) + img_array = np.array(image) + image_variance = float(img_array.std()) + assert image_variance > min_variance, f"Generated image appears blank (variance: {image_variance:.2f})" + + return { + "size": image.size, + "mode": image.mode, + "variance": image_variance, + "mean_pixel_value": float(img_array.mean()), + "min_pixel": int(img_array.min()), + "max_pixel": int(img_array.max()), + "valid": True, + } + + @staticmethod + def check_file_exists(file_path: str, file_type: str = "file") -> bool: + """ + Check if file exists and log result. + Args: + file_path: Path to check + file_type: Description of file type for logging + Returns: + bool: True if file exists + """ + exists = os.path.exists(file_path) + print(f"file exist: {exists}; {file_type}: {file_path}") + return exists + + @staticmethod + def print_test_header(title: str, config: Dict[str, Any]) -> None: + """ + Print formatted test header with configuration details. + + Args: + title: Test title + config: Test configuration dictionary + """ + print(f"\n{'=' * 80}") + print(f"{title}") + print(f"{'=' * 80}") + + if "model_setup" in config: + setup = config["model_setup"] + for k, v in setup.items(): + print(f"{k} : {v}") + + if "functional_testing" in config: + func = config["functional_testing"] + print(f"Test Prompt: {func.get('test_prompt', 'N/A')}") + print(f"Inference Steps: {func.get('num_inference_steps', 'N/A')}") + print(f"Guidance Scale: {func.get('guidance_scale', 'N/A')}") + + print(f"{'=' * 80}") + + +class MADValidator: + """Specialized class for MAD validation - always enabled, always reports, always fails on exceed""" + + def __init__(self, tolerances: Dict[str, float] = None): + """ + Initialize MAD validator. + MAD validation is always enabled, always reports values, and always fails if tolerance is exceeded. + + Args: + tolerances: Dictionary of module_name -> tolerance mappings + """ + self.tolerances = tolerances + self.results = {} + + def calculate_mad( + self, tensor1: Union[torch.Tensor, np.ndarray], tensor2: Union[torch.Tensor, np.ndarray] + ) -> float: + """ + Calculate Max Absolute Deviation between two tensors. + + Args: + tensor1: First tensor (PyTorch or NumPy) + tensor2: Second tensor (PyTorch or NumPy) + + Returns: + float: Maximum absolute difference between tensors + """ + if isinstance(tensor1, torch.Tensor): + tensor1 = tensor1.detach().numpy() + if isinstance(tensor2, torch.Tensor): + tensor2 = tensor2.detach().numpy() + + return float(np.max(np.abs(tensor1 - tensor2))) + + def validate_module_mad( + self, + pytorch_output: Union[torch.Tensor, np.ndarray], + qaic_output: Union[torch.Tensor, np.ndarray], + module_name: str, + step_info: str = "", + ) -> bool: + """ + Validate MAD for a specific module. + Always validates, always reports, always fails if tolerance exceeded. + + Args: + pytorch_output: PyTorch reference output + qaic_output: QAIC inference output + module_name: Name of the module + step_info: Additional step information for logging + + Returns: + bool: True if validation passed + + Raises: + AssertionError: If MAD exceeds tolerance + """ + mad_value = self.calculate_mad(pytorch_output, qaic_output) + + # Always report MAD value + step_str = f" {step_info}" if step_info else "" + print(f"{module_name.upper()} MAD{step_str}: {mad_value:.8f}") + + # Always validate - fail if tolerance exceeded + tolerance = self.tolerances.get(module_name, 1e-2) + if mad_value > tolerance: + raise AssertionError(f"{module_name} MAD {mad_value:.6f} exceeds tolerance {tolerance:.6f}") + + # Store result + if module_name not in self.results: + self.results[module_name] = [] + self.results[module_name].append({"mad": mad_value, "step_info": step_info, "tolerance": tolerance}) + return True diff --git a/tests/diffusers/flux_test_config.json b/tests/diffusers/flux_test_config.json new file mode 100644 index 000000000..9f13daca0 --- /dev/null +++ b/tests/diffusers/flux_test_config.json @@ -0,0 +1,123 @@ +{ + "model_setup": { + "height": 256, + "width": 256, + "num_transformer_layers": 2, + "num_single_layers": 2, + "use_onnx_subfunctions": false + }, + "mad_validation": { + "tolerances": { + "clip_text_encoder": 0.1, + "t5_text_encoder": 5.5, + "transformer": 2.0, + "vae_decoder": 1.0 + } + }, + "pipeline_params": { + "test_prompt": "A cat holding a sign that says hello world", + "num_inference_steps": 2, + "guidance_scale": 0.0, + "max_sequence_length": 256, + "validate_gen_img": true, + "min_image_variance": 1.0, + "custom_config_path": null + }, + "validation_checks": { + "image_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": + { + "text_encoder": + { + "specializations":{ + "batch_size": 1, + "seq_len": 77 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + + }, + "text_encoder_2": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + }, + "transformer": + { + "specializations": + { + "batch_size": 1, + "seq_len": 256, + "steps": 1 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts-mos": 1, + "aic-enable-depth-first": true + }, + "execute": + { + "device_ids": null + } + }, + "vae_decoder": + { + "specializations": + { + "batch_size": 1, + "channels": 16 + }, + "compilation": + { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": false, + "convert_to_fp16": true, + "aic_num_cores": 16 + }, + "execute": + { + "device_ids": null + } + } + } + +} diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py new file mode 100644 index 000000000..721850257 --- /dev/null +++ b/tests/diffusers/test_flux.py @@ -0,0 +1,448 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import pytest +import torch +from diffusers import FluxPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from QEfficient import QEffFluxPipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ModulePerf, + QEffPipelineOutput, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +# Test Configuration for 256x256 resolution with 2 layers # update mad tolerance +CONFIG_PATH = "tests/diffusers/flux_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +def flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 256, + width: int = 256, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + parallel_compile: bool = False, + mad_tolerances: Dict[str, float] = None, +): + """ + Pipeline call function that replicates the exact flow of pipeline_flux.py.__call__() + while adding comprehensive MAD validation at each step. + + This function follows the EXACT same structure as QEffFluxPipeline.__call__() + but adds MAD validation hooks throughout the process. + """ + # Initialize MAD validator + mad_validator = MADValidator(tolerances=mad_tolerances) + + device = "cpu" + + # Step 1: Load configuration, compile models + pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width) + + # Set device IDs for all modules based on configuration + set_module_device_ids(pipeline) + + # Validate all inputs + pipeline.model.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + # Set pipeline attributes + pipeline._guidance_scale = guidance_scale + pipeline._interrupt = False + batch_size = INITIAL_TEST_CONFIG["modules"]["transformer"]["specializations"]["batch_size"] + + # Step 3: Encode prompts with both text encoders + # Use pipeline's encode_prompt method + (t5_qaic_prompt_embeds, clip_qaic_pooled_prompt_embeds, text_ids, text_encoder_perf) = pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + (t5_torch_prompt_embeds, clip_torch_pooled_prompt_embeds, text_ids) = pytorch_pipeline.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + # Deactivate text encoder qpc sessions + pipeline.text_encoder.qpc_session.deactivate() + pipeline.text_encoder_2.qpc_session.deactivate() + + # MAD Validation for Text Encoders + print(" Performing MAD validation for text encoders...") + mad_validator.validate_module_mad( + clip_qaic_pooled_prompt_embeds, clip_torch_pooled_prompt_embeds, module_name="clip_text_encoder" + ) + mad_validator.validate_module_mad(t5_torch_prompt_embeds, t5_qaic_prompt_embeds, "t5_text_encoder") + + # Step 4: Prepare timesteps for denoising + timesteps, num_inference_steps = retrieve_timesteps(pipeline.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0) + pipeline._num_timesteps = len(timesteps) + + # Step 5: Prepare initial latents + num_channels_latents = pipeline.transformer.model.config.in_channels // 4 + latents, latent_image_ids = pipeline.model.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + t5_qaic_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # Step 6: Initialize transformer inference session + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession( + str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids + ) + + # Calculate compressed latent dimension (cl) for transformer buffer allocation + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, _, _ = calculate_compressed_latent_dimension(height, width, pipeline.model.vae_scale_factor) + + # Allocate output buffer for transformer + output_buffer = { + "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32), + } + pipeline.transformer.qpc_session.set_buffers(output_buffer) + + transformer_perf = [] + pipeline.scheduler.set_begin_index(0) + + # Step 7: Denoising loop + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + # Prepare timestep embedding + timestep = t.expand(latents.shape[0]).to(latents.dtype) + temb = pipeline.transformer.model.time_text_embed(timestep, clip_qaic_pooled_prompt_embeds) + + # Compute AdaLN embeddings for dual transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.transformer_blocks)): + block = pipeline.transformer.model.transformer_blocks[block_idx] + f1 = block.norm1.linear(block.norm1.silu(temb)).chunk(6, dim=1) + f2 = block.norm1_context.linear(block.norm1_context.silu(temb)).chunk(6, dim=1) + adaln_emb.append(torch.cat(list(f1) + list(f2))) + adaln_dual_emb = torch.stack(adaln_emb) + + # Compute AdaLN embeddings for single transformer blocks + adaln_emb = [] + for block_idx in range(len(pipeline.transformer.model.single_transformer_blocks)): + block = pipeline.transformer.model.single_transformer_blocks[block_idx] + f1 = block.norm.linear(block.norm.silu(temb)).chunk(3, dim=1) + adaln_emb.append(torch.cat(list(f1))) + adaln_single_emb = torch.stack(adaln_emb) + + # Compute output AdaLN embedding + temp = pipeline.transformer.model.norm_out + adaln_out = temp.linear(temp.silu(temb)) + + # Normalize timestep to [0, 1] range + timestep = timestep / 1000 + + # Prepare all inputs for transformer inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": t5_qaic_prompt_embeds.detach().numpy(), + "pooled_projections": clip_qaic_pooled_prompt_embeds.detach().numpy(), + "timestep": timestep.detach().numpy(), + "img_ids": latent_image_ids.detach().numpy(), + "txt_ids": text_ids.detach().numpy(), + "adaln_emb": adaln_dual_emb.detach().numpy(), + "adaln_single_emb": adaln_single_emb.detach().numpy(), + "adaln_out": adaln_out.detach().numpy(), + } + + # MAD Validation for Transformer - PyTorch reference inference + noise_pred_torch = pytorch_pipeline.transformer( + hidden_states=latents, + encoder_hidden_states=t5_torch_prompt_embeds, + pooled_projections=clip_torch_pooled_prompt_embeds, + timestep=torch.tensor(timestep), + img_ids=latent_image_ids, + txt_ids=text_ids, + return_dict=False, + )[0] + + # Run transformer inference and measure time + start_transformer_step_time = time.time() + outputs = pipeline.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.time() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + noise_pred = torch.from_numpy(outputs["output"]) + + # Transformer MAD validation + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + outputs["output"], + "transformer", + f"step {i} (t={t.item():.1f})", + ) + + # Update latents using scheduler + latents_dtype = latents.dtype + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Handle dtype mismatch + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + # Step 8: Decode latents to images + if output_type == "latent": + image = latents + vae_decode_perf = 0.0 # No VAE decoding for latent output + else: + # Unpack and denormalize latents + latents = pipeline.model._unpack_latents(latents, height, width, pipeline.model.vae_scale_factor) + + # Denormalize latents + latents = (latents / pipeline.vae_decode.model.scaling_factor) + pipeline.vae_decode.model.shift_factor + # Initialize VAE decoder inference session + if pipeline.vae_decode.qpc_session is None: + pipeline.vae_decode.qpc_session = QAICInferenceSession( + str(pipeline.vae_decode.qpc_path), device_ids=pipeline.vae_decode.device_ids + ) + + # Allocate output buffer for VAE decoder + output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)} + pipeline.vae_decode.qpc_session.set_buffers(output_buffer) + + # MAD Validation for VAE + # PyTorch reference inference + image_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] + + # Run VAE decoder inference and measure time + inputs = {"latent_sample": latents.numpy()} + start_decode_time = time.time() + image = pipeline.vae_decode.qpc_session.run(inputs) + end_decode_time = time.time() + vae_decode_perf = end_decode_time - start_decode_time + + # VAE MAD validation + mad_validator.validate_module_mad(image_torch.detach().cpu().numpy(), image["sample"], "vae_decoder") + + # Post-process image + image_tensor = torch.from_numpy(image["sample"]) + image = pipeline.model.image_processor.postprocess(image_tensor, output_type=output_type) + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="text_encoder", perf=text_encoder_perf[0]), + ModulePerf(module_name="text_encoder_2", perf=text_encoder_perf[1]), + ModulePerf(module_name="transformer", perf=transformer_perf), + ModulePerf(module_name="vae_decoder", perf=vae_decode_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=image, + ) + + +@pytest.fixture(scope="session") +def flux_pipeline(): + """Setup compiled Flux pipeline for testing""" + config = INITIAL_TEST_CONFIG["model_setup"] + + pipeline = QEffFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + use_onnx_subfunctions=config["use_onnx_subfunctions"], + ) + + # Reduce to 2 layers for testing + original_blocks = pipeline.transformer.model.transformer_blocks + org_single_blocks = pipeline.transformer.model.single_transformer_blocks + + pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"] + pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"] + pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + + ### Pytorch pipeline + pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") + original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks + org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks + pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList( + [original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])] + ) + pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList( + [org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])] + ) + return pipeline, pytorch_pipeline + + +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +def test_flux_pipeline(flux_pipeline): + """ + Comprehensive Flux pipeline test that follows the exact same flow as pipeline_flux.py: + - 256x256 resolution - 2 transformer layers + - MAD validation + - Functional image generation test + - Export/compilation checks + - Returns QEffPipelineOutput with performance metrics + """ + pipeline, pytorch_pipeline = flux_pipeline + config = INITIAL_TEST_CONFIG + + # Print test header + DiffusersTestUtils.print_test_header( + f"FLUX PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_transformer_layers']} Layers", + config, + ) + + # Test parameters + test_prompt = config["pipeline_params"]["test_prompt"] + num_inference_steps = config["pipeline_params"]["num_inference_steps"] + guidance_scale = config["pipeline_params"]["guidance_scale"] + max_sequence_length = config["pipeline_params"]["max_sequence_length"] + + # Generate with MAD validation + generator = torch.manual_seed(42) + start_time = time.time() + + try: + # Run the pipeline with integrated MAD validation (follows exact pipeline flow) + result = flux_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + prompt=test_prompt, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + max_sequence_length=max_sequence_length, + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + parallel_compile=True, + return_dict=True, + ) + + execution_time = time.time() - start_time + + # Validate image generation + if config["pipeline_params"]["validate_gen_img"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No images generated" + + generated_image = result.images[0] + expected_size = (config["model_setup"]["height"], config["model_setup"]["width"]) + # Validate image properties using utilities + image_validation = DiffusersTestUtils.validate_image_generation( + generated_image, expected_size, config["pipeline_params"]["min_image_variance"] + ) + + print("\n IMAGE VALIDATION PASSED") + print(f" - Size: {image_validation['size']}") + print(f" - Mode: {image_validation['mode']}") + print(f" - Variance: {image_validation['variance']:.2f}") + print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}") + file_path = "test_flux_256x256_2layers.png" + # Save test image + generated_image.save(file_path) + + if os.path.exists(file_path): + print(f"Image saved successfully at: {file_path}") + else: + print("Image was not saved.") + + if config["validation_checks"]["onnx_export"]: + # Check if ONNX files exist (basic check) + print("\n ONNX Export Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "onnx_path") and module_obj.onnx_path: + DiffusersTestUtils.check_file_exists(str(module_obj.onnx_path), f"{module_name} ONNX") + + if config["validation_checks"]["compilation"]: + # Check if QPC files exist (basic check) + print("\n Compilation Validation:") + for module_name in ["text_encoder", "text_encoder_2", "transformer", "vae_decode"]: + module_obj = getattr(pipeline, module_name, None) + if module_obj and hasattr(module_obj, "qpc_path") and module_obj.qpc_path: + DiffusersTestUtils.check_file_exists(str(module_obj.qpc_path), f"{module_name} QPC") + + # Print test summary using utilities + print(f"\nTotal execution time: {execution_time:.4f}s") + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + # This allows running the test file directly for debugging + pytest.main([__file__, "-v", "-s", "-m", "flux"]) +# pytest tests/diffusers/test_flux.py -m flux -v -s --tb=short diff --git a/tests/diffusers/test_wan.py b/tests/diffusers/test_wan.py new file mode 100644 index 000000000..f11db826b --- /dev/null +++ b/tests/diffusers/test_wan.py @@ -0,0 +1,535 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Test for wan pipeline +# TODO : 1. Add pytest for call method + 2. See if we reduce height and width + 3. Keep test for Sub fn as default once sdk supports +""" + +import time +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import pytest +import safetensors.torch +import torch +from diffusers import WanPipeline +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline +from QEfficient.diffusers.pipelines.pipeline_utils import ( + ModulePerf, + QEffPipelineOutput, + calculate_latent_dimensions_with_frames, + set_module_device_ids, +) +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants +from QEfficient.utils._utils import load_json +from tests.diffusers.diffusers_utils import DiffusersTestUtils, MADValidator + +# Test Configuration for 192x320 resolution with 1 layer +CONFIG_PATH = "tests/diffusers/wan_test_config.json" +INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) + + +def wan_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height: int = 192, + width: int = 320, + num_frames: int = 81, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 2, + guidance_scale: float = 1.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + custom_config_path: Optional[str] = None, + use_onnx_subfunctions: bool = False, + parallel_compile: bool = True, + mad_tolerances: Dict[str, float] = None, +): + """ + Pipeline call function that replicates the exact flow of pipeline_wan.py.__call__() + while adding comprehensive MAD validation for transformer modules only. + + This function follows the EXACT same structure as QEffWanPipeline.__call__() + but adds MAD validation hooks for transformer testing. + """ + # Initialize MAD validator + mad_validator = MADValidator(tolerances=mad_tolerances) + + device = "cpu" + + # Step 1: Compile() (export and compile) + pipeline.cl, pipeline.latent_height, pipeline.latent_width, pipeline.latent_frames = ( + calculate_latent_dimensions_with_frames( + height, + width, + num_frames, + pipeline.model.vae.config.scale_factor_spatial, + pipeline.model.vae.config.scale_factor_temporal, + pipeline.patch_height, + pipeline.patch_width, + ) + ) + pipeline.compile( + compile_config=custom_config_path, + parallel=parallel_compile, + height=height, + width=width, + num_frames=num_frames, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + set_module_device_ids(pipeline) + + # Step 2: Check inputs + pipeline.model.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % pipeline.model.vae.config.scale_factor_temporal != 1: + num_frames = ( + num_frames + // pipeline.model.vae.config.scale_factor_temporal + * pipeline.model.vae.config.scale_factor_temporal + + 1 + ) + num_frames = max(num_frames, 1) + + if pipeline.model.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + pipeline._guidance_scale = guidance_scale + pipeline._guidance_scale_2 = guidance_scale_2 + pipeline._attention_kwargs = attention_kwargs + pipeline._current_timestep = None + pipeline._interrupt = False + + # Step 3: Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Step 4: Encode input prompt(using CPU text encoder for now) + prompt_embeds, negative_prompt_embeds = pipeline.model.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=pipeline.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Get PyTorch reference prompt embeddings + # For standard WAN pipeline, CFG is determined by presence of negative prompts + do_classifier_free_guidance = negative_prompt is not None + pytorch_prompt_embeds, pytorch_negative_prompt_embeds = pytorch_pipeline.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = pipeline.transformer.model.transformer_high.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + pytorch_prompt_embeds = pytorch_prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + pytorch_negative_prompt_embeds = pytorch_negative_prompt_embeds.to(transformer_dtype) + + # Step 5: Prepare timesteps + pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = pipeline.scheduler.timesteps + + # Step 6: Prepare latent variables + num_channels_latents = pipeline.transformer.model.config.in_channels + latents = pipeline.model.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # Step 7: Setup transformer inference session + if pipeline.transformer.qpc_session is None: + pipeline.transformer.qpc_session = QAICInferenceSession( + str(pipeline.transformer.qpc_path), device_ids=pipeline.transformer.device_ids + ) + + output_buffer = { + "output": np.random.rand( + batch_size, + pipeline.cl, + constants.WAN_DIT_OUT_CHANNELS, + ).astype(np.int32), + } + pipeline.transformer.qpc_session.set_buffers(output_buffer) + transformer_perf = [] + + # Step 8: Denoising loop with transformer MAD validation + if pipeline.model.config.boundary_ratio is not None: + boundary_timestep = pipeline.model.config.boundary_ratio * pipeline.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order + + with pipeline.model.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if pipeline._interrupt: + continue + + pipeline._current_timestep = t + + # Determine which transformer to use (high or low noise) + if boundary_timestep is None or t >= boundary_timestep: + # High-noise stage + current_model = pipeline.transformer.model.transformer_high + pytorch_current_model = pytorch_pipeline.transformer + model_type = torch.ones(1, dtype=torch.int64) + model_name = "transformer_high" + else: + # Low-noise stage + current_model = pipeline.transformer.model.transformer_low + pytorch_current_model = pytorch_pipeline.transformer_2 + model_type = torch.ones(2, dtype=torch.int64) + model_name = "transformer_low" + + latent_model_input = latents.to(transformer_dtype) + if pipeline.model.config.expand_timesteps: + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + batch_size, num_channels, num_frames, height, width = latents.shape + p_t, p_h, p_w = current_model.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Prepare transformer inputs + rotary_emb = current_model.rope(latent_model_input) + rotary_emb = torch.cat(rotary_emb, dim=0) + ts_seq_len = None + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = current_model.condition_embedder( + timestep, prompt_embeds, encoder_hidden_states_image=None, timestep_seq_len=ts_seq_len + ) + + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # Prepare inputs for QAIC inference + inputs_aic = { + "hidden_states": latents.detach().numpy(), + "encoder_hidden_states": encoder_hidden_states.detach().numpy(), + "rotary_emb": rotary_emb.detach().numpy(), + "temb": temb.detach().numpy(), + "timestep_proj": timestep_proj.detach().numpy(), + "tsp": model_type.detach().numpy(), + } + + # PyTorch reference inference (standard WAN transformer has different signature) + noise_pred_torch = pytorch_current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=pytorch_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # QAIC inference + with current_model.cache_context("cond"): + start_transformer_step_time = time.time() + outputs = pipeline.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.time() + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + + hidden_states = torch.tensor(outputs["output"]) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + noise_pred = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + # Transformer MAD validation + print(f" Performing MAD validation for {model_name} at step {i}...") + mad_validator.validate_module_mad( + noise_pred_torch.detach().cpu().numpy(), + noise_pred.detach().cpu().numpy(), + model_name, + f"step {i} (t={t.item():.1f})", + ) + + # Update latents using scheduler + latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Update progress bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + + # Step 9: Decode latents to video (using CPU VAE for now) + if not output_type == "latent": + latents = latents.to(pipeline.vae_decode.dtype) + latents_mean = ( + torch.tensor(pipeline.vae_decode.config.latents_mean) + .view(1, pipeline.vae_decode.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(pipeline.vae_decode.config.latents_std).view( + 1, pipeline.vae_decode.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + video = pipeline.model.vae.decode(latents, return_dict=False)[0] + + video = pipeline.model.video_processor.postprocess_video(video.detach()) + else: + video = latents + + # Build performance metrics + perf_metrics = [ + ModulePerf(module_name="transformer", perf=transformer_perf), + ] + + return QEffPipelineOutput( + pipeline_module=perf_metrics, + images=video, + ) + + +@pytest.fixture(scope="session") +def wan_pipeline(): + """Setup compiled WAN pipeline for testing with LoRA adapters and 2 layers total""" + config = INITIAL_TEST_CONFIG["model_setup"] + + def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + # Download and load LoRA adapters + high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", + ) + low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", + ) + + # Load PyTorch reference pipeline + pytorch_pipeline = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + + # Load into the transformers + pytorch_pipeline.transformer.load_lora_adapter(load_wan_lora(high_noise_lora_path), adapter_name="high_noise") + pytorch_pipeline.transformer.set_adapters(["high_noise"], weights=[1.0]) + + pytorch_pipeline.transformer_2.load_lora_adapter(load_wan_lora(low_noise_lora_path), adapter_name="low_noise") + pytorch_pipeline.transformer_2.set_adapters(["low_noise"], weights=[1.0]) + + # ### for 2 layer model + pytorch_pipeline.transformer.config.num_layers = config["num_transformer_layers_high"] + pytorch_pipeline.transformer_2.config.num_layers = config["num_transformer_layers_low"] + original_blocks = pytorch_pipeline.transformer.blocks + org_blocks = pytorch_pipeline.transformer_2.blocks + pytorch_pipeline.transformer.blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pytorch_pipeline.transformer.config.num_layers)] + ) + pytorch_pipeline.transformer_2.blocks = torch.nn.ModuleList( + [org_blocks[i] for i in range(0, pytorch_pipeline.transformer_2.config.num_layers)] + ) + + # Load QEff WAN pipeline + pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + + # Load LoRA adapters into transformers + pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" + ) + pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" + ) + pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + # Reduce to 1 layer (1 high, 1 low) for testing + pipeline.transformer.model.transformer_high.config.num_layers = config["num_transformer_layers_high"] + pipeline.transformer.model.transformer_low.config.num_layers = config["num_transformer_layers_low"] + + original_blocks_high = pipeline.transformer.model.transformer_high.blocks + original_blocks_low = pipeline.transformer.model.transformer_low.blocks + + pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( + [original_blocks_high[i] for i in range(0, config["num_transformer_layers_high"])] + ) + pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( + [original_blocks_low[i] for i in range(0, config["num_transformer_layers_low"])] + ) + + return pipeline, pytorch_pipeline + + +@pytest.mark.diffusion_models +@pytest.mark.on_qaic +@pytest.mark.wan +def test_wan_pipeline(wan_pipeline): + """ + Comprehensive WAN pipeline test that focuses on transformer validation: + - 192x320 resolution - 2 transformer layers total (1 high + 1 low) + - MAD validation for transformer modules only + - Functional video generation test + - Export/compilation checks for transformer + - Returns QEffPipelineOutput with performance metrics + """ + pipeline, pytorch_pipeline = wan_pipeline + config = INITIAL_TEST_CONFIG + + # Print test header + DiffusersTestUtils.print_test_header( + f"WAN PIPELINE TEST - {config['model_setup']['height']}x{config['model_setup']['width']} Resolution, {config['model_setup']['num_frames']} Frames, 2 Layers Total", + config, + ) + + # Test parameters + test_prompt = config["pipeline_params"]["test_prompt"] + num_inference_steps = config["pipeline_params"]["num_inference_steps"] + guidance_scale = config["pipeline_params"]["guidance_scale"] + guidance_scale_2 = config["pipeline_params"]["guidance_scale_2"] + max_sequence_length = config["pipeline_params"]["max_sequence_length"] + num_frames = config["model_setup"]["num_frames"] + + # Generate with MAD validation + generator = torch.manual_seed(42) + start_time = time.time() + + try: + # Run the pipeline with integrated MAD validation (focuses on transformer) + result = wan_pipeline_call_with_mad_validation( + pipeline, + pytorch_pipeline, + height=config["model_setup"]["height"], + width=config["model_setup"]["width"], + num_frames=num_frames, + prompt=test_prompt, + guidance_scale=guidance_scale, + guidance_scale_2=guidance_scale_2, + num_inference_steps=num_inference_steps, + max_sequence_length=max_sequence_length, + custom_config_path=CONFIG_PATH, + generator=generator, + mad_tolerances=config["mad_validation"]["tolerances"], + parallel_compile=True, + return_dict=True, + ) + + execution_time = time.time() - start_time + + # Validate video generation + if config["pipeline_params"]["validate_gen_video"]: + assert result is not None, "Pipeline returned None" + assert hasattr(result, "images"), "Result missing 'images' attribute" + assert len(result.images) > 0, "No video frames generated" + + generated_video = result.images[0] + assert len(generated_video) == num_frames, f"Expected {num_frames} frames, got {len(generated_video)}" + + # Validate first frame properties + first_frame = generated_video[0] + expected_size = (config["model_setup"]["width"], config["model_setup"]["height"]) + + # Convert numpy array to PIL Image if needed for validation + if isinstance(first_frame, np.ndarray): + from PIL import Image + + if first_frame.dtype != np.uint8: + first_frame = (first_frame * 255).astype(np.uint8) + if len(first_frame.shape) == 3 and first_frame.shape[0] == 3: + first_frame = first_frame.transpose(1, 2, 0) + first_frame = Image.fromarray(first_frame) + + # Validate video frame properties + frame_validation = DiffusersTestUtils.validate_image_generation( + first_frame, expected_size, config["pipeline_params"]["min_video_variance"] + ) + + print("\n VIDEO VALIDATION PASSED") + print(f" - Frame count: {len(generated_video)}") + print(f" - Frame size: {frame_validation['size']}") + print(f" - Frame mode: {frame_validation['mode']}") + print(f" - Frame variance: {frame_validation['variance']:.2f}") + print(f" - Mean pixel value: {frame_validation['mean_pixel_value']:.2f}") + + # Save result as video + frames = result.images[0] + export_to_video(frames, "test_wan_output_t2v.mp4", fps=16) + print("\n VIDEO SAVED: test_wan_output_t2v.mp4") + print(result) + + if config["validation_checks"]["onnx_export"]: + # Check if transformer ONNX file exists + print("\n ONNX Export Validation:") + if hasattr(pipeline.transformer, "onnx_path") and pipeline.transformer.onnx_path: + DiffusersTestUtils.check_file_exists(str(pipeline.transformer.onnx_path), "transformer ONNX") + + if config["validation_checks"]["compilation"]: + # Check if transformer QPC file exists + print("\n Compilation Validation:") + if hasattr(pipeline.transformer, "qpc_path") and pipeline.transformer.qpc_path: + DiffusersTestUtils.check_file_exists(str(pipeline.transformer.qpc_path), "transformer QPC") + + # Print test summary + print(f"\nTotal execution time: {execution_time:.4f}s") + print(" WAN TRANSFORMER TEST COMPLETED SUCCESSFULLY") + + except Exception as e: + print(f"\nTEST FAILED: {e}") + raise + + +if __name__ == "__main__": + # This allows running the test file directly for debugging + pytest.main([__file__, "-v", "-s", "-m", "wan"]) +# pytest tests/diffusers/test_wan.py -m wan -v -s --tb=short diff --git a/tests/diffusers/wan_test_config.json b/tests/diffusers/wan_test_config.json new file mode 100644 index 000000000..1ed36294a --- /dev/null +++ b/tests/diffusers/wan_test_config.json @@ -0,0 +1,63 @@ +{ + "model_setup": { + "height": 192, + "width": 320, + "num_frames": 81, + "num_transformer_layers_high": 1, + "num_transformer_layers_low": 1, + "use_onnx_subfunctions": false + }, + "mad_validation": { + "tolerances": { + "transformer_high": 0.3, + "transformer_low": 0.2 + } + }, + "pipeline_params": { + "test_prompt": "A cat walking in a garden", + "num_inference_steps": 2, + "guidance_scale": 1.0, + "guidance_scale_2": 1.0, + "max_sequence_length": 512, + "validate_gen_video": true, + "min_video_variance": 1.0 + }, + "validation_checks": { + "video_generation": true, + "onnx_export": true, + "compilation": true + }, + "modules": { + "transformer": { + "specializations": [ + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 1 + }, + { + "batch_size": "1", + "num_channels": "16", + "steps": "1", + "sequence_length": "512", + "model_type": 2 + } + ], + "compilation": { + "onnx_path": null, + "compile_dir": null, + "mdp_ts_num_devices": 1, + "mxfp6_matmul": true, + "convert_to_fp16": true, + "aic_num_cores": 16, + "mos": 1, + "mdts_mos": 1 + }, + "execute": { + "device_ids": null + } + } + } +} diff --git a/tests/finetune/reference_data.py b/tests/finetune/reference_data.py index a2a5438f5..c94c03b0b 100644 --- a/tests/finetune/reference_data.py +++ b/tests/finetune/reference_data.py @@ -13,206 +13,206 @@ "llama_3.2_1B_config_alpaca_single_device": { "description": "Baseline for Llama on Alpaca single-device", "train_step_losses": [ - 1.5112206935882568, - 1.2211230993270874, - 1.9942185878753662, - 2.093623161315918, - 0.9168124198913574, - 1.2125635147094727, - 0.3648962676525116, - 1.6231939792633057, - 0.8259601593017578, - 0.7741442918777466, - 1.7359141111373901, - 2.118462085723877, - 2.061161994934082, - 0.8256913423538208, - 0.8088029623031616, - 1.761340618133545, - 1.6828027963638306, - 1.3538823127746582, - 2.0672550201416016, - 3.1532647609710693, + 1.5110896825790405, + 1.2206485271453857, + 1.9950776100158691, + 2.091615676879883, + 0.9182446599006653, + 1.1993569135665894, + 0.36413607001304626, + 1.6241482496261597, + 0.8270177245140076, + 0.7749958634376526, + 1.73696768283844, + 2.120077610015869, + 2.061460256576538, + 0.8267984390258789, + 0.8105809688568115, + 1.7627557516098022, + 1.6819559335708618, + 1.3528242111206055, + 2.0654125213623047, + 3.156151294708252, ], "eval_step_losses": [ - 1.462059736251831, - 0.24527676403522491, - 1.046107292175293, - 1.6403586864471436, - 1.395291805267334, - 2.8664817810058594, - 1.035412311553955, - 1.8670039176940918, - 3.8079662322998047, - 0.6516809463500977, + 1.4607517719268799, + 0.24302150309085846, + 1.0471211671829224, + 1.642044186592102, + 1.3949533700942993, + 2.8850066661834717, + 1.0366586446762085, + 1.8661959171295166, + 3.81632924079895, + 0.6577113270759583, ], "train_step_metrics": [ - 4.532259941101074, - 3.390994071960449, - 7.34645938873291, - 8.114261627197266, - 2.5013046264648438, - 3.3620924949645996, - 1.4403645992279053, - 5.069255828857422, - 2.2840728759765625, - 2.1687355041503906, - 5.674112319946289, - 8.318334579467773, - 7.855090141296387, - 2.283458948135376, - 2.2452187538146973, - 5.820234775543213, - 5.380615711212158, - 3.872429847717285, - 7.903097629547119, - 23.412376403808594, + 4.531666278839111, + 3.389385223388672, + 7.352773189544678, + 8.09798812866211, + 2.504889488220215, + 3.3179824352264404, + 1.43927001953125, + 5.074095249176025, + 2.286489486694336, + 2.1705832481384277, + 5.680093288421631, + 8.33178424835205, + 7.857433319091797, + 2.2859883308410645, + 2.2492144107818604, + 5.828476905822754, + 5.376060962677002, + 3.8683345317840576, + 7.8885498046875, + 23.480052947998047, ], "eval_step_metrics": [ # steps 0-9 - 4.31483793258667, - 1.2779749631881714, - 2.8465487957000732, - 5.157018661499023, - 4.036152362823486, - 17.575077056884766, - 2.816267251968384, - 6.468885898590088, - 45.05870819091797, - 1.9187631607055664, + 4.309197902679443, + 1.27509605884552, + 2.8494362831115723, + 5.1657185554504395, + 4.034786224365234, + 17.9036865234375, + 2.819779396057129, + 6.463661193847656, + 45.437110900878906, + 1.9303690195083618, ], }, # Scenario 2: Single-device llama 3.2-1B training on GSM8k dataset. "llama_3.2_1B_config_gsm8k_single_device": { "description": "Baseline for Llama on GSM8k single-device", "train_step_losses": [ - 2.250276803970337, - 2.3231687545776367, - 1.9379945993423462, - 1.5981022119522095, - 1.9867562055587769, - 1.4573354721069336, - 1.8969658613204956, - 1.2177824974060059, - 1.6489791870117188, - 1.5380687713623047, - 1.4025083780288696, - 1.5301083326339722, - 1.6858205795288086, - 1.383747935295105, - 1.7968919277191162, - 1.4075607061386108, - 1.6447738409042358, - 1.2807793617248535, - 0.8450672030448914, - 1.5795941352844238, + 2.250361204147339, + 2.3252110481262207, + 1.9360781908035278, + 1.5984115600585938, + 1.9874038696289062, + 1.4579044580459595, + 1.8975679874420166, + 1.2175723314285278, + 1.6473736763000488, + 1.537960410118103, + 1.4019465446472168, + 1.5310447216033936, + 1.6878201961517334, + 1.3849903345108032, + 1.7976438999176025, + 1.4060133695602417, + 1.646375060081482, + 1.2835280895233154, + 0.8465587496757507, + 1.5783095359802246, ], "eval_step_losses": [ - 1.7081595659255981, - 1.719305157661438, - 1.153528094291687, - 2.0051634311676025, - 1.3372926712036133, - 1.3009852170944214, - 1.2207027673721313, - 1.3452664613723755, - 1.329830288887024, - 1.307450532913208, + 1.707140326499939, + 1.7226355075836182, + 1.1531383991241455, + 2.0035903453826904, + 1.3362350463867188, + 1.3013248443603516, + 1.2195535898208618, + 1.3454742431640625, + 1.3299248218536377, + 1.3073854446411133, ], "train_step_metrics": [ - 9.490362167358398, - 10.207969665527344, - 6.944809913635254, - 4.943641662597656, - 7.291841506958008, - 4.294501304626465, - 6.6656389236450195, - 3.3796849250793457, - 5.201667308807373, - 4.655590534210205, - 4.065384864807129, - 4.618677139282227, - 5.396877765655518, - 3.989826202392578, - 6.030873775482178, - 4.0859761238098145, - 5.179838180541992, - 3.5994436740875244, - 2.328134298324585, - 4.852985858917236, + 9.49116325378418, + 10.228837966918945, + 6.93151330947876, + 4.945170879364014, + 7.296566009521484, + 4.296945571899414, + 6.66965389251709, + 3.378974676132202, + 5.193322658538818, + 4.655086040496826, + 4.063101291656494, + 4.623003959655762, + 5.407680034637451, + 3.994786262512207, + 6.0354108810424805, + 4.0796589851379395, + 5.188138961791992, + 3.60935115814209, + 2.3316092491149902, + 4.846755504608154, ], "eval_step_metrics": [ # steps 0-9 - 5.518795013427734, - 5.580649375915527, - 3.1693549156188965, - 7.42730712890625, - 3.8087174892425537, - 3.672913074493408, - 3.38956880569458, - 3.8392088413238525, - 3.7804012298583984, - 3.6967368125915527, + 5.5131731033325195, + 5.599266052246094, + 3.1681201457977295, + 7.415632247924805, + 3.8046915531158447, + 3.674160957336426, + 3.3856759071350098, + 3.8400065898895264, + 3.7807586193084717, + 3.69649600982666, ], }, # Scenario 3: Single-device google-bert/bert-base-uncased training on IMDB dataset. "bert_base_uncased_config_imdb_single_device": { "description": "Baseline for google-bert/bert-base-uncased on IMDB single-device", "train_step_losses": [ - 0.357421875, - 0.546875, - 0.98486328125, - 0.35302734375, - 1.23828125, - 0.60791015625, - 0.44384765625, - 0.791015625, - 0.7861328125, - 0.51318359375, - 0.50244140625, - 0.90087890625, - 0.8818359375, - 0.86279296875, - 0.6396484375, - 0.49267578125, - 0.97119140625, - 0.7451171875, - 0.798828125, - 0.7080078125, + 0.390625, + 0.51220703125, + 0.9208984375, + 0.4052734375, + 1.1640625, + 0.6533203125, + 0.5087890625, + 0.76171875, + 0.63525390625, + 0.50146484375, + 0.5439453125, + 0.947265625, + 0.89013671875, + 0.80419921875, + 0.6533203125, + 0.4580078125, + 0.92041015625, + 0.7412109375, + 0.7197265625, + 0.62158203125, ], "eval_step_losses": [ - 0.634765625, - 0.8173828125, + 0.6044921875, + 0.798828125, 0.9072265625, - 0.7177734375, - 0.59423828125, - 0.69921875, - 0.7109375, - 0.7216796875, - 0.6064453125, - 0.7041015625, + 0.70361328125, + 0.59912109375, + 0.66357421875, + 0.6962890625, + 0.75390625, + 0.61328125, + 0.6806640625, ], "train_step_metrics": [ 1.0, 1.0, 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.5, - 0.449951171875, - 0.4091796875, + 0.49999988079071045, + 0.49999988079071045, + 0.5, + 0.5000002384185791, + 0.5000002384185791, + 0.6250002384185791, + 0.6249998807907104, + 0.625, + 0.6000000238418579, + 0.5833332538604736, + 0.5714285373687744, + 0.5714285373687744, + 0.5714285373687744, + 0.5625, + 0.555555522441864, + 0.5055557489395142, + 0.5101010203361511, ], - "eval_step_metrics": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + "eval_step_metrics": [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], }, # Scenario 4: Distributed google-bert/bert-base-uncased training (world_size=2) "bert_base_uncased_config_imdb_distributed_ws2": { diff --git a/tests/finetune/test_finetune.py b/tests/finetune/test_finetune.py index 300ade704..dc9acf1ca 100644 --- a/tests/finetune/test_finetune.py +++ b/tests/finetune/test_finetune.py @@ -21,7 +21,7 @@ from tests.finetune import constants as constant from tests.finetune import reference_data as ref_data -alpaca_json_path = os.path.join(os.getcwd(), "alpaca_data.json") +alpaca_json_path = os.path.join(os.getcwd(), "./dataset/alpaca_data.json") def clean_up(path): @@ -34,7 +34,8 @@ def clean_up(path): def download_alpaca(): alpaca_url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json" response = requests.get(alpaca_url) - + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(alpaca_json_path), exist_ok=True) with open(alpaca_json_path, "wb") as f: f.write(response.content) @@ -140,15 +141,7 @@ def assert_list_close(ref_list, actual_list, atol, name, scenario_key, current_w ] -@pytest.mark.skip() # remove when it's clear why diff val_step_loss values are observed in diff runs on existing code (even without PR #478 changes) -@pytest.mark.cli -@pytest.mark.on_qaic -@pytest.mark.finetune -@pytest.mark.parametrize( - "model_name,task_mode,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,scenario_key", # This parameter will be used to look up reference data - configs, -) -def test_finetune( +def train_function( model_name, task_mode, max_eval_step, @@ -211,93 +204,190 @@ def test_finetune( download_alpaca() results = finetune(**kwargs) + all_ref_metrices = { + "ref_train_losses": ref_train_losses, + "ref_eval_losses": ref_eval_losses, + "ref_train_metrics": ref_train_metrics, + "ref_eval_metrics": ref_eval_metrics, + } - # Assertions for step-level values using the helper function - assert_list_close( - ref_train_losses, - results["train_step_loss"], - constant.LOSS_ATOL, - "Train Step Losses", - scenario_key, - current_world_size, - current_rank, - ) - assert_list_close( - ref_eval_losses, - results["eval_step_loss"], - constant.LOSS_ATOL, - "Eval Step Losses", - scenario_key, - current_world_size, - current_rank, - ) - assert_list_close( - ref_train_metrics, - results["train_step_metric"], - constant.METRIC_ATOL, - "Train Step Metrics", - scenario_key, - current_world_size, - current_rank, - ) - assert_list_close( - ref_eval_metrics, - results["eval_step_metric"], - constant.METRIC_ATOL, - "Eval Step Metrics", + all_config_spy = { + "train_config_spy": train_config_spy, + "generate_dataset_config_spy": generate_dataset_config_spy, + "generate_peft_config_spy": generate_peft_config_spy, + "get_dataloader_kwargs_spy": get_dataloader_kwargs_spy, + "update_config_spy": update_config_spy, + "get_custom_data_collator_spy": get_custom_data_collator_spy, + "get_preprocessed_dataset_spy": get_preprocessed_dataset_spy, + "get_longest_seq_length_spy": get_longest_seq_length_spy, + "print_model_size_spy": print_model_size_spy, + "train_spy": train_spy, + "current_world_size": current_world_size, + "current_rank": current_rank, + } + return results, all_ref_metrices, all_config_spy + + +@pytest.mark.cli +@pytest.mark.on_qaic +@pytest.mark.finetune +@pytest.mark.parametrize( + "model_name,task_mode,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,scenario_key", # This parameter will be used to look up reference data + configs, +) +def test_finetune_functional( + model_name, + task_mode, + max_eval_step, + max_train_step, + dataset_name, + data_path, + intermediate_step_save, + context_length, + run_validation, + use_peft, + device, + scenario_key, + mocker, +): + results, all_ref_metrices, all_config_spy = train_function( + model_name, + task_mode, + max_eval_step, + max_train_step, + dataset_name, + data_path, + intermediate_step_save, + context_length, + run_validation, + use_peft, + device, scenario_key, - current_world_size, - current_rank, + mocker, ) + # Assertions for step-level values using the helper function assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds." - - train_config_spy.assert_called_once() - generate_dataset_config_spy.assert_called_once() + all_config_spy["train_config_spy"].assert_called_once() + all_config_spy["generate_dataset_config_spy"].assert_called_once() if task_mode == Task_Mode.GENERATION: - generate_peft_config_spy.assert_called_once() - get_longest_seq_length_spy.assert_called_once() - print_model_size_spy.assert_called_once() - train_spy.assert_called_once() - - assert update_config_spy.call_count == 1 - assert get_custom_data_collator_spy.call_count == 2 - assert get_dataloader_kwargs_spy.call_count == 2 - assert get_preprocessed_dataset_spy.call_count == 2 - - args, kwargs = train_spy.call_args + all_config_spy["generate_peft_config_spy"].assert_called_once() + all_config_spy["get_longest_seq_length_spy"].assert_called_once() + all_config_spy["print_model_size_spy"].assert_called_once() + all_config_spy["train_spy"].assert_called_once() + assert all_config_spy["update_config_spy"].call_count == 1 + assert all_config_spy["get_custom_data_collator_spy"].call_count == 2 + assert all_config_spy["get_dataloader_kwargs_spy"].call_count == 2 + assert all_config_spy["get_preprocessed_dataset_spy"].call_count == 2 + args, kwargs = all_config_spy["train_spy"].call_args train_dataloader = args[2] eval_dataloader = args[3] optimizer = args[4] - batch = next(iter(train_dataloader)) assert "labels" in batch.keys() assert "input_ids" in batch.keys() assert "attention_mask" in batch.keys() - assert isinstance(optimizer, optim.AdamW) - assert isinstance(train_dataloader, DataLoader) if run_validation: assert isinstance(eval_dataloader, DataLoader) else: assert eval_dataloader is None - - args, kwargs = update_config_spy.call_args_list[0] + args, kwargs = all_config_spy["update_config_spy"].call_args_list[0] train_config = args[0] assert max_train_step >= train_config.gradient_accumulation_steps, ( "Total training step should be more than " f"{train_config.gradient_accumulation_steps} which is gradient accumulation steps." ) - if use_peft: saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors") else: saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/model.safetensors") assert os.path.isfile(saved_file) - clean_up(train_config.output_dir) clean_up("qaic-dumps") if dataset_name == "alpaca_dataset": clean_up(alpaca_json_path) + + +@pytest.mark.skip() # remove when it's clear why diff val_step_loss values are observed in diff runs on existing code (even without PR #478 changes) +@pytest.mark.cli +@pytest.mark.on_qaic +@pytest.mark.finetune +@pytest.mark.parametrize( + "model_name,task_mode,max_eval_step,max_train_step,dataset_name,data_path,intermediate_step_save,context_length,run_validation,use_peft,device,scenario_key", # This parameter will be used to look up reference data + configs, +) +def test_finetune_assert( + model_name, + task_mode, + max_eval_step, + max_train_step, + dataset_name, + data_path, + intermediate_step_save, + context_length, + run_validation, + use_peft, + device, + scenario_key, + mocker, +): + results, all_ref_metrices, all_config_spy = train_function( + model_name, + task_mode, + max_eval_step, + max_train_step, + dataset_name, + data_path, + intermediate_step_save, + context_length, + run_validation, + use_peft, + device, + scenario_key, + mocker, + ) + + # Assertions for step-level values using the helper function + assert_list_close( + all_ref_metrices["ref_train_losses"], + results["train_step_loss"], + constant.LOSS_ATOL, + "Train Step Losses", + scenario_key, + all_config_spy["current_world_size"], + all_config_spy["current_rank"], + ) + assert_list_close( + all_ref_metrices["ref_eval_losses"], + results["eval_step_loss"], + constant.LOSS_ATOL, + "Eval Step Losses", + scenario_key, + all_config_spy["current_world_size"], + all_config_spy["current_rank"], + ) + assert_list_close( + all_ref_metrices["ref_train_metrics"], + results["train_step_metric"], + constant.METRIC_ATOL, + "Train Step Metrics", + scenario_key, + all_config_spy["current_world_size"], + all_config_spy["current_rank"], + ) + assert_list_close( + all_ref_metrices["ref_eval_metrics"], + results["eval_step_metric"], + constant.METRIC_ATOL, + "Eval Step Metrics", + scenario_key, + all_config_spy["current_world_size"], + all_config_spy["current_rank"], + ) + clean_up("qaic-dumps") + + if dataset_name == "alpaca_dataset": + clean_up(alpaca_json_path) diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 00a4216b7..46b33c60b 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( # export start = perf_counter() - qeff_model.export(export_dir=tmp_path) + onnx_path = qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash) @@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( assert export_time_1 < export_time_0 # test compile - qeff_model.compile(prefill_seq_len=32, ctx_len=64) + qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py index cc94467db..c3bb2f140 100644 --- a/tests/peft/test_peft_model.py +++ b/tests/peft/test_peft_model.py @@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path): _, lora_model = create_peft_model(base_config, adapter_config) qeff_model = QEffAutoPeftModelForCausalLM(lora_model) - qeff_model.export(tmp_path) + onnx_path = qeff_model.export(tmp_path) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_0 = end - start @@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con ) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_1 = end - start assert compile_time_1 < 0.01 * compile_time_0 diff --git a/tests/transformers/models/image_text_to_text/test_continuous_batching.py b/tests/transformers/models/image_text_to_text/test_continuous_batching.py new file mode 100644 index 000000000..2f33b7ee8 --- /dev/null +++ b/tests/transformers/models/image_text_to_text/test_continuous_batching.py @@ -0,0 +1,720 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import pytest +import requests +from PIL import Image +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + AutoTokenizer, + GenerationConfig, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText +from QEfficient.utils import hf_download +from QEfficient.utils._utils import get_num_layers_vlm +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunnerInternVL, ApiRunnerMolmo, ApiRunnerVlm +from QEfficient.utils.test_utils import InternProcessor + +NEW_GENERATION_TOKENS = 10 + +# TODO: Add CB support for kv_offload=False case +test_models_config = [ + # CONFIG PARAMS NEEDED FOR A MODEL TO BE TESTED + # ( + # model_name, + # kv_offload, + # batch_size, + # prompt_len, + # ctx_len, + # img_size, + # img_url_list", + # text_prompt_list, + # number of layers of the model, + # full_batch_size + # ), + ( + "llava-hf/llava-1.5-7b-hf", + True, + 1, + 784, + 1024, + 336, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + # Disabled in CI due to performance issues + # ( + # "meta-llama/Llama-4-Scout-17B-16E-Instruct", + # True, + # 1, + # 128, + # 3072, + # 336, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 4, + # 4, + # ), + ( + "google/gemma-3-4b-it", + True, + 1, + 128, + 3072, + 896, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + ( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + True, + 1, + 128, + 4096, + 1540, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 1, + 4, + ), + ( + "Qwen/Qwen2.5-VL-3B-Instruct", + True, + 1, + 128, + 4096, + 1540, + [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), + # ( + # "meta-llama/Llama-3.2-11B-Vision-Instruct", + # True, + # 1, + # 32, + # 512, + # 560, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 7, + # 4, + # ), +] + +intern_model_config = [ + ( + "OpenGVLab/InternVL2_5-1B", + True, + 1, + 384, + 512, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), + ( + "OpenGVLab/InternVL3_5-1B", + True, + 1, + 384, + 512, + [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + ], + [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", + ], + 2, + 4, + ), +] + +molmo_model_config = [ + # Disabled in CI due to HF issues + # ( + # "allenai/Molmo-7B-D-0924", + # True, + # 1, + # 128, + # 4096, + # ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + # "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg",], + # ["Can you describe the image in detail?", + # "What are the objects in the image?", + # "What is the main subject of the image?", + # "What colors are predominant in the image?"], + # 2, + # 4, + # ), +] + + +def load_image_text_to_text_model(model_config): + model_path = hf_download( + repo_id=model_config._name_or_path, + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + try: + model_hf = AutoModelForImageTextToText.from_pretrained( + model_path, + low_cpu_mem_usage=False, + config=model_config, + ) + except ValueError: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=model_config, + ) + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def set_num_layers(config, n_layer=1): + ## -1 indicates use all the layers of the model. + if n_layer == -1: + return config + elif hasattr(config, "model_type") and "mllama" in config.model_type: + config.text_config.num_hidden_layers = n_layer + config.text_config.cross_attention_layers = [ + x for x in config.text_config.cross_attention_layers if x < n_layer + ] + elif hasattr(config, "text_config"): + config.text_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + elif hasattr(config, "llm_config"): + config.llm_config.num_hidden_layers = n_layer + config.vision_config.num_hidden_layers = n_layer + else: + config.num_hidden_layers = n_layer + return config + + +def check_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + img_size: int, + image_urls: List[str], + queries: List[str], + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + num_devices: int = 1, + full_batch_size: int = 4, + kv_offload: bool = True, +): + model_config = {"model_name": model_name} + model_config["img_size"] = img_size + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + + n_layer = get_num_layers_vlm(config) + + image_height = None + image_width = None + + images = [] + for img_url in image_urls: + image = Image.open(requests.get(img_url, stream=True).raw) + if model_name == "mistralai/Mistral-Small-3.1-24B-Instruct-2503": + image_height = 1540 + image_width = 1540 + image = image.resize((image_height, image_width)) + images.append(image) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": queries[0]}, + {"type": "image"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + api_runner = ApiRunnerVlm( + batch_size, + processor, + config, + images[0], + conversation, + prompt, + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list) + + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_config["model_name"], + kv_offload=kv_offload, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + qeff_model.compile( + img_size=model_config["img_size"], + num_cores=16, + num_devices=num_devices, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + ) + + print("QPC Outputs (QAIC):") + exec_info = qeff_model.generate( + tokenizer=processor.tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + image_height=image_height, + image_width=image_width, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with same prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries) + + print("QPC Outputs (QAIC):") + exec_info = qeff_model.generate( + tokenizer=processor.tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + image_height=image_height, + image_width=image_width, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with different prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +def check_molmo_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + image_urls: List[str], + queries: List[str], + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + num_devices: int = 1, + full_batch_size: int = 4, + kv_offload: bool = True, +): + model_config = {"model_name": model_name} + + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config._attn_implementation = "eager" + config = set_num_layers(config, n_layer=n_layer) + model_hf, _ = load_image_text_to_text_model(config) + n_layer = (n_layer, n_layer) + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + images = [] + for img_url in image_urls: + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + image = image.resize((536, 354)) + images.append(image) + + api_runner = ApiRunnerMolmo( + batch_size, + processor, + config, + images[0], + queries[0], + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + generation_config = GenerationConfig(max_new_tokens=NEW_GENERATION_TOKENS, stop_strings="<|endoftext|>") + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list, generation_config) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + attn_implementation="eager", + kv_offload=kv_offload, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_devices=4, + batch_size=1, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with same prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries, generation_config) + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching with different prompt:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +def check_intern_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name: str, + image_urls: str, + queries: str, + prompt_len: int, + ctx_len: int, + max_gen_len: int = 20, + batch_size: int = 1, + n_layer: int = 1, + kv_offload: bool = True, + num_devices: int = 1, + full_batch_size: int = 4, +): + model_config = {"model_name": model_name} + + config = AutoConfig.from_pretrained(model_config["model_name"], trust_remote_code=True) + config._attn_implementation = "eager" + config = set_num_layers(config, n_layer=n_layer) + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, + ) + n_layer = get_num_layers_vlm(config) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + processor = InternProcessor(model_hf, tokenizer) + + generation_config = dict(max_new_tokens=max_gen_len, do_sample=False) + generation_config["eos_token_id"] = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) + + images = [] + for img_url in image_urls: + img = requests.get(img_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + image = image.resize((448, 448)) + images.append(image) + + api_runner = ApiRunnerInternVL( + batch_size, + processor, + config, + images[0], + queries[0], + prompt_len, + ctx_len, + max_gen_len, + n_layer, + ) + + # For same prompt + image_list = [images[0]] * full_batch_size + prompt_list = [queries[0]] * full_batch_size + + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, image_list, prompt_list) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + ) + + qeff_model.export() + + qeff_model.compile( + num_patches=1, + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_devices=4, + batch_size=1, + full_batch_size=full_batch_size, + mxfp6_matmul=False, + ) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=[image_urls[0]] * full_batch_size, + prompts=prompt_list, + generation_len=max_gen_len, + image_height=448, + image_width=448, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching for same prompts:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for same prompts" + ) + + # For different prompts + pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries) + + exec_info = qeff_model.generate( + tokenizer=tokenizer, + processor=processor, + images=image_urls, + prompts=queries, + generation_len=max_gen_len, + image_height=448, + image_width=448, + ) + + qpc_tokens = exec_info.generated_ids[:, :max_gen_len] + print("QPC Outputs (QAIC) for Continuous Batching for different prompts:") + print(exec_info.generated_texts) + + for i in range(full_batch_size): + assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), ( + f"Tokens don't match for prompt {i} between HF and QPC output for different prompts" + ) + return + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_urls, queries, n_layer, full_batch_size", + test_models_config, +) +def test_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_urls, queries, n_layer, full_batch_size +): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + check_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + img_size=img_size, + image_urls=img_urls, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_urls, queries, n_layer, full_batch_size", + molmo_model_config, +) +def test_image_text_to_text_molmo_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_urls, queries, n_layer, full_batch_size +): + check_molmo_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + image_urls=img_urls, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.parametrize( + "model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, queries, n_layer, full_batch_size", + intern_model_config, +) +def test_image_text_to_text_intern_pytorch_vs_ai100_continuous_batching( + model_name, kv_offload, batch_size, prompt_len, ctx_len, img_url, queries, n_layer, full_batch_size +): + check_intern_image_text_to_text_pytorch_vs_ai100_continuous_batching( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + max_gen_len=NEW_GENERATION_TOKENS, + image_urls=img_url, + queries=queries, + n_layer=n_layer, + batch_size=batch_size, + kv_offload=kv_offload, + full_batch_size=full_batch_size, + ) diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py similarity index 97% rename from tests/transformers/models/test_image_text_to_text_models.py rename to tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py index a7b4162aa..e6a145195 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py @@ -134,6 +134,17 @@ "Can you describe the image in detail.", 1, ), + ( + "Qwen/Qwen2.5-VL-3B-Instruct", + True, + 1, + 128, + 4096, + 1540, + "https://picsum.photos/id/237/536/354", + "Can you describe the image in detail.", + 1, + ), # ( # "meta-llama/Llama-3.2-11B-Vision-Instruct", # True, @@ -320,6 +331,10 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config=qnn_config, ) inputs = processor(images=image, text=prompt, return_tensors="pt") + if hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl": + inputs = qeff_model.model.prepare_inputs_for_generation( + inputs=inputs, prefill_seq_len=prompt_len, batch_size=batch_size + ) if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) print("QPC Outputs (QAIC):") diff --git a/tests/transformers/models/qnn_config.json b/tests/transformers/models/qnn_config.json new file mode 100644 index 000000000..b1f249e2b --- /dev/null +++ b/tests/transformers/models/qnn_config.json @@ -0,0 +1,10 @@ +{ + "SKIP_QNN_CONVERTER_STEP":false, + "context_binary_generator_args_extension":"--log_level debug", + "converter_args_extension":"--onnx_defer_loading", + "qnn_compilation_backend":{ + "compiler_enable_depth_first":true, + "compiler_printDDRStats":false, + "compiler_printPerfMetrics":false + } +} \ No newline at end of file diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 86bce4441..ead636759 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -25,6 +25,7 @@ from QEfficient.utils.test_utils import ModelConfig test_models_causal = [ + "openai/gpt-oss-20b", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "gpt2", "Salesforce/codegen-350M-mono", @@ -66,6 +67,11 @@ "Qwen/Qwen2-0.5B", ] +test_models_blockedKV = [ + # "meta-llama/Llama-3.3-70B-Instruct", + "meta-llama/Llama-3.2-1B", +] + def get_custom_n_layers(model_name): """ @@ -76,11 +82,11 @@ def get_custom_n_layers(model_name): :return n_layer """ - if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8"}: + if model_name in {"microsoft/Phi-3-mini-4k-instruct", "neuralmagic/Qwen2-0.5B-Instruct-FP8", "openai/gpt-oss-20b"}: return 2 elif model_name in ModelConfig.SWIFTKV_MODELS: return None - return 16 + return 1 def load_causal_lm_model(model_name, n_layer=1, config=None): @@ -146,6 +152,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, pytorch_hf_tokens: Optional[list] = None, + qaic_config: Optional[dict] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -157,6 +164,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( """ replace_transformers_quantizers() if config is None: + n_layer = get_custom_n_layers(model_name) model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) else: model_hf, _ = load_causal_lm_model(model_name, config=config) @@ -177,7 +185,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name + copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config ) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) @@ -241,7 +249,11 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)] qeff_model = QEFFAutoModelForCausalLM( - model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name + model_hf, + continuous_batching=True, + is_tlm=is_tlm, + pretrained_model_name_or_path=model_name, + qaic_config=qaic_config, ) onnx_model_path = qeff_model.export() @@ -486,3 +498,30 @@ def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100_qnn(): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name, n_layer=n_layer, prefill_only=False, enable_qnn=True, qnn_config=qnn_config_json_path ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py new file mode 100644 index 000000000..6358940df --- /dev/null +++ b/tests/transformers/models/test_disagg_mode.py @@ -0,0 +1,192 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers + +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 + +prompt2 = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +prompt1 = "Once upon a time" + +prompts = [prompt1, prompt2] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + qeff_out = qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + qpc_out = prefill_session.run(inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 + + +@pytest.mark.skip(reason="no way of currently testing this without the assert sdk") +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill_chunked(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True, enable_chunking=True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = CTX_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + + qeff_out = qeff_model.model(**chunk_inputs) + inputs["past_key_values"] = qeff_out["past_key_values"] + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + ) + prefill_session = QAICInferenceSession(prefill_qpc_path) + prefill_session.skip_buffers( + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] + ) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..26cb6fda9 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -5,15 +5,18 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional, Tuple, Union import numpy as np import pytest +from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer -from QEfficient import QEFFAutoModelForCausalLM +from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import load_hf_tokenizer from QEfficient.utils.constants import Constants +from QEfficient.utils.test_utils import InternProcessor +from tests.transformers.models.image_text_to_text.test_continuous_batching import set_num_layers sampler_transform_configs = [ pytest.param( @@ -24,6 +27,20 @@ 20, # generation_len 2, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "OpenGVLab/InternVL2_5-1B", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] greedy_sampling_configs = [ @@ -35,6 +52,20 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "OpenGVLab/InternVL2_5-1B", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] random_sampling_configs = [ @@ -46,23 +77,98 @@ 20, # generation_len 4, # full_batch_size 1, # spec_length + False, # is_vlm + ), + pytest.param( + "OpenGVLab/InternVL2_5-1B", # model + ( + ["https://picsum.photos/id/237/536/354"] * 4, + ["Can you describe the image in detail."] * 4, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 4, # full_batch_size + None, # spec_length + True, # is_vlm + ), +] +guided_decoding_configs = [ + pytest.param( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # model + Constants.INPUT_STR * 4, # prompts + 32, # prefill_seq_len + 64, # ctx_len + 20, # generation_len + 4, # full_batch_size + 1, # spec_length + False, # is_vlm + ), + pytest.param( + "OpenGVLab/InternVL2_5-1B", # model + ( + ["https://picsum.photos/id/237/536/354"] * 2, + ["Can you describe the image in detail."] * 2, + ), # images and prompts + 128, # prefill_seq_len + 4096, # ctx_len + 20, # generation_len + 2, # full_batch_size + None, # spec_length + True, # is_vlm ), ] +def prepare_model_setup( + model: str, is_vlm: bool, num_hidden_layers: int, prompts: Union[List, Tuple], spec_length: Optional[int] +): + additional_configs = {} + additional_params = {} + if is_vlm: + config = AutoConfig.from_pretrained(model, trust_remote_code=True) + config = set_num_layers(config, n_layer=num_hidden_layers) + additional_configs["config"] = config + additional_configs["kv_offload"] = True + assert isinstance(prompts, tuple), "For VLMs, both image and text prompts must be provided." + additional_params["images"] = prompts[0] + prompts = prompts[1] + + if "InternVL" in model: + additional_configs["trust_remote_code"] = True + model_hf = AutoModelForCausalLM.from_pretrained( + model, + config=config, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True, use_fast=False) + additional_params["processor"] = InternProcessor(model_hf, tokenizer) + qeff_class = QEFFAutoModelForCausalLM + else: + additional_params["processor"] = AutoProcessor.from_pretrained(model) + qeff_class = QEFFAutoModelForImageTextToText + else: + if num_hidden_layers != -1: + additional_configs["num_hidden_layers"] = num_hidden_layers + spec_length = (spec_length or 1) - 1 + qeff_class = QEFFAutoModelForCausalLM + return additional_configs, additional_params, prompts, spec_length, qeff_class + + @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", sampler_transform_configs, ) def test_sampler_transform( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, - spec_length: int, + spec_length: Optional[int], + is_vlm: bool, ): """ Test if `SamplerTransform` adds nodes at the output of a `QEffForCausalLM model` to enable the @@ -70,48 +176,78 @@ def test_sampler_transform( next tokens and/or probability distributions. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + num_hidden_layers = 2 + additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( + model, is_vlm, num_hidden_layers, prompts, spec_length + ) + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 512, + "include_guided_decoding": True, + }, + **additional_configs, + ) + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=2, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, + ) + model_w_sampler_qpc_path = model_w_sampler.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, ) - model_w_sampler_qpc_path: str = model_w_sampler.compile( + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) - model_wo_sampler_qpc_path: str = model_wo_sampler.compile( + model_wo_sampler_qpc_path = model_wo_sampler.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) + if is_vlm: + model_w_sampler_qpc_path = model_w_sampler_qpc_path[1] + model_w_sampler_w_guided_decoding_qpc_path = model_w_sampler_w_guided_decoding_qpc_path[1] + model_wo_sampler_qpc_path = model_wo_sampler_qpc_path[1] # Init qaic session model_w_sampler_session = QAICInferenceSession(model_w_sampler_qpc_path) + model_w_sampler_w_guided_decoding_session = QAICInferenceSession(model_w_sampler_w_guided_decoding_qpc_path) model_wo_sampler_session = QAICInferenceSession(model_wo_sampler_qpc_path) # Skip inputs/outputs buffers @@ -119,6 +255,12 @@ def test_sampler_transform( model_w_sampler_session.skip_buffers( set([x for x in model_w_sampler_session.output_names if x.endswith("_RetainedState")]) ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.input_names if x.startswith("past_")]) + ) + model_w_sampler_w_guided_decoding_session.skip_buffers( + set([x for x in model_w_sampler_w_guided_decoding_session.output_names if x.endswith("_RetainedState")]) + ) model_wo_sampler_session.skip_buffers( set([x for x in model_wo_sampler_session.input_names if x.startswith("past_")]) ) @@ -132,47 +274,58 @@ def test_sampler_transform( assert input_name in model_w_sampler_session.input_names, ( f"Sampler input {input_name} not found in QPC compiled with On Device Sampler" ) + assert input_name in model_w_sampler_w_guided_decoding_session.input_names, ( + f"Sampler input {input_name} not found in QPC compiled with On Device Sampler and Guided Decoding" + ) assert input_name not in model_wo_sampler_session.input_names, ( f"Sampler input {input_name} found in QPC compiled without On Device Sampler" ) + assert "token_bitmasks" in model_w_sampler_w_guided_decoding_session.input_names, ( + "Sampler input token_bitmasks not found in QPC compiled with On Device Sampler and Guided Decoding" + ) @pytest.mark.on_qaic @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", greedy_sampling_configs, ) def test_greedy_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, - spec_length: int, + spec_length: Optional[int], + is_vlm: bool, ): """ - Test greedy sampling with QPC compiled with and without On Device Sampling. + Test greedy sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + num_hidden_layers = 4 + additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( + model, is_vlm, num_hidden_layers, prompts, spec_length + ) + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": True, "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, - num_hidden_layers=4, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -180,7 +333,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -190,7 +343,7 @@ def test_greedy_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -211,8 +364,9 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -221,6 +375,7 @@ def test_greedy_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts and ids @@ -233,25 +388,29 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( - "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", random_sampling_configs, ) def test_random_sampling( model: str, - prompts: List[str], + prompts: Union[List[str], tuple[List[str], List[str]]], prefill_seq_len: int, ctx_len: int, generation_len: int, full_batch_size: int, - spec_length: int, + spec_length: Optional[int], + is_vlm: bool, ): """ - Test random sampling with QPC compiled with and without On Device Sampling. + Test random sampling with QPCs compiled with and without On Device Sampling. """ # Export and compile QEfficient models - model_w_sampler = QEFFAutoModelForCausalLM.from_pretrained( + num_hidden_layers = -1 + additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( + model, is_vlm, num_hidden_layers, prompts, spec_length + ) + model_w_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ @@ -259,14 +418,16 @@ def test_random_sampling( "return_pdfs": False, "max_top_k_ids": 512, }, + **additional_configs, ) - model_wo_sampler = QEFFAutoModelForCausalLM.from_pretrained( + model_wo_sampler = qeff_class.from_pretrained( model, continuous_batching=True, qaic_config={ "include_sampler": False, "return_pdfs": False, }, + **additional_configs, ) model_w_sampler.compile( prefill_seq_len=prefill_seq_len, @@ -274,7 +435,7 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) @@ -284,13 +445,14 @@ def test_random_sampling( full_batch_size=full_batch_size, num_devices=1, num_cores=16, - num_speculative_tokens=spec_length - 1, + num_speculative_tokens=spec_length, mxint8_kv_cache=True, mxfp6_matmul=True, ) # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,12 +463,15 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, + **additional_params, ) model_wo_sampler_exec_info = model_wo_sampler.generate( tokenizer=tokenizer, @@ -315,63 +480,120 @@ def test_random_sampling( include_sampler=False, return_pdfs=False, sampling_params=None, + **additional_params, ) # Compare generated texts - golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", - "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", - } - golden_ids = { - "w_sampler": [ - [ - 21380, - 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, - ] - ], - "wo_sampler": [ - [ - 2259, - 7075, - 322, - 306, - 626, - 263, - 7047, - 22055, - 29889, - 306, - 505, - 1063, - 1985, - 297, - 278, - 13661, - 363, - 278, - 4940, - 29871, - ] - ], - } + if model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0": + golden_texts = { + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", + "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", + } + golden_ids = { + "w_sampler": [ + [ + 319, + 3615, + 322, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, + ] + ], + "wo_sampler": [ + [ + 2259, + 7075, + 322, + 306, + 626, + 263, + 7047, + 22055, + 29889, + 306, + 505, + 1063, + 1985, + 297, + 278, + 13661, + 363, + 278, + 4940, + 29871, + ] + ], + } + elif model == "OpenGVLab/InternVL2_5-1B": + golden_texts = { + "w_sampler": "The description of this picture would be as follows:\n\nAn adorable black puppy is sitting on a wooden surface", + "wo_sampler": "The image features a black puppy sitting on a wooden surface. The puppy has a shiny, glossy coat", + } + golden_ids = { + "w_sampler": [ + [ + 785, + 4008, + 315, + 419, + 6802, + 1035, + 387, + 438, + 11017, + 1447, + 2082, + 40608, + 3691, + 41189, + 374, + 11699, + 389, + 264, + 22360, + 7329, + ] + ], + "wo_sampler": [ + [ + 785, + 2168, + 4419, + 264, + 3691, + 41189, + 11699, + 389, + 264, + 22360, + 7329, + 13, + 576, + 41189, + 702, + 264, + 41199, + 11, + 73056, + 22875, + ] + ], + } for i in range(full_batch_size): assert ( tokenizer.decode(model_w_sampler_exec_info.generated_ids[i][:generation_len]) == golden_texts["w_sampler"] @@ -385,3 +607,118 @@ def test_random_sampling( assert (model_wo_sampler_exec_info.generated_ids[i][:generation_len] == golden_ids["wo_sampler"]).all(), ( "Without sampler generated ids do not match" ) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize( + "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length, is_vlm", + guided_decoding_configs, +) +def test_guided_decoding( + model: str, + prompts: Union[List[str], tuple[List[str], List[str]]], + prefill_seq_len: int, + ctx_len: int, + generation_len: int, + full_batch_size: int, + spec_length: Optional[int], + is_vlm: bool, +): + """ + Test QPCs compiled with and without guided decoding. + """ + # Export and compile QEfficient models + num_hidden_layers = 2 + additional_configs, additional_params, prompts, spec_length, qeff_class = prepare_model_setup( + model, is_vlm, num_hidden_layers, prompts, spec_length + ) + model_w_sampler_w_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + "include_guided_decoding": True, + }, + **additional_configs, + ) + model_w_sampler_wo_guided_decoding = qeff_class.from_pretrained( + model, + continuous_batching=True, + qaic_config={ + "include_sampler": True, + "return_pdfs": False, + "max_top_k_ids": 1024, + }, + **additional_configs, + ) + model_w_sampler_w_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + model_w_sampler_wo_guided_decoding.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + num_devices=1, + num_cores=16, + num_speculative_tokens=spec_length, + mxint8_kv_cache=True, + mxfp6_matmul=True, + ) + + # Generate texts from prompts + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) + sampling_params = { + "repetition_penalties": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "presence_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + # "frequency_penalties": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(1024, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 1024), dtype=np.float32), + } + if is_vlm: + vocab_size = model_w_sampler_w_guided_decoding.model.language_model.config.vocab_size + else: + vocab_size = model_w_sampler_w_guided_decoding.model.config.vocab_size + model_w_sampler_w_guided_decoding_exec_info = model_w_sampler_w_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + include_guided_decoding=True, + sampling_params={ + **sampling_params, + **{ + "token_bitmasks": np.tile( + np.random.choice([True, False], size=(vocab_size,)), + (full_batch_size, 1), + ) + }, + }, + **additional_params, + ) + model_w_sampler_wo_guided_decoding_exec_info = model_w_sampler_wo_guided_decoding.generate( + tokenizer=tokenizer, + prompts=prompts, + generation_len=generation_len, + include_sampler=True, + return_pdfs=False, + sampling_params=sampling_params, + **additional_params, + ) + assert ( + model_w_sampler_w_guided_decoding_exec_info.generated_ids + != model_w_sampler_wo_guided_decoding_exec_info.generated_ids + ).any(), "Sampler outputs with and without guided decoding should not match" diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index bdc15519e..72477d56a 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -14,10 +14,11 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.hash_utils import hash_dict_params -configs = [ +test_configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), @@ -33,32 +34,46 @@ ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs +test_prefill_only_specialized_models_configs = [ + ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}), ] + + +def get_auto_config_from_test_config(configs): + auto_configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs + ] + return auto_configs + + +configs = get_auto_config_from_test_config(test_configs) config_ids = [x.model_type for x in configs] +prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs) +prefill_only_config_ids = [x.model_type for x in prefill_only_configs] + model_kwargs = {"attn_implementation": "eager"} @@ -143,20 +158,21 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path): @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash_creation(config, cb, tmp_path): +def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) qeff_model = QEFFAutoModelForCausalLM(model, cb) - qeff_model.export(tmp_path) + qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc) hash_params = {} hash_params["config"] = qeff_model.model.config.to_diff_dict() hash_params["peft_config"] = None hash_params["applied_transform_names"] = qeff_model._transform_names() hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["max_seq_len_cached"] = None hash_params["qaic_config"] = None # Create parameters separately for hash creation - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -177,13 +193,24 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", 2: "ctx_len", } + pkv_dynamic_axes = ( + qeff_model.model.get_pkv_dynamic_axes() + if hasattr(qeff_model.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes + ) + pkv_dynamic_axes = ( + [pkv_dynamic_axes] * qeff_model.model.config.num_hidden_layers + if isinstance(pkv_dynamic_axes, dict) + else pkv_dynamic_axes + ) output_names = [] output_names.append("logits") - + onnx_out_name_suffix = "InternalRetainedState" if subfunc else "RetainedState" for i in range(qeff_model.num_layers): + pkv_dynamic_axes[i][0] = "full_batch_size" if qeff_model.continuous_batching else "batch_size" for kv in ["key", "value"]: - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes - output_names.append(f"past_{kv}.{i}_RetainedState") + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] + output_names.append(f"past_{kv}.{i}_{onnx_out_name_suffix}") if qeff_model.continuous_batching: dynamic_axes["batch_index"] = {0: "batch_size"} @@ -192,14 +219,35 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): export_params["output_names"] = output_names export_params["dynamic_axes"] = dynamic_axes hash_params["export_params"] = export_params + if subfunc: + hash_params["export_modules_as_functions"] = get_decoder_layer_classes_for_export(qeff_model.model) + manual_hash = hash_dict_params(hash_params) assert manual_hash == qeff_model.export_hash +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids) +def test_prefill_only_specialized_models(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + if cb: + with pytest.raises(NotImplementedError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + else: + with pytest.raises(ValueError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False) + first_export_hash = qeff_model.export_hash + qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False) + second_export_hash = qeff_model.export_hash + assert first_export_hash != second_export_hash + + @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/transformers/test_speech_seq2seq.py b/tests/transformers/test_speech_seq2seq.py index 59281b73b..bc53cb539 100644 --- a/tests/transformers/test_speech_seq2seq.py +++ b/tests/transformers/test_speech_seq2seq.py @@ -141,7 +141,7 @@ def test_seq2seq_hash_creation(config, tmp_path): @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils.export_utils.QEFF_HOME", tmp_path) yield tmp_path diff --git a/tests/transformers/test_subfunction.py b/tests/transformers/test_subfunction.py new file mode 100644 index 000000000..53ddbb474 --- /dev/null +++ b/tests/transformers/test_subfunction.py @@ -0,0 +1,120 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +from collections import Counter + +import onnx +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +def has_gpt2block_function(onnx_path): + """Check if ONNX model contains QEffGPT2Block function definition.""" + model = onnx.load(onnx_path, load_external_data=False) + function_names = [f.name for f in model.functions] + gpt2block_functions = [name for name in function_names if "QEffGPT2Block" in name] + return len(gpt2block_functions) > 0, gpt2block_functions + + +def get_gpt2block_call_count(onnx_path): + """Get count of QEffGPT2Block function calls in the ONNX model graph.""" + model = onnx.load(onnx_path, load_external_data=False) + calls = Counter([n.op_type for n in model.graph.node]) + gpt2block_calls = {k: v for k, v in calls.items() if "QEffGPT2Block" in k} + return gpt2block_calls + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + # tokenizer = AutoTokenizer.from_pretrained(config.model_type) + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + + # Export with subfunctions enabled + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + + # Export without subfunctions + without_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=False) + + # Verify that the model with subfunctions has QEffGPT2Block function definition + has_gpt2block, gpt2block_names = has_gpt2block_function(with_sub_func_onnx) + assert has_gpt2block, ( + "Model exported with use_onnx_subfunctions=True should contain QEffGPT2Block function definition" + ) + print(f"\nGpt2Block functions found: {gpt2block_names}") + + # Verify that the model without subfunctions has no QEffGPT2Block function definition + has_gpt2block_without, _ = has_gpt2block_function(without_sub_func_onnx) + assert not has_gpt2block_without, ( + "Model exported with use_onnx_subfunctions=False should not contain QEffGPT2Block function definition" + ) + + # Get QEffGPT2Block call counts + gpt2block_calls_with_sub = get_gpt2block_call_count(with_sub_func_onnx) + gpt2block_calls_without_sub = get_gpt2block_call_count(without_sub_func_onnx) + + print(f"\nGpt2Block call counts with subfunctions: {gpt2block_calls_with_sub}") + print(f"QEffGPT2Block call counts without subfunctions: {gpt2block_calls_without_sub}") + + # Verify that QEffGPT2Block function calls exist in the subfunction model + assert len(gpt2block_calls_with_sub) > 0, ( + "Expected to find QEffGPT2Block function calls in graph when use_onnx_subfunctions=True" + ) + + # Verify that QEffGPT2Block function calls do NOT exist in the non-subfunction model + assert len(gpt2block_calls_without_sub) == 0, ( + "Expected NO QEffGPT2Block function calls in graph when use_onnx_subfunctions=False" + ) + + # TODO: Re-enable this check when generation is fully deterministic + # Compile and test generation to ensure functional equivalence + # compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + + # model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) + # generation_00 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + + # model_0_0.compile(onnx_path=without_sub_func_onnx, **compile_params) + # generation_01 = model_0_0.generate(prompts=["Help me with this"], tokenizer=tokenizer) + + # Verify that both models produce the same output + # assert generation_00.generated_texts == generation_01.generated_texts, ( + # "Models with and without subfunctions should produce identical outputs" + # ) diff --git a/tests/utils/test_hash_utils.py b/tests/utils/test_hash_utils.py index fefa73973..b7a5495c6 100644 --- a/tests/utils/test_hash_utils.py +++ b/tests/utils/test_hash_utils.py @@ -41,7 +41,7 @@ def test_to_hashable_float_nan(value): def test_json_serializable(): # Test with a set - assert json_serializable({1, 2, 3}) == [1, 2, 3] + assert json_serializable({1, 2, 3}) == ["1", "2", "3"] # Test with an unsupported type with pytest.raises(TypeError): json_serializable({1, 2, 3, {4, 5}})