diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 15376996b6..e6d76f0508 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -45,6 +45,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners /examples/llm_ptq @NVIDIA/modelopt-examples-llm_ptq-codeowners /examples/llm_qat @NVIDIA/modelopt-examples-llm_qat-codeowners /examples/llm_sparsity @NVIDIA/modelopt-torch-sparsity-codeowners +/examples/megatron_bridge @NVIDIA/modelopt-examples-megatron-codeowners /examples/model_hub @NVIDIA/modelopt-examples-model_hub-codeowners /examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners /examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml new file mode 100644 index 0000000000..b8f5cfe4b6 --- /dev/null +++ b/.github/workflows/example_tests.yml @@ -0,0 +1,170 @@ +name: Example tests + +on: + push: + branches: ["pull-request/[0-9]+"] + # NOTE: paths cannot be used since push happens to copied PR and only latest commit to PR is used + schedule: + - cron: "0 0 * * *" # Nightly + workflow_dispatch: # On-demand + +# Cancel previous runs if new commit is pushed to the same PR +concurrency: + group: ${{ github.workflow }}-${{ startsWith(github.ref, 'refs/heads/pull-request/') && github.ref || github.sha }} + cancel-in-progress: true + +jobs: + check-file-changes: + if: startsWith(github.ref, 'refs/heads/pull-request/') + runs-on: ubuntu-latest + outputs: + any_changed: ${{ steps.changed-tests.outputs.any_changed }} + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + - id: get-pr-info + uses: nv-gha-runners/get-pr-info@main + # Get commit from main branch that is present in the PR to use as base for changed files + - id: calculate-merge-base + env: + PR_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }} + BASE_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).base.sha }} + run: | + (echo -n "merge-base="; git merge-base "$BASE_SHA" "$PR_SHA") | tee --append "${GITHUB_OUTPUT}" + - name: Check for changes in test-relevant directories + id: changed-tests + uses: step-security/changed-files@v46.0.5 + with: + base_sha: ${{ steps.calculate-merge-base.outputs.merge-base }} + sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }} + files: | + .github/workflows/example_tests.yml + examples/** + modelopt/** + setup.py + tests/examples/** + fail_on_initial_diff_error: true + wait-checks: + needs: [check-file-changes] + if: needs.check-file-changes.outputs.any_changed == 'true' + uses: ./.github/workflows/_wait_for_checks.yml + permissions: + checks: read + secrets: inherit + with: + match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass + delay: 300s + + ##### PyTorch Example Tests (speculative_decoding requires 26.01 image) ##### + torch-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + example: [llm_distill, llm_qat, llm_sparsity] + include: + - example: speculative_decoding + docker_image: "26.01" + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-l4-latest-1 + + torch-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + example: [llm_distill, llm_qat, llm_sparsity] + include: + - example: speculative_decoding + docker_image: "26.01" + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:${{ matrix.docker_image || '26.01' }}-py3" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-h100-latest-2 + + ##### TensorRT-LLM Example Tests ##### + trtllm-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + example: [llm_ptq] # vlm_ptq temporarily disabled due to pipeline error + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc6.post3" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-h100-latest-1 + + trtllm-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + example: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc6.post3" + example: ${{ matrix.example }} + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-h100-latest-2 + + ##### ONNX/TensorRT Example Tests ##### + onnx-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + example: [diffusers, torch_onnx] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt:26.01-py3" + example: ${{ matrix.example }} + pip_install_extras: "[all,dev-test]" + runner: linux-amd64-gpu-l4-latest-1 + + onnx-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + example: [diffusers, torch_onnx] + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/tensorrt:26.01-py3" + example: ${{ matrix.example }} + pip_install_extras: "[all,dev-test]" + runner: linux-amd64-gpu-l4-latest-1 + + ##### Required Check for PR ##### + example-pr-required-check: + # Run even if example tests are skipped + if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} + needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr] + runs-on: ubuntu-latest + steps: + - name: Required GPU tests did not succeed + if: | + needs.check-file-changes.result != 'success' || + (needs.check-file-changes.outputs.any_changed == 'true' && ( + needs.torch-pr.result != 'success' || + needs.trtllm-pr.result != 'success' || + needs.onnx-pr.result != 'success' + )) + run: exit 1 diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index 7adf32a122..22fc6cb27e 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -1,4 +1,4 @@ -# NOTE: Make sure this file is consistent with .gitlab/tests.yml +# TODO: Optimize gpu tests runtime! name: GPU tests on: @@ -59,10 +59,18 @@ jobs: gpu-tests-pr: needs: [check-file-changes, wait-checks] if: needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + include: + - example: cuda13-gpu + timeout: 90 + - example: cuda13-gpu-megatron + timeout: 120 runs-on: linux-amd64-gpu-l4-latest-1 - timeout-minutes: 120 + timeout-minutes: ${{ matrix.timeout }} container: &gpu_container - image: nvcr.io/nvidia/pytorch:25.06-py3 + image: nvcr.io/nvidia/pytorch:26.01-py3 env: GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages @@ -76,11 +84,19 @@ jobs: - name: Install dependencies for mip run: apt-get update && apt-get install -y libffi-dev - name: Run gpu tests - run: pip install tox-current-env && tox -e py312-cuda12-gpu --current-env + run: pip install tox-current-env && tox -e ${{ matrix.example }} --current-env gpu-tests-non-pr: if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + include: + - example: cuda13-gpu + timeout: 90 + - example: cuda13-gpu-megatron + timeout: 120 runs-on: linux-amd64-gpu-h100-latest-2 - timeout-minutes: 120 + timeout-minutes: ${{ matrix.timeout }} container: *gpu_container steps: *gpu_steps gpu-pr-required-check: diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 6f7fad3a79..252d4b7195 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -37,7 +37,7 @@ jobs: - uses: actions/checkout@v6 - uses: ./.github/actions/ubuntu-setup - name: Run unit tests - run: pip install tox && COV_ARGS="--cov" tox -e py312-torch29-tf_latest-unit + run: pip install tox && COV_ARGS="--cov" tox -e py312-torch210-tf_latest-unit - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: @@ -55,6 +55,7 @@ jobs: with: python-version: "3.12" - name: Run unit tests (without coverage) + # Some issues with torch 2.10 on Windows, so using 2.9 for now run: pip install tox && tox -e py312-torch29-tf_latest-unit multi-py: if: github.event_name == 'pull_request' @@ -70,7 +71,7 @@ jobs: with: python-version: "3.${{ matrix.py }}" - name: Run unit tests - run: pip install tox && tox -e py3${{ matrix.py }}-torch29-tf_latest-unit + run: pip install tox && tox -e py3${{ matrix.py }}-torch210-tf_latest-unit multi-torch: if: github.event_name == 'pull_request' needs: [linux] @@ -78,7 +79,7 @@ jobs: timeout-minutes: 30 strategy: matrix: - torch: [26, 27, 28] + torch: [26, 27, 28, 29] steps: - uses: actions/checkout@v6 - uses: ./.github/actions/ubuntu-setup @@ -96,7 +97,7 @@ jobs: - uses: actions/checkout@v6 - uses: ./.github/actions/ubuntu-setup - name: Run unit tests - run: pip install tox && tox -e py312-torch29-tf_${{ matrix.tf }}-unit + run: pip install tox && tox -e py312-torch210-tf_${{ matrix.tf }}-unit partial-install: if: github.event_name == 'pull_request' needs: [linux] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b936106f5..9234596a0c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -109,7 +109,8 @@ repos: examples/speculative_decoding/main.py| examples/speculative_decoding/medusa_utils.py| examples/speculative_decoding/server_generate.py| - examples/puzzletron/evaluation/hf_deployable_anymodel\.py| + examples/puzzletron/evaluation/lm_eval_anymodel.py| + modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py| modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_.*\.py| )$ diff --git a/.vscode/settings.json b/.vscode/settings.json index 0a3a2353ea..0e8465ad38 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,7 +40,7 @@ "--no-cov", ], "evenBetterToml.schema.enabled": false, // disable toml/json schema since we have custom fields - "python.analysis.extraPaths": [ + "cursorpyright.analysis.extraPaths": [ "./tests/" // add tests to python path just like pytest does in pyproject.toml ], "git.alwaysSignOff": true, diff --git a/CHANGELOG-Windows.rst b/CHANGELOG-Windows.rst index cea2aac1d4..279e5e6781 100644 --- a/CHANGELOG-Windows.rst +++ b/CHANGELOG-Windows.rst @@ -1,6 +1,19 @@ NVIDIA Model Optimizer Changelog (Windows) ========================================== +0.41 (TBD) +^^^^^^^^^^ + +**Bug Fixes** + +- Fix ONNX 1.19 compatibility issues with CuPy during ONNX INT4 AWQ quantization. ONNX 1.19 uses ml_dtypes.int4 instead of numpy.int8 which caused CuPy failures. + +**New Features** + +- Add support for ONNX Mixed Precision Weight-only quantization using INT4 and INT8 precisions. Refer quantization `example for GenAI LLMs `_. +- Add support for some diffusion models' quantization on Windows. Refer `example script `_ for details. +- Add `Perplexity `_ and `KL-Divergence `_ accuracy benchmarks. + 0.33 (2025-07-21) ^^^^^^^^^^^^^^^^^ @@ -25,8 +38,8 @@ NVIDIA Model Optimizer Changelog (Windows) - This is the first official release of Model Optimizer for Windows - **ONNX INT4 Quantization:** :meth:`modelopt.onnx.quantization.quantize_int4 ` now supports ONNX INT4 quantization for DirectML and TensorRT* deployment. See :ref:`Support_Matrix` for details about supported features and models. -- **LLM Quantization with Olive:** Enabled LLM quantization through Olive, streamlining model optimization workflows. Refer `example `_ -- **DirectML Deployment Guide:** Added DML deployment guide. Refer :ref:`DirectML_Deployment`. +- **LLM Quantization with Olive:** Enabled LLM quantization through Olive, streamlining model optimization workflows. Refer `Olive example `_. +- **DirectML Deployment Guide:** Added DML deployment guide. Refer :ref:`Onnxruntime_Deployment` deployment guide for details. - **MMLU Benchmark for Accuracy Evaluations:** Introduced `MMLU benchmarking `_ for accuracy evaluation of ONNX models on DirectML (DML). - **Published quantized ONNX models collection:** Published quantized ONNX models at HuggingFace `NVIDIA collections `_. diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 978ac209d4..9a0d70916a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,39 @@ NVIDIA Model Optimizer Changelog (Linux) ======================================== +0.43 (2026-03-xx) +^^^^^^^^^^^^^^^^^ + +**New Features** + +- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. +- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. +- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add support for rotating the input before quantization for RHT. + +0.42 (2026-02-xx) +^^^^^^^^^^^^^^^^^ + +**Bug Fixes** + +- Fix calibration data generation with multiple samples in the ONNX workflow. + +**New Features** + +- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. +- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. +- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. +- New example for Minitron pruning with Megatron-Bridge framework along with advanced pruning usage with new ``params`` constraint based pruning. Also add example for distillation with Megatron-Bridge framework. Check `examples/megatron_bridge/README.md `_ for example scripts. +- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. +- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. +- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models. +- Add unified Hugging Face export support for diffusers pipelines/components. +- Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow. +- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. +- Add support for image-text data calibration in PTQ for Nemotron VL models. +- Add PTQ support for Nemotron Parse. +- Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md `_ for more details. + 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ @@ -84,7 +117,7 @@ NVIDIA Model Optimizer Changelog (Linux) **Documentation** -- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md `_ for more details. +- Add general guidelines for Minitron pruning and distillation. See `pruning guidelines `_ for more details. - Added example for exporting QLoRA checkpoint for vLLM deployment. Refer to `examples/llm_qat/README.md `_ for more details 0.37 (2025-10-08) @@ -209,7 +242,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for UNet ONNX quantization. - Enable ``concat_elimination`` pass by default to improve the performance of quantized ONNX models. - Enable Redundant Cast elimination pass by default in :meth:`moq.quantize `. -- Add new attribute ``parallel_state`` to :class:`DynamicModule ` to support distributed parallelism such as data parallel and tensor parallel. +- Add new attribute ``parallel_state`` to :class:`QuantModule ` to support distributed parallelism such as data parallel and tensor parallel. - Add MXFP8, NVFP4 quantized ONNX export support. - Add new example for torch quantization to ONNX for MXFP8, NVFP4 precision. diff --git a/README.md b/README.md index 14e91fd8b8..323c6ebf0b 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,9 @@ ______________________________________________________________________ **[Input]** Model Optimizer currently supports inputs of a [Hugging Face](https://huggingface.co/), [PyTorch](https://github.com/pytorch/pytorch) or [ONNX](https://github.com/onnx/onnx) model. **[Optimize]** Model Optimizer provides Python APIs for users to easily compose the above model optimization techniques and export an optimized quantized checkpoint. -Model Optimizer is also integrated with [NVIDIA NeMo](https://github.com/NVIDIA-NeMo/NeMo), [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and [Hugging Face Accelerate](https://github.com/huggingface/accelerate) for training required inference optimization techniques. +Model Optimizer is also integrated with [NVIDIA Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge), [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and [Hugging Face Accelerate](https://github.com/huggingface/accelerate) for training required inference optimization techniques. -**[Export for deployment]** Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like [SGLang](https://github.com/sgl-project/sglang), [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization), [TensorRT](https://github.com/NVIDIA/TensorRT), or [vLLM](https://github.com/vllm-project/vllm). +**[Export for deployment]** Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like [SGLang](https://github.com/sgl-project/sglang), [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization), [TensorRT](https://github.com/NVIDIA/TensorRT), or [vLLM](https://github.com/vllm-project/vllm). The unified Hugging Face export API now supports both transformers and diffusers models. ## Latest News diff --git a/docs/source/deployment/2_directml.rst b/docs/source/deployment/2_onnxruntime.rst similarity index 66% rename from docs/source/deployment/2_directml.rst rename to docs/source/deployment/2_onnxruntime.rst index 90a4a31a9d..266190c084 100644 --- a/docs/source/deployment/2_directml.rst +++ b/docs/source/deployment/2_onnxruntime.rst @@ -1,11 +1,19 @@ -.. _DirectML_Deployment: +.. _Onnxruntime_Deployment: -=================== -DirectML -=================== +=========== +Onnxruntime +=========== +Once an ONNX FP16 model is quantized using Model Optimizer on Windows, the resulting quantized ONNX model can be deployed via the `ONNX Runtime GenAI `_ or `ONNX Runtime `_. -Once an ONNX FP16 model is quantized using Model Optimizer on Windows, the resulting quantized ONNX model can be deployed on the DirectML (DML) backend via the `ONNX Runtime GenAI `_ or `ONNX Runtime `_. +ONNX Runtime uses execution providers (EPs) to run models efficiently across a range of backends, including: + +- **CUDA EP:** Utilizes NVIDIA GPUs for fast inference with CUDA and cuDNN libraries. +- **DirectML EP:** Enables deployment on a wide range of GPUs. +- **TensorRT-RTX EP:** Targets NVIDIA RTX GPUs, leveraging TensorRT for further optimized inference. +- **CPU EP:** Provides a fallback to run inference on the system's CPU when specialized hardware is unavailable. + +Choose the EP that best matches your model, hardware and deployment requirements. .. note:: Currently, DirectML backend doesn't support 8-bit precision. So, 8-bit quantized models should be deployed on other backends like ORT-CUDA etc. However, DML path does support INT4 quantized models. @@ -21,6 +29,10 @@ ONNX Runtime GenAI offers a streamlined solution for deploying generative AI mod - **Control Options**: Use the high-level ``generate()`` method for rapid deployment or execute each iteration of the model in a loop for fine-grained control. - **Multi-Language API Support**: Provides APIs for Python, C#, and C/C++, allowing seamless integration across a range of applications. +.. note:: + + ONNX Runtime GenAI models are typically tied to the execution provider (EP) they were built with; a model exported for one EP (e.g., CUDA or DirectML) is generally not compatible with other EPs. To run inference on a different backend, re-export or convert the model specifically for that target EP. + **Getting Started**: Refer to the `ONNX Runtime GenAI documentation `_ for an in-depth guide on installation, setup, and usage. @@ -42,4 +54,4 @@ For further details and examples, please refer to the `ONNX Runtime documentatio Collection of optimized ONNX models =================================== -The ready-to-deploy optimized ONNX models from ModelOpt-Windows are available at HuggingFace `NVIDIA collections `_. These models can be deployed using DirectML backend. Follow the instructions provided along with the published models for deployment. +The ready-to-deploy optimized ONNX models from ModelOpt-Windows are available at HuggingFace `NVIDIA collections `_. Follow the instructions provided along with the published models for deployment. diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index ee8c5bc9be..9124164b57 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -2,7 +2,7 @@ Unified HuggingFace Checkpoint ================================================================= -We support exporting modelopt-optimized Huggingface models and Megatron Core models to a unified checkpoint format that can be deployed in various inference frameworks such as TensorRT-LLM, vLLM, and SGLang. +We support exporting modelopt-optimized Hugging Face models (transformers and diffusers pipelines/components) and Megatron Core models to a unified checkpoint format that can be deployed in various inference frameworks such as TensorRT-LLM, vLLM, and SGLang. The workflow is as follows: @@ -32,6 +32,10 @@ The export API (:meth:`export_hf_checkpoint `_ (referred to as Model Optimizer, or ModelOpt) is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to compress model. It accepts a torch or ONNX model as input and provides Python APIs for users to easily stack different model optimization -techniques to produce optimized & quantized checkpoints. Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like `TensorRT-LLM `_ or `TensorRT `_ (Linux). ModelOpt is integrated with `NVIDIA NeMo `_ and `Megatron-LM `_ for training-in-the-loop optimization techniques. For enterprise users, the 8-bit quantization with Stable Diffusion is also available on `NVIDIA NIM `_. +techniques to produce optimized & quantized checkpoints. Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like `TensorRT-LLM `_ or `TensorRT `_ (Linux). The unified Hugging Face export API supports both transformers and diffusers models. ModelOpt is integrated with `NVIDIA NeMo `_ and `Megatron-LM `_ for training-in-the-loop optimization techniques. For enterprise users, the 8-bit quantization with Stable Diffusion is also available on `NVIDIA NIM `_. -For Windows users, the `Model Optimizer for Windows `_ (ModelOpt-Windows) delivers model compression techniques, including quantization, on Windows RTX PC systems. ModelOpt-Windows is optimized for efficient quantization, featuring local GPU calibration, reduced system and video memory consumption, and swift processing times. It integrates seamlessly with the Windows ecosystem, with optimized ONNX models as output for `Microsoft DirectML `_ backends. Furthermore, ModelOpt-Windows supports SDKs such as `Microsoft Olive `_ and `ONNX Runtime `_, enabling the deployment of quantized models across various independent hardware vendors through the DirectML path. +For Windows users, the `Model Optimizer for Windows `_ (ModelOpt-Windows) delivers model compression techniques, including quantization, on Windows RTX PC systems. ModelOpt-Windows is optimized for efficient quantization, featuring local GPU calibration, reduced system and video memory consumption, and swift processing times. It integrates seamlessly with the Windows ecosystem, with optimized ONNX models as output for `Microsoft DirectML `_ and `TensorRT-RTX `_ backends. Furthermore, ModelOpt-Windows supports SDKs such as `Microsoft Olive `_ and `ONNX Runtime `_, enabling the deployment of quantized models across various independent hardware vendors through the DirectML path. Model Optimizer for both Linux and Windows are available for free for all developers on `NVIDIA PyPI `_. Visit the `Model Optimizer GitHub repository `_ for end-to-end example scripts and recipes optimized for NVIDIA GPUs. diff --git a/docs/source/getting_started/_installation_for_Linux.rst b/docs/source/getting_started/_installation_for_Linux.rst index 0a82ecd1ed..74276aa3b0 100644 --- a/docs/source/getting_started/_installation_for_Linux.rst +++ b/docs/source/getting_started/_installation_for_Linux.rst @@ -14,11 +14,11 @@ Latest Model Optimizer (``nvidia-modelopt``) currently has the following system +-------------------------+-----------------------------+ | Python | >=3.10,<3.13 | +-------------------------+-----------------------------+ -| CUDA | >=12.0 | +| CUDA | 12.x, 13.x | +-------------------------+-----------------------------+ | PyTorch | >=2.6 | +-------------------------+-----------------------------+ -| TensorRT-LLM (Optional) | 1.2.0rc4 | +| TensorRT-LLM (Optional) | >=1.0 | +-------------------------+-----------------------------+ | ONNX Runtime (Optional) | 1.22 | +-------------------------+-----------------------------+ @@ -126,6 +126,10 @@ Additionally, we support installing dependencies for following 3rd-party package * - Huggingface (``transformers``, ``diffusers``, etc.) - ``[hf]`` +**CUDA specific dependencies** + +* By default, ``cupy-cuda12x`` is installed for INT4 ONNX quantization. If you have CUDA 13, you need to run ``pip uninstall -y cupy-cuda12x`` and ``pip install cupy-cuda13x`` after installing ``nvidia-modelopt[onnx]``. + **Accelerated Quantization with Triton Kernels** ModelOpt includes optimized quantization kernels implemented with Triton language that accelerate quantization diff --git a/docs/source/getting_started/windows/_installation_for_Windows.rst b/docs/source/getting_started/windows/_installation_for_Windows.rst index a386fd30f7..f68ee90b5d 100644 --- a/docs/source/getting_started/windows/_installation_for_Windows.rst +++ b/docs/source/getting_started/windows/_installation_for_Windows.rst @@ -25,7 +25,7 @@ The following system requirements are necessary to install and use Model Optimiz +-------------------------+-----------------------------+ .. note:: - - Make sure to use GPU-compatible driver and other dependencies (e.g. torch etc.). For instance, support for Blackwell GPU might be present in Nvidia 570+ driver, and CUDA-12.8. + - Make sure to use GPU-compatible driver and other dependencies (e.g. torch etc.). For instance, support for Blackwell GPU might be present in Nvidia 570+ driver, and CUDA-12.8+. - We currently support *Single-GPU* configuration. The Model Optimizer - Windows can be used in following ways: diff --git a/docs/source/getting_started/windows/_installation_standalone.rst b/docs/source/getting_started/windows/_installation_standalone.rst index 47f36050c2..500b480e12 100644 --- a/docs/source/getting_started/windows/_installation_standalone.rst +++ b/docs/source/getting_started/windows/_installation_standalone.rst @@ -13,6 +13,7 @@ Before using ModelOpt-Windows, the following components must be installed: - NVIDIA GPU and Graphics Driver - Python version >= 3.10 and < 3.13 - Visual Studio 2022 / MSVC / C/C++ Build Tools + - CUDA Toolkit, CuDNN for using CUDA path during calibration (e.g. for calibration of ONNX models using `onnxruntime-gpu` or CUDA EP) Update ``PATH`` environment variable as needed for above prerequisites. @@ -26,45 +27,38 @@ It is recommended to use a virtual environment for managing Python dependencies. $ python -m venv .\myEnv $ .\myEnv\Scripts\activate -In the newly created virtual environment, none of the required packages (e.g., onnx, onnxruntime, onnxruntime-directml, onnxruntime-gpu, nvidia-modelopt) will be pre-installed. +In the newly created virtual environment, none of the required packages (e.g., onnx, onnxruntime, onnxruntime-directml, onnxruntime-gpu, nvidia-modelopt etc.) will be pre-installed. **3. Install ModelOpt-Windows Wheel** -To install the ModelOpt-Windows wheel, run the following command: +To install the ONNX module of ModelOpt-Windows, run the following command: .. code-block:: bash pip install "nvidia-modelopt[onnx]" -This command installs ModelOpt-Windows and its ONNX module, along with the *onnxruntime-directml* (v1.20.0) package. If ModelOpt-Windows is installed without the additional parameter, only the bare minimum dependencies will be installed, without the relevant module and dependencies. +If you install ModelOpt-Windows without the extra ``[onnx]`` option, only the minimal core dependencies and the PyTorch module (``torch``) will be installed. Support for ONNX model quantization requires installing with ``[onnx]``. -**4. Setup ONNX Runtime (ORT) for Calibration** +**4. ONNX Model Quantization: Setup ONNX Runtime Execution Provider for Calibration** -The ONNX Post-Training Quantization (PTQ) process involves running the base model with user-supplied inputs, a process called calibration. The user-supplied model inputs are referred to as calibration data. To perform calibration, the base model must be run using a suitable ONNX Execution Provider (EP), such as *DmlExecutionProvider* (DirectML EP) or *CUDAExecutionProvider* (CUDA EP). There are different ONNX Runtime packages for each EP: +The Post-Training Quantization (PTQ) process for ONNX models usually involves running the base model with user-supplied inputs, a process called calibration. The user-supplied model inputs are referred to as calibration data. To perform calibration, the base model must be run using a suitable ONNX Execution Provider (EP), such as *DmlExecutionProvider* (DirectML EP) or *CUDAExecutionProvider* (CUDA EP). There are different ONNX Runtime packages for each EP: - *onnxruntime-directml* provides the DirectML EP. +- *onnxruntime-trt-rtx* provides TensorRT-RTX EP. - *onnxruntime-gpu* provides the CUDA EP. - *onnxruntime* provides the CPU EP. -By default, ModelOpt-Windows installs *onnxruntime-directml* and uses the DirectML EP (v1.20.0) for calibration. No additional dependencies are required. -If you prefer to use the CUDA EP for calibration, uninstall the existing *onnxruntime-directml* package and install the *onnxruntime-gpu* package, which requires CUDA and cuDNN dependencies: - -- Uninstall *onnxruntime-directml*: - - .. code-block:: bash - - pip uninstall onnxruntime-directml +By default, ModelOpt-Windows installs *onnxruntime-gpu*. The default CUDA version needed for *onnxruntime-gpu* since v1.19.0 is 12.x. The *onnxruntime-gpu* package (i.e. CUDA EP) has CUDA and cuDNN dependencies: - Install CUDA and cuDNN: - For the ONNX Runtime GPU package, you need to install the appropriate version of CUDA and cuDNN. Refer to the `CUDA Execution Provider requirements `_ for compatible versions of CUDA and cuDNN. -- Install ONNX Runtime GPU (CUDA 12.x): +If you need to use any other EP for calibration, you can uninstall the existing *onnxruntime-gpu* package and install the corresponding package. For example, to use the DirectML EP, you can uninstall the existing *onnxruntime-gpu* package and install the *onnxruntime-directml* package: .. code-block:: bash - pip install onnxruntime-gpu - - - The default CUDA version for *onnxruntime-gpu* since v1.19.0 is 12.x. + pip uninstall onnxruntime-gpu + pip install onnxruntime-directml **5. Setup GPU Acceleration Tool for Quantization** @@ -75,8 +69,9 @@ By default, ModelOpt-Windows utilizes the `cupy-cuda12x `_ t Ensure the following steps are verified: - **Task Manager**: Check that the GPU appears in the Task Manager, indicating that the graphics driver is installed and functioning. - **Python Interpreter**: Open the command line and type python. The Python interpreter should start, displaying the Python version. - - **Onnxruntime Package**: Ensure that one of the following is installed: + - **Onnxruntime Package**: Ensure that exactly one of the following is installed: - *onnxruntime-directml* (DirectML EP) + - *onnxruntime-trt-rtx* (TensorRT-RTX EP) - *onnxruntime-gpu* (CUDA EP) - *onnxruntime* (CPU EP) - **Onnx and Onnxruntime Import**: Ensure that following python command runs successfully. diff --git a/docs/source/getting_started/windows/_installation_with_olive.rst b/docs/source/getting_started/windows/_installation_with_olive.rst index a05155278f..544ecd2df8 100644 --- a/docs/source/getting_started/windows/_installation_with_olive.rst +++ b/docs/source/getting_started/windows/_installation_with_olive.rst @@ -4,7 +4,7 @@ Install ModelOpt-Windows with Olive =================================== -ModelOpt-Windows can be installed and used through Olive to quantize Large Language Models (LLMs) in ONNX format for deployment with DirectML. Follow the steps below to configure Olive for use with ModelOpt-Windows. +ModelOpt-Windows can be installed and used through Olive to perform model optimization using quantization technique. Follow the steps below to configure Olive for use with ModelOpt-Windows. Setup Steps for Olive with ModelOpt-Windows ------------------------------------------- @@ -17,7 +17,7 @@ Setup Steps for Olive with ModelOpt-Windows pip install olive-ai[nvmo] - - **Install Prerequisites:** Ensure all required dependencies are installed. Use the following commands to install the necessary packages: + - **Install Prerequisites:** Ensure all required dependencies are installed. For example, to use DirectML Execution-Provider (EP) based onnxruntime and onnxruntime-genai packages, run the following commands: .. code-block:: shell @@ -31,11 +31,11 @@ Setup Steps for Olive with ModelOpt-Windows **2. Configure Olive for Model Optimizer – Windows** - **New Olive Pass:** Olive introduces a new pass, ``NVModelOptQuantization`` (or “nvmo”), specifically designed for model quantization using Model Optimizer – Windows. - - **Add to Configuration:** To apply quantization to your target model, include this pass in the Olive configuration file. [Refer `phi3 `_ Olive example]. + - **Add to Configuration:** To apply quantization to your target model, include this pass in the Olive configuration file. [Refer `this `_ guide for details about this pass.]. **3. Setup Other Passes in Olive Configuration** - - **Add Other Passes:** Add additional passes to the Olive configuration file as needed for the desired Olive workflow of your input model. [Refer `phi3 `_ Olive example] + - **Add Other Passes:** Add additional passes to the Olive configuration file as needed for the desired Olive workflow of your input model. **4. Install other dependencies** @@ -62,4 +62,4 @@ Setup Steps for Olive with ModelOpt-Windows **Note**: #. Currently, the Model Optimizer - Windows only supports Onnx Runtime GenAI based LLM models in the Olive workflow. -#. To try out different LLMs and EPs in the Olive workflow of ModelOpt-Windows, refer the details provided in `phi3 `_ Olive example. +#. To get started with Olive, refer to the official `Olive documentation `_. diff --git a/docs/source/guides/0_support_matrix.rst b/docs/source/guides/0_support_matrix.rst index 69e860e4b7..0e5ddeea5b 100644 --- a/docs/source/guides/0_support_matrix.rst +++ b/docs/source/guides/0_support_matrix.rst @@ -63,7 +63,7 @@ Feature Support Matrix * Uses AWQ Algorithm * GPUs: Ampere and Later - PyTorch*, ONNX - - ORT-DirectML, TensorRT*, TensorRT-LLM* + - ORT-DML, ORT-CUDA, ORT-TRT-RTX, TensorRT*, TensorRT-LLM* * - W4A8 (INT4 Weights, FP8 Activations) - * Block-wise INT4 Weights, Per-Tensor FP8 Activations * Uses AWQ Algorithm @@ -84,7 +84,9 @@ Feature Support Matrix - PyTorch*, ONNX - TensorRT*, TensorRT-LLM*, ORT-CUDA -.. note:: Features marked with an asterisk (*) are considered experimental. +.. note:: + - Features marked with an asterisk (*) are considered experimental. + - ``ORT-CUDA``, ``ORT-DML``, and ``ORT-TRT-RTX`` are ONNX Runtime Execution Providers (EPs) for CUDA, DirectML, and TensorRT-RTX respectively. Support for different deployment backends can vary across models. Model Support Matrix @@ -96,87 +98,4 @@ Model Support Matrix .. tab:: Windows - .. list-table:: - :header-rows: 1 - - * - Model - - ONNX INT4 AWQ (W4A16) - - ONNX INT8 Max (W8A8) - - ONNX FP8 Max (W8A8) - * - Llama3.1-8B-Instruct - - Yes - - No - - No - * - Phi3.5-mini-Instruct - - Yes - - No - - No - * - Mistral-7B-Instruct-v0.3 - - Yes - - No - - No - * - Llama3.2-3B-Instruct - - Yes - - No - - No - * - Gemma-2b-it - - Yes - - No - - No - * - Gemma-2-2b - - Yes - - No - - No - * - Gemma-2-9b - - Yes - - No - - No - * - Nemotron Mini 4B Instruct - - Yes - - No - - No - * - Qwen2.5-7B-Instruct - - Yes - - No - - No - * - DeepSeek-R1-Distill-Llama-8B - - Yes - - No - - No - * - DeepSeek-R1-Distil-Qwen-1.5B - - Yes - - No - - No - * - DeepSeek-R1-Distil-Qwen-7B - - Yes - - No - - No - * - DeepSeek-R1-Distill-Qwen-14B - - Yes - - No - - No - * - Mistral-NeMo-Minitron-2B-128k-Instruct - - Yes - - No - - No - * - Mistral-NeMo-Minitron-4B-128k-Instruct - - Yes - - No - - No - * - Mistral-NeMo-Minitron-8B-128k-Instruct - - Yes - - No - - No - * - whisper-large - - No - - Yes - - Yes - * - sam2-hiera-large - - No - - Yes - - Yes - - .. note:: - - ``ONNX INT8 Max`` means INT8 (W8A8) quantization of ONNX model using Max calibration. Similar holds true for the term ``ONNX FP8 Max``. - - The LLMs in above table are `GenAI `_ built LLMs unless specified otherwise. - - Check `examples `_ for specific instructions and scripts. + Please checkout the model support matrix `details `_. diff --git a/docs/source/guides/2_save_load.rst b/docs/source/guides/2_save_load.rst index d0c0b8cb8b..e097e3f806 100644 --- a/docs/source/guides/2_save_load.rst +++ b/docs/source/guides/2_save_load.rst @@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu model = ... # Restore the model architecture using the saved `modelopt_state` + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load("modelopt_state.pth", weights_only=False) model = mto.restore_from_modelopt_state(model, modelopt_state) diff --git a/docs/source/guides/8_autocast.rst b/docs/source/guides/8_autocast.rst index 4ad39e969c..0701f2f1f9 100644 --- a/docs/source/guides/8_autocast.rst +++ b/docs/source/guides/8_autocast.rst @@ -42,6 +42,7 @@ AutoCast can also be used programmatically through its Python API: trt_plugins=[], # list of TensorRT plugin library paths in .so format max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16) + use_standalone_type_inference=False, # use standalone type inference instead of ONNX's infer_shapes (WAR) ) # Save the converted model @@ -82,6 +83,9 @@ AutoCast follows these steps to convert a model: - Converts eligible nodes to lower precision - Automatically inserts necessary cast operations - Automatically replaces initializers with lower precision values + - Performs type inference to propagate types through the graph + - By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API. + - Use ``use_standalone_type_inference=True`` to use a standalone type-only inference implementation (experimental). #. **Validation and Export**: @@ -145,6 +149,14 @@ Best Practices - A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues. - The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19). +#. **Type Inference Control** + + - By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference. + - Use ``--use_standalone_type_inference`` to enable a standalone type-only inference implementation. + - This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic. + - The standalone implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape. + - Note: The standalone type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures. + Limitations and Restrictions ---------------------------- - AutoCast does not yet support quantized models. @@ -198,3 +210,9 @@ Convert to BF16 with a specific opset: .. code-block:: bash python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22 + +Use standalone type inference instead of ONNX's infer_shapes: + +.. code-block:: bash + + python -m modelopt.onnx.autocast --onnx_path model.onnx --use_standalone_type_inference diff --git a/docs/source/guides/windows_guides/_ONNX_PTQ_guide.rst b/docs/source/guides/windows_guides/_ONNX_PTQ_guide.rst index 9e60611c07..b79415178c 100644 --- a/docs/source/guides/windows_guides/_ONNX_PTQ_guide.rst +++ b/docs/source/guides/windows_guides/_ONNX_PTQ_guide.rst @@ -155,4 +155,4 @@ To save a quantized ONNX model with external data, use the following code: Deploy Quantized ONNX Model --------------------------- -Inference of the quantized models can be done using tools like `GenAI `_, `OnnxRunTime (ORT) `_. These APIs can do inference on backends like DML. For details about DirectML deployment of quantized models, see :ref:`DirectML_Deployment`. Also, refer `example scripts `_ for any possible model-specific inference guidance or script (if any). +Inference of the quantized models can be done using tools like `GenAI `_, `OnnxRunTime (ORT) `_. These APIs can do inference on backends like DML, CUDA, TensorRT-RTX. For details about onnxruntime deployment of quantized models, see :ref:`Onnxruntime_Deployment` deployment guide. Also, refer `example scripts `_ for any possible model-specific inference guidance or script (if any). diff --git a/docs/source/reference/2_security.rst b/docs/source/reference/2_security.rst new file mode 100644 index 0000000000..5a6e37af0e --- /dev/null +++ b/docs/source/reference/2_security.rst @@ -0,0 +1,78 @@ +Security Considerations +======================= + +Overview +-------- + +NVIDIA Model Optimizer (ModelOpt) is a library used to optimize ML models and +may load and process user-provided artifacts (models, weights, configs, +calibration data) and their dependencies. Secure deployment depends on how you +source artifacts, validate inputs, and harden the environment where ModelOpt +runs. + +What to Be Aware Of +------------------- + +**Untrusted model and data inputs** + +- Models, weights, configs and data may be malicious or corrupted. + +**Deserialization and code-execution risks** + +- Unsafe deserialization can lead to arbitrary code execution if fed untrusted + inputs. +- Avoid using serialization formats/settings that can deserialize arbitrary + objects. + +**Input validation and resource exhaustion** + +- Large or malformed inputs can trigger crashes or excessive CPU/GPU/memory use. +- Missing size/type checks can increase DoS risk. + +**Data in transit and at rest** + +- If fetching models or dependencies over the network, insecure transport can + enable tampering. +- Stored artifacts, logs, and caches may contain sensitive data. + +**Logging and observability** + +- Logs may inadvertently contain sensitive inputs, paths, tokens, or proprietary + model details. +- Overly verbose logs can leak operational and security-relevant information. + +**Supply chain and third-party components** + +- Dependencies may include known vulnerabilities or be compromised. +- Third-party plugins/components loaded at runtime may not have the same + security assurances. + +Example Security Approaches +--------------------------- + +**Artifact integrity** + +- Only load artifacts from trusted sources. +- Prefer signed artifacts; verify signatures before loading. + +**Safe parsing and deserialization** + +- Prefer safer storage formats (avoid object deserialization for untrusted + inputs). +- Avoid ``pickle``, ``torch.load()`` with untrusted weights, or YAML + ``unsafe_load``. +- Treat any unverified artifact as untrusted and block/guard its loading. + +**Hardening and least privilege** + +- Run with least privilege and isolate workloads. + +**Data protection** + +- Encrypt sensitive data at rest; use TLS 1.3 for data in transit. +- Never hardcode or log credentials. + +**Resilience** + +- Validate inputs and enforce limits (file size, timeouts, quotas,..). +- Keep OS, containers, and dependencies patched; scan for known vulnerabilities. diff --git a/docs/source/support/2_faqs.rst b/docs/source/support/2_faqs.rst index 02970a1f91..3d0afaa3c8 100644 --- a/docs/source/support/2_faqs.rst +++ b/docs/source/support/2_faqs.rst @@ -15,7 +15,7 @@ ModelOpt-Windows Awq-scale search should complete in minutes with NVIDIA GPU acceleration. If stalled: -- **GPU acceleration may be disabled.** If CUDA 12.x is not available, quantization will fall back to slower ``numpy`` implementation instead of ``cupy-cuda12x``. +- **GPU acceleration may be disabled.** If CUDA 12.x is not available, quantization will fall back to slower ``numpy`` implementation instead of ``cupy-cuda12x``. Make sure that ``cupy`` package is compatible with installed CUDA toolkit. - **Low GPU memory.** Quantization needs 20-24GB VRAM; low memory forces slower shared memory usage. - **Using CPU for quantization.** Install ORT-GPU (supports CUDA EP) or ORT-DML (supports DML EP) for better speed. @@ -45,21 +45,21 @@ Make sure that the output directory is clean before each quantization run otherw `Error Unrecognized attribute: block_size for operator DequantizeLinear` -ModelOpt-Windows uses ONNX's `DequantizeLinear `_ (DQ) nodes. The int4 data-type support in DeQuantizeLinear node came in opset-21. And, *block_size* attribute was added in DeQuantizeLinear node in Opset-21. Make sure that quantized model's opset version is 21 or higher. Refer :ref:`Apply_ONNX_PTQ` for details. +ModelOpt-Windows uses ONNX's `DequantizeLinear `_ (DQ) nodes. The int4 data-type support in DeQuantizeLinear node came in opset-21. And, *block_size* attribute was added in DequantizeLinear node in Opset-21. Make sure that quantized model's opset version is 21 or higher. Refer :ref:`Apply_ONNX_PTQ` for details. 6. Running INT4 quantized ONNX model on DirectML backend gives following kind of error. What can be the issue? -------------------------------------------------------------------------------------------------------------- `Error: Type 'tensor(int4)' of input parameter (onnx::MatMul_6508_i4) of operator (DequantizeLinear) in node (onnx::MatMul_6508_DequantizeLinear) is invalid.` -One possible reason for above error is that INT4 quantized ONNX model's opset version (default or onnx domain) is less than 21. Ensure the INT4 quantized model's opset version is 21 or higher since INT4 data-type support in DeQuantizeLinear ONNX node came in opset-21. +One possible reason for above error is that INT4 quantized ONNX model's opset version (default or onnx domain) is less than 21. Ensure the INT4 quantized model's opset version is 21 or higher since INT4 data-type support in DequantizeLinear ONNX node came in opset-21. 7. Running 8-bit quantized ONNX model with ORT-DML gives onnxruntime error about using 8-bit data-type (e.g. INT8/FP8). What can be the issue? ----------------------------------------------------------------------------------------------------------------------------------------------- Currently, DirectML backend (ORT-DML) doesn't support 8-bit precision. So, it expectedly complains about 8-bit data-type. Try using ORT-CUDA or other 8-bit compatible backend. -8. How to resolve onnxruntime error about invalid use of FP8 type in QuantizeLinear / DeQuantizeLinear node? +8. How to resolve onnxruntime error about invalid use of FP8 type in QuantizeLinear / DequantizeLinear node? ------------------------------------------------------------------------------------------------------------- The FP8 type support in QuantizeLinear / DeQuantizeLinear node came with Opset-19. So, ensure that opset of ONNX model is 19+. diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 7beb2c9200..bcfd9de409 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -56,6 +56,7 @@ from modelopt.torch.export.model_config import KV_CACHE_FP8 from modelopt.torch.export.quant_utils import get_quant_config from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.triton import weight_dequant from modelopt.torch.quantization.utils import ( is_quantized_column_parallel_linear, is_quantized_parallel_linear, @@ -77,7 +78,6 @@ ) import model as deekseep_model # noqa: E402 -from ds_kernel import weight_dequant # noqa: E402 from kernel import act_quant, fp8_gemm # noqa: E402 @@ -99,7 +99,7 @@ def linear( weight = weight_quantizer(weight) return F.linear(x, weight, bias) elif gemm_impl == "bf16": - weight = weight_dequant(weight, weight.scale) + weight = weight_dequant(weight, weight.scale, dtype=torch.bfloat16) if act_quantizer is not None: x = act_quantizer(x) if weight_quantizer is not None: @@ -311,7 +311,7 @@ def calibrate_loop(model): # disable head that corresponds to lm_head (for the huggingface checkpoint) mtq_cfg["quant_cfg"]["*head*"] = {"enable": False} - allowed_mla_quant = [None, "per_tensor_fp8"] + allowed_mla_quant = [None, "per_tensor_fp8", "nvfp4"] assert mla_quant in allowed_mla_quant, f"mla_quant must be {allowed_mla_quant}" if not mla_quant: @@ -319,12 +319,32 @@ def calibrate_loop(model): elif mla_quant == "per_tensor_fp8": mtq_cfg["quant_cfg"]["*attn*weight_quantizer"] = {"num_bits": (4, 3), "axis": None} mtq_cfg["quant_cfg"]["*attn*input_quantizer"] = {"num_bits": (4, 3), "axis": None} + elif mla_quant == "nvfp4": # for DeepSeek-R1-0528-NVFP4-Turbo + mla_linear_layers = ["*wq_a*", "*wq_b*", "*wkv_a*", "*wkv_b*", "*wo*"] + mla_nvfp4_linear_layers = ["*wq_a*", "*wkv_a*", "*wq_b*", "*wo*"] + for layer in mla_linear_layers: + if layer in mla_nvfp4_linear_layers: + # wq_a, wkv_a, wq_b, wo use NVFP4 quantization + mtq_cfg["quant_cfg"][layer + "_quantizer"] = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + } + else: + mtq_cfg["quant_cfg"][layer + "_quantizer"] = {"enable": False} + + # Disable BMM quantizers + mtq_cfg["quant_cfg"]["*attn.kv_bmm_quantizer*"] = {"enable": False} + mtq_cfg["quant_cfg"]["*attn.pe_bmm_quantizer*"] = {"enable": False} if not args.disable_wo_quant and "FP4" in quant_cfg: mtq_cfg["quant_cfg"]["*wo*weight_quantizer"] = mtq_cfg["quant_cfg"]["*input_quantizer"] mtq_cfg["quant_cfg"]["*wo*input_quantizer"] = mtq_cfg["quant_cfg"]["*weight_quantizer"] + ## ptq transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop) + if int(os.environ["LOCAL_RANK"]) == 0: mtq.print_quant_summary(transformer) @@ -407,11 +427,17 @@ def state_dict_filter(state_dict): parser.add_argument("--disable_fp8_kvcache", action="store_true", help="disable fp8 kvcache.") parser.add_argument("--disable_wo_quant", action="store_true", help="disable MLA wo quant.") parser.add_argument("--trust_remote_code", action="store_true", help="trust remote code.") + parser.add_argument( + "--mla_quant", + type=str, + default=None, + help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4", + ) args = parser.parse_args() model = load_deepseek_model(args.config, args.model_path, args.batch_size) tokenizer = AutoTokenizer.from_pretrained( args.model_path, trust_remote_code=args.trust_remote_code ) - model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) + model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index d94f48fce2..af387fce5b 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -44,11 +44,11 @@ from typing import Any import torch -from ds_kernel import weight_dequant from safetensors.torch import load_file, save_file from tqdm import tqdm from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.triton import weight_dequant def _remap_key(key_dict: dict[str, Any]): @@ -82,6 +82,20 @@ def _remap_key(key_dict: dict[str, Any]): key_dict.update(new_dict) +def remove_quantization_config_from_original_config(export_dir: str) -> None: + """Remove `quantization_config` from exported HF `config.json`. + + Assumes the exported checkpoint directory has a `config.json` containing `quantization_config`. + """ + config_path = os.path.join(export_dir, "config.json") + with open(config_path) as f: + cfg = json.load(f) + del cfg["quantization_config"] + with open(config_path, "w") as f: + json.dump(cfg, f, indent=2, sort_keys=True) + f.write("\n") + + def load_and_preprocess_state_dict(modelopt_state_root, world_size=8): state_dict_list = [ torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt") @@ -302,3 +316,5 @@ def get_tensor(tensor_name): save_root=args.fp4_path, per_layer_quant_config=per_layer_quant_config, ) + + remove_quantization_config_from_original_config(args.fp4_path) diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index 51ba929b0e..17b0105543 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -89,44 +89,45 @@ We support calibration for INT8, FP8 and FP4 precision and for both weights and We also provide instructions on deploying and running E2E diffusion pipelines with Model Optimizer quantized INT8 and FP8 Backbone to generate images and measure latency on target GPUs. Note, Jetson devices are not supported at this time due to the incompatibility of the software. > [!NOTE] -> Model calibration requires relatively more GPU computing power then deployment.It does not need to be on the same GPUs as the deployment target GPUs. Using the command line below will execute both calibration and ONNX export. +> Model calibration requires relatively more GPU computing power then deployment. It does not need to be on the same GPUs as the deployment target GPUs. ONNX export and TensorRT engine instructions live in [`quantization/ONNX-TRT-Deployment.md`](./quantization/ONNX-TRT-Deployment.md). -### Quantize and export scripts +### Quantize scripts -#### 8-bit Quantize and ONNX Export [Script](./quantization/build_sdxl_8bit_engine.sh) - -You can run the following script to quantize SDXL backbone to INT8 or FP8 and generate an onnx model built with default settings for SDXL. You can then directly head to the [Build the TRT engine for the Quantized ONNX Backbone](#build-the-trt-engine-for-the-quantized-onnx-backbone) section to run E2E pipeline and generate images. - -```sh -bash build_sdxl_8bit_engine.sh --format {FORMAT} # FORMAT can be int8 or fp8 -``` - -If you prefer to customize parameters in calibration or run other models, please follow the instructions below. - -#### FLUX-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantization/quantize.py) +#### FLUX|SD3|SDXL INT8 [Script](./quantization/quantize.py) ```sh python quantize.py \ - --model {flux-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \ + --model {flux-dev|flux-schnell|sdxl-1.0|sdxl-turbo|sd3-medium|sd3.5-medium} \ --format int8 --batch-size 2 \ --calib-size 32 --alpha 0.8 --n-steps 20 \ - --model-dtype {Half/BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ - --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR} + --model-dtype {Half/BFloat16} \ + --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt \ + --hf-ckpt-dir ./hf_ckpt ``` -#### FLUX-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantization/quantize.py) - -*In our example code, FP4 is only supported for Flux. However, you can modify our script to enable FP4 format support for your own model.* +#### FLUX|SD3|SDXL|LTX|WAN2.2 FP8/FP4 [Script](./quantization/quantize.py) ```sh python quantize.py \ - --model {flux-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ + --model {flux-dev|flux-schnell|sdxl-1.0|sdxl-turbo|sd3-medium|sd3.5-medium|ltx-video-dev|wan2.2-t2v-14b|wan2.2-t2v-5b} \ + --model-dtype {Half|BFloat16} \ --format {fp8|fp4} --batch-size 2 --calib-size {128|256} --quantize-mha \ --n-steps 20 --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --collect-method default \ - --onnx-dir {ONNX_DIR} + --hf-ckpt-dir ./hf_ckpt ``` -We recommend using a device with a minimum of 48GB of combined CPU and GPU memory for exporting ONNX models. If not, please use CPU for onnx export. +#### [LTX-2](https://github.com/Lightricks/LTX-2) FP4 (torch checkpoint export) + +```sh +python quantize.py \ + --model ltx-2 --format fp4 --batch-size 1 --calib-size 32 --n-steps 40 \ + --extra-param checkpoint_path=./ltx-2-19b-dev-fp8.safetensors \ + --extra-param distilled_lora_path=./ltx-2-19b-distilled-lora-384.safetensors \ + --extra-param spatial_upsampler_path=./ltx-2-spatial-upscaler-x2-1.0.safetensors \ + --extra-param gemma_root=./gemma-3-12b-it-qat-q4_0-unquantized \ + --extra-param fp8transformer=true \ + --quantized-torch-ckpt-save-path ./ltx-2-transformer.pt +``` #### Important Parameters @@ -135,7 +136,7 @@ We recommend using a device with a minimum of 48GB of combined CPU and GPU memor - `calib-size`: For SDXL INT8, we recommend 32 or 64, for SDXL FP8, 128 is recommended. - `n_steps`: Recommendation: SD/SDXL 20 or 30, SDXL-Turbo 4. -**Then, we can load the generated checkpoint and export the INT8/FP8 quantized model in the next step. For FP8, we only support the TRT deployment on Ada/Hopper GPUs.** +**You can use the generated checkpoint directly in PyTorch, export a Hugging Face checkpoint (`--hf-ckpt-dir`) to deploy the model on SGLang/vLLM/TRTLLM, or follow the ONNX/TensorRT workflow in [`quantization/ONNX-TRT-Deployment.md`](./quantization/ONNX-TRT-Deployment.md).** ## Quantization Aware Training (QAT) @@ -222,113 +223,7 @@ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( ## Build and Run with TensorRT Compiler Framework -### Build the TRT engine for the Quantized ONNX Backbone - -> [!IMPORTANT] -> TensorRT environment must be setup prior -- Please see [Pre-Requisites](#pre-requisites) -> INT8 requires **TensorRT version >= 9.2.0**. If you prefer to use the FP8 TensorRT, ensure you have **TensorRT version 10.2.0 or higher**. You can download the latest version of TensorRT at [here](https://developer.nvidia.com/tensorrt/download). Deployment of SVDQuant is currently not supported. - -Generate INT8/FP8 Backbone Engine - -```bash -# For SDXL -trtexec --builderOptimizationLevel=4 --stronglyTyped --onnx=./model.onnx \ - --minShapes=sample:2x4x128x128,timestep:1,encoder_hidden_states:2x77x2048,text_embeds:2x1280,time_ids:2x6 \ - --optShapes=sample:16x4x128x128,timestep:1,encoder_hidden_states:16x77x2048,text_embeds:16x1280,time_ids:16x6 \ - --maxShapes=sample:16x4x128x128,timestep:1,encoder_hidden_states:16x77x2048,text_embeds:16x1280,time_ids:16x6 \ - --saveEngine=model.plan - -# For SD3-Medium -trtexec --builderOptimizationLevel=4 --stronglyTyped --onnx=./model.onnx \ - --minShapes=hidden_states:2x16x128x128,timestep:2,encoder_hidden_states:2x333x4096,pooled_projections:2x2048 \ - --optShapes=hidden_states:16x16x128x128,timestep:16,encoder_hidden_states:16x333x4096,pooled_projections:16x2048 \ - --maxShapes=hidden_states:16x16x128x128,timestep:16,encoder_hidden_states:16x333x4096,pooled_projections:16x2048 \ - --saveEngine=model.plan - -# For FLUX-Dev FP8 -trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \ - --minShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ - --optShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ - --maxShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ - --saveEngine=model.plan -``` - -**Please note that `maxShapes` represents the maximum shape of the given tensor. If you want to use a larger batch size or any other dimensions, feel free to adjust the value accordingly.** - -### Run End-to-end Stable Diffusion Pipeline with Model Optimizer Quantized ONNX Model and demoDiffusion - -#### demoDiffusion - -If you want to run end-to-end SD/SDXL pipeline with Model Optimizer quantized UNet to generate images and measure latency on target GPUs, here are the steps: - -- Clone a copy of [demo/Diffusion repo](https://github.com/NVIDIA/TensorRT/tree/release/10.2/demo/Diffusion). - -- Following the README from demoDiffusion to set up the pipeline, and run a baseline txt2img example (fp16): - -```sh -# SDXL -python demo_txt2img_xl.py "enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow" --negative-prompt "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" --version xl-1.0 --scheduler Euler --denoising-steps 30 --seed 2946901 -# Please refer to the examples provided in the demoDiffusion SD/SDXL pipeline. -``` - -Note, it will take some time to build TRT engines for the first time - -- Replace the fp16 backbone TRT engine with int8 engine generated in [Build the TRT engine for the Quantized ONNX Backbone](#build-the-trt-engine-for-the-quantized-onnx-backbone), e.g.,: - -```sh -cp -r {YOUR_UNETXL}.plan ./engine/ -``` - -Note, the engines must be built on the same GPU, and ensure that the INT8 engine name matches the names of the FP16 engines to enable compatibility with the demoDiffusion pipeline. - -- Run the above txt2img example command again. You can compare the generated images and latency for fp16 vs int8. - Similarly, you could run end-to-end pipeline with Model Optimizer quantized backbone and corresponding examples in demoDiffusion with other diffusion models. - -### Running the inference pipeline with DeviceModel - -DeviceModel is an interface designed to run TensorRT engines like torch models. It takes torch inputs and returns torch outputs. Under the hood, DeviceModel exports a torch checkpoint to ONNX and then generates a TensorRT engine from it. This allows you to swap the backbone of the diffusion pipeline with DeviceModel and execute the pipeline for your desired prompt. - -Generate a quantized torch checkpoint using the [Script](./quantization/quantize.py) shown below: - -```bash -python quantize.py \ - --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \ - --format fp8 \ - --batch-size {1|2} \ - --calib-size 128 \ - --n-steps 20 \ - --quantized-torch-ckpt-save-path ./{MODEL}_fp8.pt \ - --collect-method default -``` - -Generate images for the quantized checkpoint with the following [Script](./quantization/diffusion_trt.py): - -```bash -python diffusion_trt.py \ - --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \ - --prompt "A cat holding a sign that says hello world" \ - [--override-model-path /path/to/model] \ - [--restore-from ./{MODEL}_fp8.pt] \ - [--onnx-load-path {ONNX_DIR}] \ - [--trt-engine-load-path {ENGINE_DIR}] \ - [--dq-only] \ - [--torch] \ - [--save-image-as /path/to/image] \ - [--benchmark] \ - [--torch-compile] \ - [--skip-image] -``` - -This script will save the output image as `./{MODEL}.png` and report the latency of the TensorRT backbone. -To generate the image with FP16|BF16 precision, you can run the command shown above without the `--restore-from` argument. - -While loading a TensorRT engine using the --trt-engine-load-path argument, it is recommended to load only engines generated using this pipeline. - -#### Demo Images - -| SDXL FP16 | SDXL INT8 | -|:---------:|:---------:| -| ![FP16](./quantization/assets/xl_base-fp16.png) | ![INT8](./quantization/assets/xl_base-int8.png) | +ONNX export and TensorRT engine instructions are documented in [`quantization/ONNX-TRT-Deployment.md`](./quantization/ONNX-TRT-Deployment.md). ### LoRA diff --git a/examples/diffusers/distillation/README.md b/examples/diffusers/distillation/README.md new file mode 100644 index 0000000000..ce57c60363 --- /dev/null +++ b/examples/diffusers/distillation/README.md @@ -0,0 +1,153 @@ +# LTX-2 Distillation Training with ModelOpt + +Knowledge distillation for LTX-2 DiT models using NVIDIA ModelOpt. A frozen **teacher** guides a trainable **student** through a combined loss: + +```text +L_total = α × L_task + (1-α) × L_distill +``` + +Currently supported: + +- **Quantization-Aware Distillation (QAD)** — student uses ModelOpt fake quantization + +Planned: + +- **Sparsity-Aware Distillation (SAD)** — student uses ModelOpt sparsity + +## Installation + +```bash +# From the distillation example directory +cd examples/diffusers/distillation + +# Install Model-Optimizer (from repo root) +pip install -e ../../.. + +# Install all dependencies (ltx-trainer, ltx-core, ltx-pipelines, omegaconf) +pip install -r requirements.txt +``` + +## Quick Start + +### 1. Prepare Your Dataset + +Use the ltx-trainer preprocessing to extract latents and text embeddings: + +```bash +python -m ltx_trainer.preprocess \ + --input_dir /path/to/videos \ + --output_dir /path/to/preprocessed \ + --model_path /path/to/ltx2/checkpoint.safetensors +``` + +### 2. Configure + +Copy and edit the example config: + +```bash +cp configs/distillation_example.yaml configs/my_experiment.yaml +``` + +Key settings to update: + +```yaml +model: + model_path: "/path/to/ltx2/checkpoint.safetensors" + text_encoder_path: "/path/to/gemma/model" + +data: + preprocessed_data_root: "/path/to/preprocessed/data" + +distillation: + distillation_alpha: 0.5 # 1.0 = pure task loss, 0.0 = pure distillation + quant_cfg: "FP8_DEFAULT_CFG" # or INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, null + +# IMPORTANT: disable ltx-trainer's built-in quantization +acceleration: + quantization: null +``` + +### 3. Run Training + +#### Single GPU + +```bash +python distillation_trainer.py --config configs/my_experiment.yaml +``` + +#### Multi-GPU (Single Node) with Accelerate + +```bash +accelerate launch \ + --config_file configs/accelerate/fsdp.yaml \ + --num_processes 8 \ + distillation_trainer.py --config configs/my_experiment.yaml +``` + +#### Multi-node Training with Accelerate + +To launch on multiple nodes, make sure to set the following environment variables on each node: + +- `NUM_NODES`: Total number of nodes +- `GPUS_PER_NODE`: Number of GPUs per node +- `NODE_RANK`: Unique rank/index of this node (0-based) +- `MASTER_ADDR`: IP address of the master node (rank 0) +- `MASTER_PORT`: Communication port (e.g., 29500) + +Then run this (on every node): + +```bash +accelerate launch \ + --config_file configs/accelerate/fsdp.yaml \ + --num_machines $NUM_NODES \ + --num_processes $((NUM_NODES * GPUS_PER_NODE)) \ + --machine_rank $NODE_RANK \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + distillation_trainer.py --config configs/my_experiment.yaml +``` + +**Config overrides** can be passed via CLI using dotted notation: + +```bash +accelerate launch ... distillation_trainer.py \ + --config configs/my_experiment.yaml \ + ++distillation.distillation_alpha=0.6 \ + ++distillation.quant_cfg=INT8_DEFAULT_CFG \ + ++optimization.learning_rate=1e-5 +``` + +## Configuration Reference + +### Calibration + +Before training begins, calibration runs full denoising inference to collect activation statistics for accurate quantizer scales. This is cached as a step-0 checkpoint and reused on subsequent runs. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `calibration_prompts_file` | null | Text file with one prompt per line. Use the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' if null. | +| `calibration_size` | 128 | Number of prompts (each runs a full denoising loop) | +| `calibration_n_steps` | 30 | Denoising steps per prompt | +| `calibration_guidance_scale` | 4.0 | CFG scale (should match inference-time) | + +### Checkpoint Resume + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `resume_from_checkpoint` | null | `"latest"` to auto-detect, or explicit path | +| `must_save_by` | null | Minutes after which to save and exit (for Slurm time limits) | +| `restore_quantized_checkpoint` | null | Restore a pre-quantized model (skips calibration) | +| `save_quantized_checkpoint` | null | Path to save the final quantized model | + +### Custom Quantization Configs + +To define custom quantization configs, add entries to `CUSTOM_QUANT_CONFIGS` in `distillation_trainer.py`: + +```python +CUSTOM_QUANT_CONFIGS["MY_FP8_CFG"] = { + "quant_cfg": mtq.FP8_DEFAULT_CFG["quant_cfg"], + "algorithm": "max", +} +``` + +Then reference it in your YAML: `quant_cfg: MY_FP8_CFG`. diff --git a/examples/diffusers/distillation/configs/accelerate/fsdp.yaml b/examples/diffusers/distillation/configs/accelerate/fsdp.yaml new file mode 100644 index 0000000000..35e3edf778 --- /dev/null +++ b/examples/diffusers/distillation/configs/accelerate/fsdp.yaml @@ -0,0 +1,45 @@ +# FSDP Configuration +# +# FULL_SHARD across all GPUs for maximum memory efficiency. +# For multi-node training with `accelerate launch`. +# +# Usage: +# accelerate launch \ +# --config_file configs/accelerate/fsdp.yaml \ +# --num_processes 16 \ +# --num_machines 2 \ +# --machine_rank $MACHINE_RANK \ +# --main_process_ip $MASTER_IP \ +# --main_process_port 29500 \ +# distillation_trainer.py --config configs/distillation_example.yaml + +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false + +fsdp_config: + # FULL_SHARD: Shard optimizer states, gradients, and parameters across ALL GPUs + # This provides maximum memory efficiency for large models like LTX-2 19B + # Parameters are fully sharded across all nodes (not replicated) + fsdp_sharding_strategy: FULL_SHARD + + # Enable activation checkpointing to reduce memory during backward pass + # Critical for 19B model training + fsdp_activation_checkpointing: true + + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock + fsdp_use_orig_params: true + fsdp_version: 1 + +# Note: num_machines and num_processes are overridden by accelerate launch command-line args +# These are just defaults for local testing +num_machines: 1 +num_processes: 8 diff --git a/examples/diffusers/distillation/configs/distillation_example.yaml b/examples/diffusers/distillation/configs/distillation_example.yaml new file mode 100644 index 0000000000..6f7778a35c --- /dev/null +++ b/examples/diffusers/distillation/configs/distillation_example.yaml @@ -0,0 +1,142 @@ +# LTX-2 Distillation Training Configuration with ModelOpt + +# Model Configuration +model: + # Path to the LTX-2 checkpoint (used for both teacher and student) + model_path: "/path/to/ltx2/checkpoint.safetensors" + + # Path to Gemma text encoder (required for LTX-2) + text_encoder_path: "/path/to/gemma/model" + + # Training mode: "lora" is not supported yet + training_mode: "full" + +# Distillation Configuration +distillation: + # Path to teacher model (if different from model.model_path) + # Set to null to use the same checkpoint as student (loaded without quantization) + teacher_model_path: + + # Weight for task loss: L_total = α * L_task + (1-α) * L_distill + # α = 1.0: pure task loss (no distillation) + # α = 0.0: pure distillation loss + distillation_alpha: 0.0 + + # Type of distillation loss + # "mse": Mean squared error (recommended - transformer outputs are continuous velocity predictions) + # "cosine": Cosine similarity loss (matches direction only, ignores magnitude) + distillation_loss_type: "mse" + + # Data type for teacher model (bfloat16 recommended for memory efficiency) + teacher_dtype: "bfloat16" + + # ModelOpt Quantization Settings + # Name of the mtq config, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG. + # Custom configs defined in CUSTOM_QUANT_CONFIGS (distillation_trainer.py) are also supported. + quant_cfg: + + # Full-inference calibration settings (matching PTQ workflow). + # Each prompt runs a complete denoising loop through the DiT, covering all noise levels. + # Path to a text file with one prompt per line. If null, uses the default + # HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (same as PTQ). + calibration_prompts_file: + # Total number of calibration prompts (set to 0 to skip calibration) + calibration_size: 128 + # Number of denoising steps per prompt (matches PTQ --n-steps) + calibration_n_steps: 30 + # CFG guidance scale during calibration (4.0 = PTQ default, calls transformer + # twice per step for positive + negative prompt; 1.0 = no CFG, saves memory) + calibration_guidance_scale: 4.0 + + # Path to restore a previously quantized model (from mto.save) + restore_quantized_checkpoint: + + # Path to save the final quantized model checkpoint + save_quantized_checkpoint: + + # Resume from a full training state checkpoint (saves model + optimizer + RNG + step) + # Set to "latest" to auto-find the most recent checkpoint in output_dir/checkpoints/ + # Or set to an explicit path like "/path/to/checkpoints/step_001000" + resume_from_checkpoint: latest + + # Time-limit-aware saving for Slurm jobs. + # Minutes after which training must save a checkpoint and exit gracefully. + # Set slightly below your Slurm --time limit (e.g. time=30min -> must_save_by: 25). + # Timer starts when train() is called (after model loading/calibration). + must_save_by: + + # Debug/Test: Use mock data instead of real preprocessed data + # Useful for testing the training pipeline without preparing a dataset + use_mock_data: false + mock_data_samples: 100 + +# Training Strategy +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 + with_audio: false + +# Optimization Configuration +optimization: + learning_rate: 2.0e-6 + steps: 10000 + batch_size: 1 + gradient_accumulation_steps: 4 + max_grad_norm: 1.0 + optimizer_type: "adamw" # # Use "adamw8bit" for memory efficiency + scheduler_type: "cosine" + enable_gradient_checkpointing: true # Essential for memory savings + +# Acceleration Configuration +acceleration: + mixed_precision_mode: "bf16" + + # NOTE: Set to null - we use ModelOpt quantization instead of ltx-trainer's quanto + quantization: + + # 8-bit text encoder for memory savings + load_text_encoder_in_8bit: false + +# Data Configuration +data: + # Path to preprocessed training data (created by process_dataset.py) + preprocessed_data_root: "/path/to/preprocessed/data" + num_dataloader_workers: 2 + +# Validation Configuration +validation: + prompts: + - "A beautiful sunset over the ocean with gentle waves" + - "A cat playing with a ball of yarn in a cozy living room" + negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" + video_dims: [512, 320, 33] # [width, height, frames] + frame_rate: 25.0 + inference_steps: 30 + interval: 500 # Validate every 500 steps + guidance_scale: 4.0 + seed: 42 + +# Checkpointing Configuration +checkpoints: + interval: 1000 # Save checkpoint every 1000 steps + keep_last_n: 3 # Keep only last 3 checkpoints + precision: "bfloat16" + +# Weights & Biases Logging +wandb: + enabled: true + project: "ltx2-distillation" + entity: # Your W&B username or team + tags: + - "distillation" + - "modelopt" + log_validation_videos: true + +# Flow Matching Configuration +flow_matching: + timestep_sampling_mode: "shifted_logit_normal" + timestep_sampling_params: {} + +# General Settings +seed: 42 +output_dir: "./outputs/distillation_experiment" diff --git a/examples/diffusers/distillation/distillation_trainer.py b/examples/diffusers/distillation/distillation_trainer.py new file mode 100644 index 0000000000..d98278b9af --- /dev/null +++ b/examples/diffusers/distillation/distillation_trainer.py @@ -0,0 +1,1832 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Distillation Trainer for LTX-2 DiT Model with ModelOpt Quantization + +This module implements sparsity/quantization-aware distillation training where: +- Teacher: Original unsparsified/unquantized model (inference only) +- Student: Quantized model using ModelOpt's fake quantization (trainable) + +The distillation loss combines: +- L_task: Standard flow matching MSE loss (student_pred vs velocity_target) +- L_distill: Distillation MSE loss (student_pred vs teacher_pred) + +Usage: + python distillation_trainer.py --config configs/distillation_example.yaml +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import time +from pathlib import Path +from typing import Literal + +import torch +import torch.distributed as dist +from ltx_trainer import logger +from ltx_trainer.config import ConfigBaseModel, LtxTrainerConfig +from ltx_trainer.model_loader import load_transformer +from ltx_trainer.trainer import IS_MAIN_PROCESS, LtxvTrainer +from omegaconf import OmegaConf +from pydantic import Field +from torch import Tensor + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq + +# Custom quantization configs. Checked before mtq built-in configs. +# Add your own configs here; they take precedence over mtq.* attributes. +CUSTOM_QUANT_CONFIGS: dict[str, dict] = { + # Example: override NVFP4 with a different algorithm + # "MY_NVFP4_CFG": { + # "quant_cfg": mtq.NVFP4_DEFAULT_CFG["quant_cfg"], + # "algorithm": "max", + # }, +} + + +# IS_MAIN_PROCESS (from ltx_trainer) checks LOCAL_RANK == 0, which is True on +# every node in multi-node training. For file writes on a shared filesystem +# (Lustre) we need a global-rank-0 check so that only a single process writes. +def is_global_rank0() -> bool: + """Check if this is global rank 0. Safe to call before or after dist init.""" + if dist.is_initialized(): + return dist.get_rank() == 0 + return os.environ.get("RANK", "0") == "0" + + +def get_quant_config(quant_cfg_name: str) -> dict: + """ + Resolve a quantization config by name. + + Lookup order: + 1. CUSTOM_QUANT_CONFIGS (user-defined overrides in this file) + 2. mtq. (built-in ModelOpt configs, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG) + + Args: + quant_cfg_name: Name of the quantization config, e.g. "FP8_DEFAULT_CFG". + + Returns: + A copy of the quantization config dict. + """ + # Check custom configs first + if quant_cfg_name in CUSTOM_QUANT_CONFIGS: + logger.info(f"Using custom quant config: {quant_cfg_name}") + return CUSTOM_QUANT_CONFIGS[quant_cfg_name].copy() + + # Fall back to mtq built-in configs + cfg = getattr(mtq, quant_cfg_name, None) + if cfg is None: + available_custom = list(CUSTOM_QUANT_CONFIGS.keys()) + available_mtq = [ + attr for attr in dir(mtq) if attr.endswith("_CFG") and not attr.startswith("_") + ] + raise ValueError( + f"Unknown quant_cfg: '{quant_cfg_name}'. " + f"Available custom: {available_custom}. " + f"Available mtq built-in: {available_mtq}" + ) + logger.info(f"Using mtq built-in quant config: {quant_cfg_name}") + return cfg.copy() + + +class MockDataset(torch.utils.data.Dataset): + """ + Mock dataset that produces random data matching the expected training format. + + This is useful for testing the training pipeline without preparing real data. + The output format matches what PrecomputedDataset produces, with keys: + - "latents": video latent tensors and metadata + - "conditions": text embeddings and attention masks + + Note: prompt_embed_dim should be 3840 (the connector's inner_dim = 30 heads * 128 dim), + NOT 4096 (Gemma's raw hidden size). The PrecomputedDataset stores embeddings that have + already been projected through the feature_extractor_linear layer. + """ + + def __init__( + self, + width: int = 512, + height: int = 320, + num_frames: int = 33, + dataset_length: int = 100, + latent_dim: int = 128, + latent_spatial_compression_ratio: int = 32, + latent_temporal_compression_ratio: int = 8, + prompt_embed_dim: int = 3840, # Connector inner_dim, not Gemma's 4096 + prompt_sequence_length: int = 256, + fps: int = 25, + dtype: torch.dtype = torch.bfloat16, # Must match model dtype + ): + """ + Initialize mock dataset. + + Args: + width: Video width in pixels (must be divisible by 32) + height: Video height in pixels (must be divisible by 32) + num_frames: Number of video frames (should be 8k+1 for proper compression) + dataset_length: Number of samples in the dataset + latent_dim: Latent channel dimension (128 for LTX-2) + latent_spatial_compression_ratio: Spatial compression ratio (32 for LTX-2) + latent_temporal_compression_ratio: Temporal compression ratio (8 for LTX-2) + prompt_embed_dim: Text embedding dimension after projection (3840 for LTX-2, + which is connector's inner_dim = 30 heads * 128 dim_head) + prompt_sequence_length: Max text sequence length + fps: Frames per second + dtype: Data type for floating point tensors (must match model dtype, default bfloat16) + """ + self.width = width + self.height = height + self.num_frames = num_frames + self.dataset_length = dataset_length + self.latent_dim = latent_dim + self.num_latent_frames = (num_frames - 1) // latent_temporal_compression_ratio + 1 + self.latent_height = height // latent_spatial_compression_ratio + self.latent_width = width // latent_spatial_compression_ratio + self.prompt_embed_dim = prompt_embed_dim + self.prompt_sequence_length = prompt_sequence_length + self.fps = fps + self.dtype = dtype + + def __len__(self) -> int: + return self.dataset_length + + def __getitem__(self, idx: int) -> dict: + """ + Get a mock sample. + + Returns format expected by training strategy: + - latents: dict with "latents" tensor [C, F, H, W] and metadata + - conditions: dict with "prompt_embeds" and "prompt_attention_mask" + """ + return { + # Video latents (key: "latents" to match PrecomputedDataset) + "latents": { + "latents": torch.randn( + self.latent_dim, + self.num_latent_frames, + self.latent_height, + self.latent_width, + dtype=self.dtype, # Must match model dtype + ), + "num_frames": torch.tensor(self.num_latent_frames), + "height": torch.tensor(self.latent_height), + "width": torch.tensor(self.latent_width), + "fps": torch.tensor(self.fps), + }, + # Text conditions (key: "conditions" to match PrecomputedDataset) + "conditions": { + "prompt_embeds": torch.randn( + self.prompt_sequence_length, + self.prompt_embed_dim, + dtype=self.dtype, # Must match model dtype + ), + # Attention mask must be numeric (not bool) for _run_connectors + # Using int8 to save memory (1 byte vs 8 bytes for long) + "prompt_attention_mask": torch.ones( + self.prompt_sequence_length, + dtype=torch.int8, + ), + }, + "idx": idx, + } + + +class DistillationConfig(ConfigBaseModel): + """Configuration for distillation-specific parameters.""" + + teacher_model_path: str | Path | None = Field( + default=None, + description="Path to the teacher model checkpoint. If None, uses the same as model.model_path " + "(teacher is loaded without quantization).", + ) + + distillation_alpha: float = Field( + default=0.5, + description="Weight for the task loss. Distillation loss weight is (1 - alpha). " + "alpha=1.0 means no distillation (pure task loss), " + "alpha=0.0 means pure distillation loss.", + ge=0.0, + le=1.0, + ) + + distillation_loss_type: Literal["mse", "cosine"] = Field( + default="mse", + description="Type of distillation loss. 'mse' is recommended since transformer outputs " + "are continuous velocity predictions in latent space (not probabilities). " + "'cosine' matches direction only, ignoring magnitude.", + ) + + teacher_dtype: Literal["bfloat16", "float16", "float32"] = Field( + default="bfloat16", + description="Data type for teacher model. BFloat16 is recommended for memory efficiency.", + ) + + # ModelOpt Quantization Settings + quant_cfg: str | None = Field( + default=None, + description="Name of the ModelOpt quantization config to apply to the student model. " + "Looked up first in CUSTOM_QUANT_CONFIGS (distillation_trainer.py), then as mtq.. " + "Examples: 'FP8_DEFAULT_CFG', 'INT8_DEFAULT_CFG', 'NVFP4_DEFAULT_CFG'. " + "Set to None to disable quantization.", + ) + + # Calibration settings (full-inference calibration, matching PTQ workflow) + calibration_prompts_file: str | Path | None = Field( + default=None, + description="Path to a text file with one calibration prompt per line. " + "If None, uses the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' ", + ) + + calibration_size: int = Field( + default=128, + description="Total number of calibration prompts to use. Each prompt runs a full " + "denoising inference through the DiT, covering all noise levels. ", + ge=0, + ) + + calibration_n_steps: int = Field( + default=30, + description="Number of denoising steps per calibration prompt. Each step calls the " + "transformer at a different noise level.", + ge=1, + ) + + calibration_guidance_scale: float = Field( + default=4.0, + description="CFG guidance scale during calibration. Default 4.0.", + ge=1.0, + ) + + restore_quantized_checkpoint: str | Path | None = Field( + default=None, + description="Path to restore a previously quantized model from mto.save().", + ) + + save_quantized_checkpoint: str | Path | None = Field( + default=None, + description="Path to save the final quantized model checkpoint.", + ) + + # Checkpoint resume settings + resume_from_checkpoint: str | Path | None = Field( + default=None, + description="Path to a training state checkpoint directory (from save_training_state) to resume " + "training from. Restores model weights, optimizer, LR scheduler, RNG states, and step counter. " + "Set to 'latest' to auto-find the latest checkpoint in output_dir/checkpoints/.", + ) + + must_save_by: float | None = Field( + default=None, + description="Minutes after which training must save a checkpoint and exit. " + "Use this when running under a Slurm time limit — set to a value slightly less " + "than the time limit (e.g., time_limit=30min → must_save_by=25) to ensure " + "a checkpoint is saved before the job is killed. Timer starts at train() entry. " + "Set to None to disable.", + gt=0, + ) + + # Debug/Test options + use_mock_data: bool = Field( + default=False, + description="Use mock data instead of real preprocessed data for testing.", + ) + + mock_data_samples: int = Field( + default=100, + description="Number of mock samples to generate when use_mock_data is True.", + ge=1, + ) + + +class DistillationTrainerConfig(LtxTrainerConfig): + """Extended trainer config with distillation settings.""" + + distillation: DistillationConfig = Field( + default_factory=DistillationConfig, + description="Distillation-specific configuration.", + ) + + +class DistillationTrainer(LtxvTrainer): + """ + Distillation trainer that extends LtxvTrainer with: + - Teacher model loading and inference + - ModelOpt quantization for student + - Combined task + distillation loss + """ + + def __init__(self, config: DistillationTrainerConfig) -> None: + # Store distillation config before parent init (needed by overrides called during super().__init__) + self._distillation_config = config.distillation + # Will be populated by _load_text_encoder_and_cache_embeddings() during super().__init__ + self._cached_calibration_embeddings: list | None = None + + # Create base trainer config (without distillation section) + trainer_config = LtxTrainerConfig( + **{k: v for k, v in config.model_dump().items() if k != "distillation"} + ) + + # Initialize parent (loads student model, sets up accelerator) + # Note: _prepare_models_for_training() is overridden to NOT call + # accelerator.prepare() on the student — we defer that to _init_optimizer() + # so model+optimizer can be prepared together (required by FSDP2). + super().__init__(trainer_config) + + # Load teacher model (after parent init so we have accelerator) + # Teacher is loaded, frozen, and prepared with a dummy optimizer. + self._load_teacher_model() + + logger.info( + f"Distillation training initialized with alpha={self._distillation_config.distillation_alpha:.2f}" + ) + + def _prepare_models_for_training(self) -> None: + """ + Override parent to defer accelerator.prepare() for the student model. + + The parent calls accelerator.prepare(transformer) here, but FSDP2 requires + model and optimizer to be prepared together. So we do everything the parent + does EXCEPT the accelerator.prepare() call — that happens in _init_optimizer() + where we can call prepare(model, optimizer, scheduler) together. + """ + from accelerate.utils import DistributedType + + # For FSDP + LoRA: Cast entire model to FP32 for uniform dtype + if ( + self._accelerator.distributed_type == DistributedType.FSDP + and self._config.model.training_mode == "lora" + ): + logger.debug("FSDP: casting transformer to FP32 for uniform dtype") + self._transformer = self._transformer.to(dtype=torch.float32) + + # Enable gradient checkpointing if requested + transformer = ( + self._transformer.get_base_model() + if hasattr(self._transformer, "get_base_model") + else self._transformer + ) + transformer.set_gradient_checkpointing( + self._config.optimization.enable_gradient_checkpointing + ) + + # Keep frozen models on CPU for memory efficiency + self._vae_decoder = self._vae_decoder.to("cpu") + if self._vae_encoder is not None: + self._vae_encoder = self._vae_encoder.to("cpu") + + # NOTE: We intentionally do NOT call accelerator.prepare(self._transformer) here. + # It will be called in _init_optimizer() together with the optimizer, which is + # required for FSDP2 compatibility. This also works fine with FSDP1. + + # Log GPU memory usage + vram_usage_gb = torch.cuda.memory_allocated() / 1024**3 + logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB") + + def _load_text_encoder_and_cache_embeddings(self): + """ + Override parent to also cache calibration prompt embeddings before Gemma is unloaded. + + The parent method loads the full Gemma text encoder, caches validation prompt embeddings, + then UNLOADS the heavy Gemma model (sets model/tokenizer/feature_extractor_linear to None) + to free VRAM. Only the lightweight embedding connectors remain. + + We hook in here to also cache calibration prompt embeddings while the full text encoder + is still available. These cached embeddings are later used by _run_inference_calibration() + via the ValidationSampler's CachedPromptEmbeddings mechanism. + """ + from ltx_trainer.model_loader import load_text_encoder + from ltx_trainer.validation_sampler import CachedPromptEmbeddings + + # Call parent to load text encoder, cache validation embeddings, and unload Gemma. + # But we need to intercept BEFORE the unload. We re-implement the parent logic + # with our addition in the middle. + + logger.debug("Loading text encoder...") + self._text_encoder = load_text_encoder( + checkpoint_path=self._config.model.model_path, + gemma_model_path=self._config.model.text_encoder_path, + device="cuda", + dtype=torch.bfloat16, + load_in_8bit=self._config.acceleration.load_text_encoder_in_8bit, + ) + + # Cache validation embeddings (same as parent) + cached_validation = None + if self._config.validation.prompts: + logger.info( + f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts..." + ) + cached_validation = [] + with torch.inference_mode(): + for prompt in self._config.validation.prompts: + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) + v_ctx_neg, a_ctx_neg, _ = self._text_encoder( + self._config.validation.negative_prompt + ) + cached_validation.append( + CachedPromptEmbeddings( + video_context_positive=v_ctx_pos.cpu(), + audio_context_positive=a_ctx_pos.cpu(), + video_context_negative=v_ctx_neg.cpu() + if v_ctx_neg is not None + else None, + audio_context_negative=a_ctx_neg.cpu() + if a_ctx_neg is not None + else None, + ) + ) + + # Cache calibration prompt embeddings while the heavy text encoder is still loaded. + # Only needed if we'll actually run fresh calibration (Path C). Skip if a + # resumable checkpoint, user-specified checkpoint, or step 0 checkpoint exists. + calib_cfg = self._distillation_config + if ( + calib_cfg.quant_cfg is not None + and calib_cfg.calibration_size > 0 + and self._needs_fresh_calibration() + ): + prompts = self._load_calibration_prompts() + negative_prompt = getattr( + self._config.validation, + "negative_prompt", + "worst quality, inconsistent motion, blurry, jittery, distorted", + ) + logger.info( + f"Pre-computing embeddings for {len(prompts)} calibration prompts " + f"(guidance_scale={calib_cfg.calibration_guidance_scale})..." + ) + self._cached_calibration_embeddings = [] + use_cfg = calib_cfg.calibration_guidance_scale != 1.0 + with torch.inference_mode(): + for prompt in prompts: + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) + v_ctx_neg, a_ctx_neg = None, None + if use_cfg: + v_ctx_neg, a_ctx_neg, _ = self._text_encoder(negative_prompt) + self._cached_calibration_embeddings.append( + CachedPromptEmbeddings( + video_context_positive=v_ctx_pos.cpu(), + audio_context_positive=a_ctx_pos.cpu(), + video_context_negative=v_ctx_neg.cpu() + if v_ctx_neg is not None + else None, + audio_context_negative=a_ctx_neg.cpu() + if a_ctx_neg is not None + else None, + ) + ) + logger.info(f"Cached {len(self._cached_calibration_embeddings)} calibration embeddings") + + # Unload heavy components to free VRAM, keeping only the embedding connectors + self._text_encoder.model = None + self._text_encoder.tokenizer = None + self._text_encoder.feature_extractor_linear = None + gc.collect() + torch.cuda.empty_cache() + logger.debug("Validation/calibration prompt embeddings cached. Gemma model unloaded") + + return cached_validation + + def _load_models(self) -> None: + """ + Load the LTX-2 model components with ModelOpt quantization for student. + + This overrides the parent method to: + 1. Load models as usual (without ltx-trainer's quantization) + 2. Apply ModelOpt fake quantization to the student transformer + """ + # Call parent to load all models normally + super()._load_models() + + # Apply ModelOpt quantization to student if configured + if self._distillation_config.quant_cfg is not None: + self._apply_modelopt_quantization() + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Quantized model: {self._transformer}") + + def _needs_fresh_calibration(self) -> bool: + """Check whether fresh quantization calibration will be needed. + + Returns False if an existing checkpoint can be restored instead + (Path A, B, or B2 in _apply_modelopt_quantization), meaning we can + skip the expensive calibration embedding caching. + """ + cfg = self._distillation_config + + # Path A: resume checkpoint with modelopt_state.pt + if cfg.resume_from_checkpoint is not None: + checkpoint_dir = self._find_resume_checkpoint(cfg.resume_from_checkpoint) + if checkpoint_dir is not None: + if (checkpoint_dir / "modelopt_state.pt").exists(): + return False + + # Path B: user-specified quantized checkpoint + if cfg.restore_quantized_checkpoint is not None: + return False + + # Path B2: auto-detected step 0 checkpoint + step0_path = self._get_checkpoints_dir() / "step_000000_quantized" / "backbone.pt" + return not step0_path.exists() + + def _apply_modelopt_quantization(self) -> None: + """ + Apply ModelOpt fake quantization to the student transformer. + + Four paths are supported (checked in order): + + Path A - Resume from training checkpoint: + If resume_from_checkpoint is set, restore only the quantization module + architecture (fake quantizer modules) from the saved modelopt_state.pt. + The actual trained weights (including quantizer scales) will be loaded + later by accelerator.load_state() in _load_training_state(). + + Path B - Restore from a user-specified quantized checkpoint: + If restore_quantized_checkpoint is set, restore both architecture and + weights from a complete mto.save() checkpoint. + + Path B2 - Auto-detect step 0 quantized checkpoint: + If a previous run already completed calibration and saved the step 0 + checkpoint (step_000000_quantized/backbone.pt), restore from it + automatically. This avoids re-running the expensive calibration. + + Path C - Fresh quantization with full-inference calibration: + Apply mtq.quantize() with a forward_loop that runs full denoising + inference (like the PTQ workflow), covering all noise levels. + After calibration, saves the result as step 0 checkpoint for future runs. + """ + quant_cfg_name = self._distillation_config.quant_cfg + if not quant_cfg_name: + logger.info("No quant_cfg specified, skipping quantization") + return + + # Path A: Resume from training checkpoint — restore architecture only. + # The trained weights (including quantizer scales) are loaded later by + # accelerator.load_state() in _load_training_state(). + resume_path = self._distillation_config.resume_from_checkpoint + if resume_path is not None: + checkpoint_dir = self._find_resume_checkpoint(resume_path) + if checkpoint_dir is not None: + modelopt_state_path = checkpoint_dir / "modelopt_state.pt" + if modelopt_state_path.exists(): + logger.info( + f"Resuming: restoring quantization architecture from " + f"{modelopt_state_path} (weights loaded later by accelerator)" + ) + # Security NOTE: weights_only=False is used on ModelOpt-generated state, + # not on untrusted user input. + state = torch.load(modelopt_state_path, weights_only=False, map_location="cpu") + self._transformer = mto.restore_from_modelopt_state(self._transformer, state) + logger.info("Quantization architecture restored for resume") + return + else: + logger.warning( + f"modelopt_state.pt not found in {checkpoint_dir}, " + "falling through to fresh quantization" + ) + + # Path B: Restore from a standalone quantized checkpoint (architecture + weights). + if self._distillation_config.restore_quantized_checkpoint is not None: + restore_path = str(self._distillation_config.restore_quantized_checkpoint) + logger.info(f"Restoring quantized model from {restore_path}") + mto.restore(self._transformer, restore_path) + return + + # Path B2: Auto-detect step 0 quantized checkpoint from a previous run. + # If calibration was already completed and saved, reuse it instead of + # re-running the expensive calibration process. + step0_path = self._get_checkpoints_dir() / "step_000000_quantized" / "backbone.pt" + if step0_path.exists(): + logger.info( + f"Found existing step 0 quantized checkpoint at {step0_path}, " + "restoring instead of re-running calibration" + ) + try: + mto.restore(self._transformer, str(step0_path)) + return + except Exception as e: + logger.warning( + f"Failed to restore step 0 checkpoint (file may be corrupted): {e}. " + "Falling through to fresh quantization." + ) + + # Path C: Fresh quantization with full-inference calibration. + logger.info(f"Applying ModelOpt quantization ({quant_cfg_name}) to student transformer...") + + quant_config = get_quant_config(quant_cfg_name) + + def forward_loop(model): + """Run full-inference calibration covering all noise levels.""" + self._run_inference_calibration(model) + + mtq.quantize(self._transformer, quant_config, forward_loop=forward_loop) + + # Free cached calibration embeddings — no longer needed after quantization + self._cached_calibration_embeddings = None + + logger.info(f"ModelOpt quantization ({quant_cfg_name}) applied successfully") + + # Save the freshly quantized+calibrated model as "step 0" checkpoint. + # This avoids re-running calibration if training is interrupted before the + # first regular checkpoint. On resume, Path B2 auto-detects and loads this. + # Only model + quantizer scales are saved (no optimizer/scheduler state at step 0). + # We use atomic save (write to tmp, then rename) to prevent corrupt checkpoints. + step0_dir = self._get_checkpoints_dir() / "step_000000_quantized" + step0_path = step0_dir / "backbone.pt" + # Only global rank 0 saves (all ranks have identical models pre-FSDP); + # others wait at the barrier. Atomic save (tmp + rename) prevents corruption. + if is_global_rank0(): + step0_dir.mkdir(parents=True, exist_ok=True) + step0_tmp_path = step0_dir / "backbone.pt.tmp" + logger.info(f"Saving quantized model (step 0) to {step0_path}") + mto.save(self._transformer, str(step0_tmp_path)) + step0_tmp_path.rename(step0_path) + logger.info("Step 0 quantized checkpoint saved") + if dist.is_initialized(): + dist.barrier() + + def _create_mock_dataset(self) -> MockDataset: + """Create a mock dataset for testing without real data.""" + # Get video dimensions from validation config or use defaults + video_dims = getattr(self._config.validation, "video_dims", [512, 320, 33]) + width, height, num_frames = video_dims + + logger.info( + f"Creating mock dataset with {self._distillation_config.mock_data_samples} samples " + f"(video: {width}x{height}x{num_frames})" + ) + + return MockDataset( + width=width, + height=height, + num_frames=num_frames, + dataset_length=self._distillation_config.mock_data_samples, + ) + + def _load_calibration_prompts(self) -> list[str]: + """ + Load calibration prompts for full-inference quantization calibration. + + Follows the same pattern as the PTQ workflow (examples/diffusers/quantization/): + - If calibration_prompts_file is set: reads a text file with one prompt per line + - Otherwise: loads from the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' + + Returns: + List of calibration prompts, truncated to calibration_size. + """ + calib_size = self._distillation_config.calibration_size + prompts_file = self._distillation_config.calibration_prompts_file + + if prompts_file is not None: + prompts_path = Path(prompts_file) + if not prompts_path.exists(): + raise FileNotFoundError(f"Calibration prompts file not found: {prompts_path}") + logger.info(f"Loading calibration prompts from {prompts_path}") + with open(prompts_path) as f: + prompts = [line.strip() for line in f if line.strip()] + else: + logger.info( + "Loading calibration prompts from HuggingFace dataset " + "'Gustavosta/Stable-Diffusion-Prompts'..." + ) + from datasets import load_dataset + + dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") + prompts = list(dataset["train"]["Prompt"]) + + # Truncate to requested size + prompts = prompts[:calib_size] + logger.info(f"Loaded {len(prompts)} calibration prompts") + return prompts + + def _run_inference_calibration(self, model: torch.nn.Module) -> None: + """ + Run full-inference calibration through the DiT, covering all noise levels. + + This replaces the old training-style calibration with full denoising inference, + matching the PTQ workflow. For each calibration prompt, a complete denoising loop + is run (e.g. 30 steps), so the transformer sees activations at every noise level. + + With CFG guidance_scale > 1.0 (default 4.0), each denoising step calls the + transformer twice (positive + negative prompt), matching real inference patterns. + + Note: Text embeddings were pre-computed and cached in + _load_text_encoder_and_cache_embeddings() BEFORE the Gemma model was unloaded. + We pass these cached embeddings to the ValidationSampler via GenerationConfig. + + Args: + model: The transformer model being calibrated (same reference as self._transformer, + with statistics collection enabled by mtq.quantize). + """ + from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler + + calib_cfg = self._distillation_config + if calib_cfg.calibration_size == 0: + logger.info("Skipping calibration (calibration_size=0)") + return + + if not self._cached_calibration_embeddings: + raise RuntimeError( + "No cached calibration embeddings available! " + "Probably the saved checkpoint has no modelopt_state.pt or corrupted." + ) + + # Get video dimensions from validation config + video_dims = getattr(self._config.validation, "video_dims", [512, 320, 33]) + width, height, num_frames = video_dims + negative_prompt = getattr( + self._config.validation, + "negative_prompt", + "worst quality, inconsistent motion, blurry, jittery, distorted", + ) + num_prompts = len(self._cached_calibration_embeddings) + + logger.info( + f"Running full-inference calibration: {num_prompts} prompts, " + f"{calib_cfg.calibration_n_steps} steps/prompt, " + f"guidance_scale={calib_cfg.calibration_guidance_scale}, " + f"video={width}x{height}x{num_frames}" + ) + + # Create a ValidationSampler with the model being calibrated. + # The exact model reference matters: mtq.quantize() sets up statistics + # collection on this instance, so all forward passes must go through it. + # text_encoder=None because we use pre-cached embeddings (Gemma is unloaded). + sampler = ValidationSampler( + transformer=model, + vae_decoder=self._vae_decoder, + vae_encoder=self._vae_encoder, + text_encoder=None, # Gemma unloaded; using cached embeddings + audio_decoder=None, # Skip audio for calibration + vocoder=None, + ) + + device = "cuda" + model.eval() + + with torch.no_grad(): + for i, cached_emb in enumerate(self._cached_calibration_embeddings): + gen_config = GenerationConfig( + prompt="", # Not used when cached_embeddings is provided + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + frame_rate=getattr(self._config.validation, "frame_rate", 25.0), + num_inference_steps=calib_cfg.calibration_n_steps, + guidance_scale=calib_cfg.calibration_guidance_scale, + seed=42 + i, # Vary seed per prompt for diverse activations + generate_audio=False, + tiled_decoding=False, # Skip tiling overhead + cached_embeddings=cached_emb, # Pre-computed text embeddings + ) + + try: + sampler.generate(config=gen_config, device=device) + except Exception as e: + logger.warning(f"Calibration prompt {i} failed: {e}") + continue + + if (i + 1) % 10 == 0 or (i + 1) == len(self._cached_calibration_embeddings): + logger.info(f"Calibration progress: {i + 1}/{num_prompts} prompts") + + model.train() + logger.info("Full-inference calibration complete") + + def _init_optimizer(self) -> None: + """ + Override parent to prepare student model + optimizer + scheduler together. + + FSDP2 requires model and optimizer to be passed to accelerator.prepare() + in a single call. This override: + 1. Creates the optimizer (pointing at self._transformer parameters) + 2. Creates the LR scheduler + 3. Calls accelerator.prepare(model, optimizer, scheduler) together + + This is compatible with both FSDP1 and FSDP2. + """ + from torch.optim import AdamW + + opt_cfg = self._config.optimization + + lr = opt_cfg.learning_rate + if opt_cfg.optimizer_type == "adamw": + optimizer = AdamW(self._trainable_params, lr=lr) + elif opt_cfg.optimizer_type == "adamw8bit": + from bitsandbytes.optim import AdamW8bit + + optimizer = AdamW8bit(self._trainable_params, lr=lr) + else: + raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}") + + lr_scheduler = self._create_scheduler(optimizer) + + # Prepare student model + optimizer + scheduler together (FSDP2 requirement) + logger.info("Preparing student model + optimizer + scheduler with accelerator...") + if lr_scheduler is not None: + self._transformer, self._optimizer, self._lr_scheduler = self._accelerator.prepare( + self._transformer, optimizer, lr_scheduler + ) + else: + self._transformer, self._optimizer = self._accelerator.prepare( + self._transformer, optimizer + ) + self._lr_scheduler = None + + # Log memory after preparation + if torch.cuda.is_available(): + mem_gb = torch.cuda.memory_allocated() / 1024**3 + logger.info(f"GPU memory after model+optimizer preparation: {mem_gb:.2f} GB") + + def _init_dataloader(self) -> None: + """Override to support mock data for training.""" + if self._distillation_config.use_mock_data: + from torch.utils.data import DataLoader + + self._dataset = self._create_mock_dataset() + self._dataloader = DataLoader( + self._dataset, + batch_size=self._config.optimization.batch_size, + shuffle=True, + num_workers=self._config.data.num_dataloader_workers, + pin_memory=True, + drop_last=True, + ) + # Wrap with accelerator + self._dataloader = self._accelerator.prepare(self._dataloader) + else: + # Use parent implementation for real data + super()._init_dataloader() + + def _load_teacher_model(self) -> None: + """ + Load the teacher transformer model for distillation. + + The teacher is loaded, frozen, and prepared with the accelerator using a + dummy SGD optimizer (lr=0, never stepped). The dummy optimizer is needed + because FSDP2 requires model+optimizer together in prepare(). For FSDP1, + this also works fine (prepare just wraps the model). + """ + from torch.optim import SGD + + teacher_path = self._distillation_config.teacher_model_path + if teacher_path is None: + teacher_path = self._config.model.model_path + + # Map dtype string to torch dtype + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + teacher_dtype = dtype_map[self._distillation_config.teacher_dtype] + + logger.info( + f"Loading teacher model from {teacher_path} with dtype={self._distillation_config.teacher_dtype}" + ) + + # Load teacher transformer to CPU first + self._teacher_transformer = load_transformer( + checkpoint_path=str(teacher_path), + device="cpu", + dtype=teacher_dtype, + ) + + # Teacher is inference-only, freeze it + self._teacher_transformer.requires_grad_(False) + self._teacher_transformer.eval() + + # Prepare teacher with accelerator using a dummy optimizer. + # FSDP2 requires model+optimizer together in prepare(). We use a minimal + # SGD with lr=0 that will never be stepped — just to satisfy the API. + logger.info( + f"Preparing teacher model with accelerator (distributed_type={self._accelerator.distributed_type})" + ) + teacher_params = list(self._teacher_transformer.parameters()) + dummy_optimizer = SGD(teacher_params, lr=0.0) + + self._teacher_transformer, wrapped_dummy_optimizer = self._accelerator.prepare( + self._teacher_transformer, dummy_optimizer + ) + + # Remove the teacher model and dummy optimizer from accelerator's internal + # tracking lists. This prevents save_state()/load_state() from saving/loading + # the teacher (which is frozen and loaded fresh from the original checkpoint + # on each run). The FSDP wrapping is already done at this point, so the + # teacher doesn't need to stay registered. + # Note: _models and _optimizers must stay 1:1 aligned for FSDP optimizer + # save/load (load_fsdp_optimizer uses _models[i] to pair with _optimizers[i]). + # We use the wrapped objects returned by prepare() since _optimizers stores + # AcceleratedOptimizer wrappers, not raw optimizers. + self._accelerator._models.remove(self._teacher_transformer) + self._accelerator._optimizers.remove(wrapped_dummy_optimizer) + + # Re-freeze teacher after prepare (FSDP wrapping may reset requires_grad) + self._teacher_transformer.requires_grad_(False) + self._teacher_transformer.eval() + + # Log memory after teacher loading + if torch.cuda.is_available(): + mem_gb = torch.cuda.memory_allocated() / 1024**3 + logger.info(f"GPU memory after teacher preparation: {mem_gb:.2f} GB") + + logger.info( + "Teacher model loaded and prepared (unregistered from accelerator state tracking)" + ) + + def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor: + """ + Perform a single distillation training step. + + Computes combined loss: + L_total = alpha * L_task + (1 - alpha) * L_distill + + where: + - L_task: MSE between student prediction and flow matching target + - L_distill: MSE between student prediction and teacher prediction + """ + alpha = self._distillation_config.distillation_alpha + + # Apply embedding connectors to transform pre-computed text embeddings + conditions = batch["conditions"] + video_embeds, audio_embeds, attention_mask = self._text_encoder._run_connectors( + conditions["prompt_embeds"], conditions["prompt_attention_mask"] + ) + conditions["video_prompt_embeds"] = video_embeds + conditions["audio_prompt_embeds"] = audio_embeds + conditions["prompt_attention_mask"] = attention_mask + + # Use strategy to prepare training inputs + model_inputs = self._training_strategy.prepare_training_inputs( + batch, self._timestep_sampler + ) + + # Run student forward pass + student_video_pred, student_audio_pred = self._transformer( + video=model_inputs.video, + audio=model_inputs.audio, + perturbations=None, + ) + + # Compute task loss only if alpha > 0 + if alpha > 0: + task_loss = self._training_strategy.compute_loss( + student_video_pred, student_audio_pred, model_inputs + ) + else: + task_loss = torch.tensor(0.0, device=student_video_pred.device) + + # Compute distillation loss only if alpha < 1 + if alpha < 1.0: + # Run teacher forward pass (no gradients) + with torch.no_grad(): + teacher_video_pred, _teacher_audio_pred = self._teacher_transformer( + video=model_inputs.video, + audio=model_inputs.audio, + perturbations=None, + ) + + # Compute distillation loss + distill_loss = self._compute_distillation_loss( + student_video_pred, + teacher_video_pred, + loss_mask=model_inputs.video_loss_mask, + ) + else: + distill_loss = torch.tensor(0.0, device=student_video_pred.device) + + # Combine losses + total_loss = alpha * task_loss + (1.0 - alpha) * distill_loss + + # Log individual losses using parent's _log_metrics pattern (no explicit step) + # This avoids step conflicts with wandb's auto-incrementing step counter + if hasattr(self, "_accelerator") and self._accelerator.is_main_process: + self._log_metrics( + { + "loss/task": task_loss.item(), + "loss/distillation": distill_loss.item(), + "loss/total": total_loss.item(), + } + ) + + return total_loss + + def _compute_distillation_loss( + self, + student_pred: Tensor, + teacher_pred: Tensor, + loss_mask: Tensor | None = None, + ) -> Tensor: + """Compute distillation loss between student and teacher predictions.""" + loss_type = self._distillation_config.distillation_loss_type + + if loss_type == "mse": + loss = torch.nn.functional.mse_loss(student_pred, teacher_pred, reduction="none") + elif loss_type == "cosine": + student_flat = student_pred.flatten(start_dim=2) + teacher_flat = teacher_pred.flatten(start_dim=2) + cos_sim = torch.nn.functional.cosine_similarity(student_flat, teacher_flat, dim=-1) + loss = 1.0 - cos_sim.mean() + else: + raise ValueError(f"Unknown distillation loss type: {loss_type}") + + # Apply loss mask if provided + # loss_mask is [B, seq_len], need to unsqueeze to [B, seq_len, 1] for broadcasting + # with loss shape [B, seq_len, C] + if loss_mask is not None: + # Unsqueeze and convert to float for multiplication + loss_mask = loss_mask.unsqueeze(-1).float() + # Apply mask and normalize (same as original trainer) + loss = loss.mul(loss_mask).div(loss_mask.mean()) + loss = loss.mean() + else: + loss = loss.mean() + + return loss + + def save_quantized_model(self, path: str | Path | None = None) -> None: + """Save the quantized model using ModelOpt (global rank 0 only).""" + if not is_global_rank0(): + return + if path is None: + path = self._distillation_config.save_quantized_checkpoint + if path is None: + path = Path(self._config.output_dir) / "quantized_model" + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving quantized model to {path}") + mto.save(self._transformer, str(path)) + logger.info("Quantized model saved successfully") + + # ── Overrides to fix multi-node shared-FS writes ────────────────────── + # The parent trainer guards file writes with IS_MAIN_PROCESS (LOCAL_RANK==0), + # which is True on every node. We override to use is_global_rank0() so that + # only a single process writes on a shared filesystem. + + def _save_checkpoint(self) -> Path | None: + """Save model weights (override: use global rank 0 for file writes).""" + from accelerate.utils import DistributedType + from safetensors.torch import save_file + + is_lora = self._config.model.training_mode == "lora" + is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP + + save_dir = Path(self._config.output_dir) / "checkpoints" + prefix = "lora" if is_lora else "model" + filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors" + saved_weights_path = save_dir / filename + + # Collective operation — all ranks must participate + self._accelerator.wait_for_everyone() + full_state_dict = self._accelerator.get_state_dict(self._transformer) + + if not is_global_rank0(): + return None + + save_dir.mkdir(exist_ok=True, parents=True) + save_dtype = ( + torch.bfloat16 if self._config.checkpoints.precision == "bfloat16" else torch.float32 + ) + + if is_lora: + from peft import get_peft_model_state_dict + + unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False) + state_dict = get_peft_model_state_dict( + unwrapped, state_dict=full_state_dict if is_fsdp else None + ) + state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()} + state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()} + state_dict = { + k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in state_dict.items() + } + metadata = self._build_checkpoint_metadata() + save_file(state_dict, saved_weights_path, metadata=metadata) + else: + full_state_dict = { + k: v.to(save_dtype) if isinstance(v, Tensor) else v + for k, v in full_state_dict.items() + } + self._accelerator.save(full_state_dict, saved_weights_path) + + rel_path = saved_weights_path.relative_to(self._config.output_dir) + logger.info(f"Model weights for step {self._global_step} saved in {rel_path}") + + self._checkpoint_paths.append(saved_weights_path) + return saved_weights_path + + def _save_config(self) -> None: + """Save training config (override: use global rank 0 for file writes).""" + if not is_global_rank0(): + return + import yaml + + config_path = Path(self._config.output_dir) / "training_config.yaml" + with open(config_path, "w") as f: + yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2) + logger.info( + f"Training configuration saved to: {config_path.relative_to(self._config.output_dir)}" + ) + + def _init_wandb(self) -> None: + """Initialize W&B (override: use global rank 0 to avoid duplicate runs).""" + if not self._config.wandb.enabled or not is_global_rank0(): + self._wandb_run = None + return + # Delegate to parent's implementation on global rank 0 + super()._init_wandb() + + def _get_checkpoints_dir(self) -> Path: + """Return the directory used for full training state checkpoints.""" + return Path(self._config.output_dir) / "checkpoints" + + def _save_training_state(self) -> Path | None: + """ + Save the full training state using accelerator.save_state(). + + This saves everything needed to resume training exactly: + - Student model weights (FSDP-sharded) + - Optimizer state + - LR scheduler state + - RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + - Gradient scaler state (if using mixed precision) + - ModelOpt state (quantization architecture for restore on resume) + - Custom metadata (global_step, distillation config) + + Atomic save strategy: + 1. Save everything into step_XXXXXX_tmp/ + 2. After all writes complete, rename to step_XXXXXX/ + Directory rename is atomic on the same filesystem, so either + the final directory exists (complete) or it doesn't. If the + process is killed mid-save, only the _tmp directory remains, + which is cleaned up on the next run. + + Note: The teacher model is NOT saved here — it was unregistered from + the accelerator's tracking lists after prepare() (see _load_teacher_model). + On resume, the teacher is loaded fresh from the original pretrained checkpoint. + + Returns: + Path to the saved state directory, or None on non-main processes. + """ + final_dir = self._get_checkpoints_dir() / f"step_{self._global_step:06d}" + tmp_dir = self._get_checkpoints_dir() / f"step_{self._global_step:06d}_tmp" + + logger.info(f"Saving full training state at step {self._global_step}...") + + # Ensure the checkpoints directory exists before save_state. + if is_global_rank0(): + tmp_dir.mkdir(parents=True, exist_ok=True) + self._accelerator.wait_for_everyone() + + # Save into the _tmp directory first (all ranks participate for FSDP). + self._accelerator.save_state(str(tmp_dir)) + + # Additional saves only on global rank 0 to avoid file write races. + if is_global_rank0(): + # Save modelopt state for quantization architecture restoration on resume. + if self._distillation_config.quant_cfg is not None: + try: + modelopt_state_dict = mto.modelopt_state(self._transformer) + torch.save(modelopt_state_dict, tmp_dir / "modelopt_state.pt") + logger.debug("Saved modelopt_state.pt for resume") + except Exception as e: + logger.warning(f"Failed to save modelopt_state: {e}") + + # Save custom metadata. + metadata = { + "global_step": self._global_step, + "distillation_alpha": self._distillation_config.distillation_alpha, + "quant_cfg": self._distillation_config.quant_cfg, + } + metadata_path = tmp_dir / "distillation_metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + # Barrier: ensure all ranks finished writing before rename + self._accelerator.wait_for_everyone() + + # Atomic rename _tmp → final (only global rank 0) + if is_global_rank0(): + if tmp_dir.exists(): + tmp_dir.rename(final_dir) + logger.info(f"Training state saved to {final_dir} (step={self._global_step})") + else: + logger.error(f"Save directory {tmp_dir} not found after save_state — skipping") + + # Cleanup old / incomplete checkpoints + self._accelerator.wait_for_everyone() + self._cleanup_checkpoints() + + self._accelerator.wait_for_everyone() + return final_dir if is_global_rank0() else None + + def _cleanup_checkpoints(self) -> None: + """Remove old checkpoints, keeping only the last N. + + Also removes any *_tmp directories left behind by interrupted saves. + """ + if not is_global_rank0(): + return + + import shutil + + ckpt_dir = self._get_checkpoints_dir() + if not ckpt_dir.exists(): + return + + # Remove leftover _tmp directories from interrupted saves + for tmp_dir in ckpt_dir.glob("step_*_tmp"): + shutil.rmtree(tmp_dir, ignore_errors=True) + logger.info(f"Removed incomplete checkpoint: {tmp_dir.name}") + + # Keep only last N complete training checkpoints. + # Exclude _tmp (incomplete) and _quantized (calibration-only, not training state). + keep_n = self._config.checkpoints.keep_last_n + if keep_n <= 0: + return + + step_dirs = sorted(ckpt_dir.glob("step_[0-9]*"), key=lambda p: p.name) + step_dirs = [ + d + for d in step_dirs + if not d.name.endswith("_tmp") and not d.name.endswith("_quantized") + ] + if len(step_dirs) <= keep_n: + return + + dirs_to_remove = step_dirs[:-keep_n] + for old_dir in dirs_to_remove: + shutil.rmtree(old_dir, ignore_errors=True) + logger.info(f"Removed old checkpoint: {old_dir.name}") + + def _find_resume_checkpoint(self, path_or_keyword: str | Path) -> Path | None: + """ + Find the checkpoint directory to resume from. + + Only considers fully saved checkpoints (step_XXXXXX, not step_*_tmp). + Incomplete _tmp checkpoints are ignored and cleaned up. + + Args: + path_or_keyword: Either "latest" to auto-find, or an explicit path. + + Returns: + Path to the checkpoint directory, or None if not found. + """ + if str(path_or_keyword).lower() == "latest": + ckpt_dir = self._get_checkpoints_dir() + if not ckpt_dir.exists(): + logger.warning(f"No checkpoints directory found at {ckpt_dir}") + return None + + # Only match step_XXXXXX (6 digits), excluding _tmp (incomplete saves) + # and _quantized (step 0 calibration-only checkpoint, no training state). + step_dirs = sorted(ckpt_dir.glob("step_[0-9]*"), key=lambda p: p.name) + step_dirs = [ + d + for d in step_dirs + if not d.name.endswith("_tmp") and not d.name.endswith("_quantized") + ] + if not step_dirs: + logger.warning(f"No complete checkpoints found in {ckpt_dir}") + return None + + latest = step_dirs[-1] + logger.info(f"Auto-detected latest checkpoint: {latest}") + return latest + else: + path = Path(path_or_keyword) + if not path.exists(): + raise FileNotFoundError(f"Resume checkpoint not found: {path}") + return path + + def _load_training_state(self, checkpoint_dir: Path) -> int: + """ + Load full training state from a checkpoint directory. + + Note: The quantization architecture (fake quantizer modules) must already be + restored BEFORE this method is called. This happens in _apply_modelopt_quantization() + (Path A) which uses mto.restore_from_modelopt_state() to set up the module structure. + This method then loads the trained weights (including quantizer scales) into that + structure via accelerator.load_state(). + + This restores (all via accelerator.load_state()): + - Model weights (student, FSDP-sharded, including quantizer scales) + - Optimizer state + - LR scheduler state + - Dataloader iteration position (auto-skips consumed batches) + - RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + - Gradient scaler (mixed precision) + - global_step (from custom metadata file) + + Args: + checkpoint_dir: Path to the training state checkpoint directory. + + Returns: + The global_step to resume from. + """ + logger.info(f"Resuming training state from {checkpoint_dir}...") + + # accelerator.load_state() is a collective op — all ranks must call it. + # It restores all objects registered via accelerator.prepare() in order: + # 1. Student model weights (self._transformer) — including quantizer scales + # 2. Optimizer state (self._optimizer) + # 3. LR scheduler state (self._lr_scheduler) + # 4. Dataloader iteration position (via skip_first_batches internally) + # 5. RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + # 6. Gradient scaler (mixed precision) + # Note: Teacher model was unregistered from accelerator (see _load_teacher_model), + # so it is NOT loaded here — it is loaded fresh from pretrained on each run. + self._accelerator.load_state(str(checkpoint_dir)) + logger.info( + "Restored: student model (with quantizer scales), optimizer, LR scheduler, " + "dataloader position, RNG states, and gradient scaler via accelerator.load_state()" + ) + + # Load custom metadata to get global_step + metadata_path = checkpoint_dir / "distillation_metadata.json" + if metadata_path.exists(): + with open(metadata_path) as f: + metadata = json.load(f) + resumed_step = metadata.get("global_step", 0) + logger.info(f"Restored global_step={resumed_step} from metadata") + else: + # Fallback: try to parse step from directory name + try: + resumed_step = int(checkpoint_dir.name.split("_")[-1]) + logger.warning( + f"Metadata file not found, parsed step from dir name: {resumed_step}" + ) + except (ValueError, IndexError): + resumed_step = 0 + logger.warning("Could not determine step from checkpoint, resuming from step 0") + + return resumed_step + + def train( + self, + disable_progress_bars: bool = False, + step_callback=None, + ) -> tuple[Path | None, dict]: + """ + Override parent train() to add full checkpoint resume support. + + When `distillation.resume_from_checkpoint` is set, this: + 1. Initializes optimizer/dataloader/scheduler as normal + 2. Loads full training state (model, optimizer, scheduler, RNG) + 3. Skips already-completed steps + 4. Saves full training state at checkpoint intervals + """ + from accelerate.utils import DistributedType, set_seed + from ltx_trainer.gpu_utils import get_gpu_memory_gb + from ltx_trainer.hf_hub_utils import push_to_hub + from ltx_trainer.progress import TrainingProgress + from ltx_trainer.trainer import TrainingStats + + MEMORY_CHECK_INTERVAL = 200 # noqa: N806 + + device = self._accelerator.device + cfg = self._config + start_mem = get_gpu_memory_gb(device) + + train_start_time = time.time() + + # Use the same seed for all processes and ensure deterministic operations + set_seed(cfg.seed) + logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}") + + self._init_optimizer() + self._init_dataloader() + self._init_timestep_sampler() + + # Synchronize all processes after initialization + self._accelerator.wait_for_everyone() + + Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) + + # Save the training configuration as YAML + self._save_config() + + # ===================================================================== + # Resume from checkpoint if configured + # ===================================================================== + resume_step = 0 + resume_path = self._distillation_config.resume_from_checkpoint + if resume_path is not None: + checkpoint_dir = self._find_resume_checkpoint(resume_path) + if checkpoint_dir is not None: + resume_step = self._load_training_state(checkpoint_dir) + logger.info(f"Resuming training from step {resume_step}") + else: + logger.warning("No checkpoint found to resume from, starting from scratch") + + # Create the dataloader iterator AFTER load_state() so it picks up the + # resumed dataloader state. accelerator.load_state() automatically replaces + # self._dataloader with a version that skips already-consumed batches + # (via skip_first_batches), so iter() here starts at the correct position. + data_iter = iter(self._dataloader) + + # Timer for Slurm time-limit-aware checkpointing + must_save_by_minutes = self._distillation_config.must_save_by + if must_save_by_minutes is not None: + save_deadline = train_start_time + must_save_by_minutes * 60 + logger.info( + f"Time-limit save enabled: will save and exit after " + f"{must_save_by_minutes:.1f} minutes" + ) + else: + save_deadline = None + + logger.info("Starting training...") + config_msg = ( + f"Config: steps={cfg.optimization.steps}, " + f"grad_accum={cfg.optimization.gradient_accumulation_steps}, " + f"checkpoints.interval={cfg.checkpoints.interval}, " + f"checkpoints.keep_last_n={cfg.checkpoints.keep_last_n}, " + f"output_dir={cfg.output_dir}, " + f"must_save_by={must_save_by_minutes}" + ) + logger.info(config_msg) + # Also print to stdout (logger goes to stderr via RichHandler, + # which lands in .err files in Slurm) + if IS_MAIN_PROCESS: + print(f"[distillation_trainer] {config_msg}", flush=True) + + # Create progress tracking + progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars + progress = TrainingProgress( + enabled=progress_enabled, + total_steps=cfg.optimization.steps, + ) + + if IS_MAIN_PROCESS and disable_progress_bars: + logger.warning( + "Progress bars disabled. Intermediate status messages will be logged instead." + ) + + self._transformer.train() + self._global_step = resume_step + + peak_mem_during_training = start_mem + + sampled_videos_paths = None + + # Calculate how many raw steps to skip and how many to run + total_raw_steps = cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps + skip_raw_steps = resume_step * cfg.optimization.gradient_accumulation_steps + + with progress: + # Initial validation before training starts (skip if resuming) + if ( + resume_step == 0 + and cfg.validation.interval + and not cfg.validation.skip_initial_validation + ): + sampled_videos_paths = self._sample_videos(progress) + if ( + IS_MAIN_PROCESS + and sampled_videos_paths + and self._config.wandb.log_validation_videos + ): + self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts) + + self._accelerator.wait_for_everyone() + + # Accumulators for averaging metrics across gradient accumulation steps + grad_accum_steps = cfg.optimization.gradient_accumulation_steps + accum_loss = 0.0 + accum_step_time = 0.0 + + for step in range(skip_raw_steps, total_raw_steps): + # Get next batch, reset the dataloader if needed + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(self._dataloader) + batch = next(data_iter) + + step_start_time = time.time() + with self._accelerator.accumulate(self._transformer): + is_optimization_step = (step + 1) % grad_accum_steps == 0 + if is_optimization_step: + self._global_step += 1 + + loss = self._training_step(batch) + self._accelerator.backward(loss) + + # Accumulate metrics for this micro-batch + accum_loss += loss.item() + accum_step_time += time.time() - step_start_time + + if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0: + self._accelerator.clip_grad_norm_( + self._trainable_params, + cfg.optimization.max_grad_norm, + ) + + self._optimizer.step() + self._optimizer.zero_grad() + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + # Run validation if needed + if ( + cfg.validation.interval + and self._global_step > 0 + and self._global_step % cfg.validation.interval == 0 + and is_optimization_step + ): + if self._accelerator.distributed_type == DistributedType.FSDP: + sampled_videos_paths = self._sample_videos(progress) + if ( + IS_MAIN_PROCESS + and sampled_videos_paths + and self._config.wandb.log_validation_videos + ): + self._log_validation_samples( + sampled_videos_paths, cfg.validation.prompts + ) + elif IS_MAIN_PROCESS: + sampled_videos_paths = self._sample_videos(progress) + if sampled_videos_paths and self._config.wandb.log_validation_videos: + self._log_validation_samples( + sampled_videos_paths, cfg.validation.prompts + ) + + # Save training state for resuming (model, optimizer, scheduler, + # dataloader position, RNG states — all handled by accelerator) + saved_this_step = False + ckpt_interval = cfg.checkpoints.interval + if ( + ckpt_interval + and self._global_step > 0 + and self._global_step % ckpt_interval == 0 + and is_optimization_step + ): + logger.info( + f"Saving checkpoint at step {self._global_step} " + f"(interval={ckpt_interval})..." + ) + self._save_training_state() + saved_this_step = True + + # Time-limit save: if we're approaching the Slurm time limit, + # save a checkpoint and exit gracefully. + if ( + save_deadline is not None + and is_optimization_step + and time.time() >= save_deadline + ): + elapsed_min = (time.time() - train_start_time) / 60 + logger.info( + f"Time limit reached ({elapsed_min:.1f} min >= " + f"{must_save_by_minutes:.1f} min). " + f"Saving checkpoint at step {self._global_step} and exiting..." + ) + if not saved_this_step: + self._save_training_state() + # Break out of the training loop; post-loop code + # will collect stats and return. + break + + self._accelerator.wait_for_everyone() + + # Call step callback if provided + if step_callback and is_optimization_step: + step_callback( + self._global_step, cfg.optimization.steps, sampled_videos_paths + ) + + self._accelerator.wait_for_everyone() + + # On optimization steps: compute averaged metrics, log, then reset + if is_optimization_step: + avg_loss = accum_loss / grad_accum_steps + total_step_time = accum_step_time + + current_lr = self._optimizer.param_groups[0]["lr"] + + progress.update_training( + loss=avg_loss, + lr=current_lr, + step_time=total_step_time, + advance=True, + ) + + # Log averaged metrics to W&B + if IS_MAIN_PROCESS: + self._log_metrics( + { + "train/loss": avg_loss, + "train/learning_rate": current_lr, + "train/step_time": total_step_time, + "train/global_step": self._global_step, + } + ) + + # Periodic step logging to console/Slurm logs + if IS_MAIN_PROCESS and self._global_step % 10 == 0: + elapsed = time.time() - train_start_time + progress_pct = self._global_step / cfg.optimization.steps + if progress_pct > 0: + eta = (elapsed / progress_pct) - elapsed + eta_str = f"{eta // 3600:.0f}h {(eta % 3600) // 60:.0f}m" + else: + eta_str = "calculating..." + logger.info( + f"Step {self._global_step}/{cfg.optimization.steps} | " + f"Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | " + f"Time/Step: {total_step_time:.2f}s | ETA: {eta_str}", + ) + + # Reset accumulators + accum_loss = 0.0 + accum_step_time = 0.0 + + # Sample GPU memory periodically + if step % MEMORY_CHECK_INTERVAL == 0: + current_mem = get_gpu_memory_gb(device) + peak_mem_during_training = max(peak_mem_during_training, current_mem) + + # Collect final stats + train_end_time = time.time() + end_mem = get_gpu_memory_gb(device) + peak_mem = max(start_mem, end_mem, peak_mem_during_training) + + total_time_seconds = train_end_time - train_start_time + actual_steps = self._global_step - resume_step + steps_per_second = actual_steps / total_time_seconds if total_time_seconds > 0 else 0 + samples_per_second = ( + steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size + ) + + stats = TrainingStats( + total_time_seconds=total_time_seconds, + steps_per_second=steps_per_second, + samples_per_second=samples_per_second, + peak_gpu_memory_gb=peak_mem, + num_processes=self._accelerator.num_processes, + global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes, + ) + + # Save final training state (for potential resume) + self._save_training_state() + + # Save inference-ready model weights (standalone safetensors file) + saved_path = self._save_checkpoint() + + if is_global_rank0(): + self._log_training_stats(stats) + + if cfg.hub.push_to_hub: + push_to_hub(saved_path, sampled_videos_paths, self._config) + + if self._wandb_run is not None: + self._log_metrics( + { + "stats/total_time_minutes": stats.total_time_seconds / 60, + "stats/steps_per_second": stats.steps_per_second, + "stats/samples_per_second": stats.samples_per_second, + "stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb, + } + ) + self._wandb_run.finish() + + self._accelerator.wait_for_everyone() + self._accelerator.end_training() + + return saved_path, stats + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="LTX-2 Distillation Training with ModelOpt Quantization", + # Allow OmegaConf-style overrides to pass through + allow_abbrev=False, + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the YAML configuration file", + ) + + # Parse known args to allow for OmegaConf overrides + args, overrides = parser.parse_known_args() + return args, overrides + + +def main(): + """Main entry point for distillation training.""" + # CRITICAL: Set CUDA device BEFORE any model loading. + # + # The LTX trainer loads the text encoder in __init__ BEFORE _setup_accelerator(), + # using device="cuda" which defaults to GPU 0. We must set the device early + # so that "cuda" maps to the correct GPU for each process. + # + # Note: We do NOT call init_process_group() here - let accelerate handle that. + # We only set the CUDA device based on LOCAL_RANK. + + # Read distributed environment variables (set by accelerate launch / torchrun) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "29500") + + # Debug: Print all relevant environment variables + print( + f"[DEBUG] PID={os.getpid()} RANK={rank} LOCAL_RANK={local_rank} " + f"WORLD_SIZE={world_size} MASTER_ADDR={master_addr} MASTER_PORT={master_port}" + ) + print(f"[DEBUG] torch.cuda.device_count()={torch.cuda.device_count()}") + + # Set CUDA device based on LOCAL_RANK - this ensures device="cuda" uses correct GPU + if torch.cuda.is_available() and local_rank < torch.cuda.device_count(): + torch.cuda.set_device(local_rank) + print( + f"[DEBUG] Set CUDA device to {local_rank}, current device: {torch.cuda.current_device()}" + ) + else: + print(f"[WARNING] LOCAL_RANK={local_rank} but device_count={torch.cuda.device_count()}") + + logger.info(f"Process RANK={rank}, LOCAL_RANK={local_rank}, WORLD_SIZE={world_size}") + + args, cli_overrides = parse_args() + + # Load base config from YAML using OmegaConf + base_config = OmegaConf.load(args.config) + + # Parse CLI overrides using OmegaConf + # Supports formats like: + # distillation.distillation_alpha=0.6 + # ++distillation.quant_cfg=FP8_DEFAULT_CFG + # model.training_mode=lora + if cli_overrides: + # Clean up override strings (remove leading ++, +, etc.) + cleaned_overrides = [] + for override in cli_overrides: + # Strip leading + or ++ (Hydra-style) + clean = override.lstrip("+") + if "=" in clean: + cleaned_overrides.append(clean) + elif IS_MAIN_PROCESS: + logger.warning(f"Ignoring malformed override: {override}") + + if cleaned_overrides: + cli_config = OmegaConf.from_dotlist(cleaned_overrides) + # Merge CLI overrides into base config (CLI takes precedence) + config = OmegaConf.merge(base_config, cli_config) + if IS_MAIN_PROCESS: + logger.info(f"Applied {len(cleaned_overrides)} config overrides:") + for override in cleaned_overrides: + logger.info(f" {override}") + else: + config = base_config + else: + config = base_config + + # Convert OmegaConf to plain dict for Pydantic + config_dict = OmegaConf.to_container(config, resolve=True) + + # Create typed config object + config = DistillationTrainerConfig(**config_dict) + + # Create trainer and run + trainer = DistillationTrainer(config) + + # Train + saved_path, stats = trainer.train() + + # Save quantized model if configured + if config.distillation.quant_cfg is not None: + trainer.save_quantized_model() + + if IS_MAIN_PROCESS: + logger.info(f"Training complete. Model saved to: {saved_path}") + logger.info(f"Training stats: {stats}") + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/distillation/requirements.txt b/examples/diffusers/distillation/requirements.txt new file mode 100644 index 0000000000..964edf625d --- /dev/null +++ b/examples/diffusers/distillation/requirements.txt @@ -0,0 +1,4 @@ +ltx-core @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core +ltx-pipelines @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines +ltx-trainer @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-trainer +omegaconf diff --git a/examples/diffusers/quantization/ONNX-TRT-Deployment.md b/examples/diffusers/quantization/ONNX-TRT-Deployment.md new file mode 100644 index 0000000000..57448b8a38 --- /dev/null +++ b/examples/diffusers/quantization/ONNX-TRT-Deployment.md @@ -0,0 +1,149 @@ +# ONNX Export and TensorRT Engine Build + +This page covers the optional ONNX export + TensorRT engine workflow for diffusion models. +For quantization-only workflows, refer to `../README.md`. + +## Quantize and export ONNX + +### 8-bit Quantize and ONNX Export [Script](./build_sdxl_8bit_engine.sh) + +You can run the following script to quantize SDXL backbone to INT8 or FP8 and generate an ONNX model built with default settings for SDXL. You can then directly head to the [Build the TRT engine for the Quantized ONNX Backbone](#build-the-trt-engine-for-the-quantized-onnx-backbone) section to run E2E pipeline and generate images. + +```sh +bash build_sdxl_8bit_engine.sh --format {FORMAT} # FORMAT can be int8 or fp8 +``` + +If you prefer to customize parameters in calibration or run other models, please follow the instructions below. + +#### FLUX-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantize.py) + +```sh +python quantize.py \ + --model {flux-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \ + --format int8 --batch-size 2 \ + --calib-size 32 --alpha 0.8 --n-steps 20 \ + --model-dtype {Half/BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ + --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR} +``` + +#### FLUX-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantize.py) + +*In our example code, FP4 is only supported for Flux. However, you can modify our script to enable FP4 format support for your own model.* + +```sh +python quantize.py \ + --model {flux-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ + --format {fp8|fp4} --batch-size 2 --calib-size {128|256} --quantize-mha \ + --n-steps 20 --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --collect-method default \ + --onnx-dir {ONNX_DIR} +``` + +We recommend using a device with a minimum of 48GB of combined CPU and GPU memory for exporting ONNX models. If not, please use CPU for ONNX export. + +## Build the TRT engine for the Quantized ONNX Backbone + +> [!IMPORTANT] +> TensorRT environment must be setup prior -- Please see [Pre-Requisites](../README.md#pre-requisites) +> INT8 requires **TensorRT version >= 9.2.0**. If you prefer to use the FP8 TensorRT, ensure you have **TensorRT version 10.2.0 or higher**. You can download the latest version of TensorRT at [here](https://developer.nvidia.com/tensorrt/download). Deployment of SVDQuant is currently not supported. + +Generate INT8/FP8 Backbone Engine + +```bash +# For SDXL +trtexec --builderOptimizationLevel=4 --stronglyTyped --onnx=./model.onnx \ + --minShapes=sample:2x4x128x128,timestep:1,encoder_hidden_states:2x77x2048,text_embeds:2x1280,time_ids:2x6 \ + --optShapes=sample:16x4x128x128,timestep:1,encoder_hidden_states:16x77x2048,text_embeds:16x1280,time_ids:16x6 \ + --maxShapes=sample:16x4x128x128,timestep:1,encoder_hidden_states:16x77x2048,text_embeds:16x1280,time_ids:16x6 \ + --saveEngine=model.plan + +# For SD3-Medium +trtexec --builderOptimizationLevel=4 --stronglyTyped --onnx=./model.onnx \ + --minShapes=hidden_states:2x16x128x128,timestep:2,encoder_hidden_states:2x333x4096,pooled_projections:2x2048 \ + --optShapes=hidden_states:16x16x128x128,timestep:16,encoder_hidden_states:16x333x4096,pooled_projections:16x2048 \ + --maxShapes=hidden_states:16x16x128x128,timestep:16,encoder_hidden_states:16x333x4096,pooled_projections:16x2048 \ + --saveEngine=model.plan + +# For FLUX-Dev FP8 +trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \ + --minShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ + --optShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ + --maxShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ + --saveEngine=model.plan +``` + +**Please note that `maxShapes` represents the maximum shape of the given tensor. If you want to use a larger batch size or any other dimensions, feel free to adjust the value accordingly.** + +## Run End-to-end Stable Diffusion Pipeline with Model Optimizer Quantized ONNX Model and demoDiffusion + +### demoDiffusion + +If you want to run end-to-end SD/SDXL pipeline with Model Optimizer quantized UNet to generate images and measure latency on target GPUs, here are the steps: + +- Clone a copy of [demo/Diffusion repo](https://github.com/NVIDIA/TensorRT/tree/release/10.2/demo/Diffusion). + +- Following the README from demoDiffusion to set up the pipeline, and run a baseline txt2img example (fp16): + +```sh +# SDXL +python demo_txt2img_xl.py "enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow" --negative-prompt "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" --version xl-1.0 --scheduler Euler --denoising-steps 30 --seed 2946901 +# Please refer to the examples provided in the demoDiffusion SD/SDXL pipeline. +``` + +Note, it will take some time to build TRT engines for the first time + +- Replace the fp16 backbone TRT engine with int8 engine generated in [Build the TRT engine for the Quantized ONNX Backbone](#build-the-trt-engine-for-the-quantized-onnx-backbone), e.g.,: + +```sh +cp -r {YOUR_UNETXL}.plan ./engine/ +``` + +Note, the engines must be built on the same GPU, and ensure that the INT8 engine name matches the names of the FP16 engines to enable compatibility with the demoDiffusion pipeline. + +- Run the above txt2img example command again. You can compare the generated images and latency for fp16 vs int8. + Similarly, you could run end-to-end pipeline with Model Optimizer quantized backbone and corresponding examples in demoDiffusion with other diffusion models. + +## Running the inference pipeline with DeviceModel + +DeviceModel is an interface designed to run TensorRT engines like torch models. It takes torch inputs and returns torch outputs. Under the hood, DeviceModel exports a torch checkpoint to ONNX and then generates a TensorRT engine from it. This allows you to swap the backbone of the diffusion pipeline with DeviceModel and execute the pipeline for your desired prompt. + +Generate a quantized torch checkpoint using the [Script](./quantize.py) shown below: + +```bash +python quantize.py \ + --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \ + --format fp8 \ + --batch-size {1|2} \ + --calib-size 128 \ + --n-steps 20 \ + --quantized-torch-ckpt-save-path ./{MODEL}_fp8.pt \ + --collect-method default +``` + +Generate images for the quantized checkpoint with the following [Script](./diffusion_trt.py): + +```bash +python diffusion_trt.py \ + --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \ + --prompt "A cat holding a sign that says hello world" \ + [--override-model-path /path/to/model] \ + [--restore-from ./{MODEL}_fp8.pt] \ + [--onnx-load-path {ONNX_DIR}] \ + [--trt-engine-load-path {ENGINE_DIR}] \ + [--dq-only] \ + [--torch] \ + [--save-image-as /path/to/image] \ + [--benchmark] \ + [--torch-compile] \ + [--skip-image] +``` + +This script will save the output image as `./{MODEL}.png` and report the latency of the TensorRT backbone. +To generate the image with FP16|BF16 precision, you can run the command shown above without the `--restore-from` argument. + +While loading a TensorRT engine using the --trt-engine-load-path argument, it is recommended to load only engines generated using this pipeline. + +### Demo Images + +| SDXL FP16 | SDXL INT8 | +|:---------:|:---------:| +| ![FP16](./assets/xl_base-fp16.png) | ![INT8](./assets/xl_base-int8.png) | diff --git a/examples/diffusers/quantization/calibration.py b/examples/diffusers/quantization/calibration.py new file mode 100644 index 0000000000..aa5d378475 --- /dev/null +++ b/examples/diffusers/quantization/calibration.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any + +from models_utils import MODEL_DEFAULTS, ModelType +from pipeline_manager import PipelineManager +from quantize_config import CalibrationConfig +from tqdm import tqdm +from utils import load_calib_prompts + + +class Calibrator: + """Handles model calibration for quantization.""" + + def __init__( + self, + pipeline_manager: PipelineManager, + config: CalibrationConfig, + model_type: ModelType, + logger: logging.Logger, + ): + """ + Initialize calibrator. + + Args: + pipeline_manager: Pipeline manager with main and upsampler pipelines + config: Calibration configuration + model_type: Type of model being calibrated + logger: Logger instance + """ + self.pipeline_manager = pipeline_manager + self.pipe = pipeline_manager.pipe + self.pipe_upsample = pipeline_manager.pipe_upsample + self.config = config + self.model_type = model_type + self.logger = logger + + def load_and_batch_prompts(self) -> list[list[str]]: + """ + Load calibration prompts from file. + + Returns: + List of batched calibration prompts + """ + self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}") + if isinstance(self.config.prompts_dataset, Path): + return load_calib_prompts( + self.config.batch_size, + self.config.prompts_dataset, + ) + + return load_calib_prompts( + self.config.batch_size, + self.config.prompts_dataset["name"], + self.config.prompts_dataset["split"], + self.config.prompts_dataset["column"], + ) + + def run_calibration(self, batched_prompts: list[list[str]]) -> None: + """ + Run calibration steps on the pipeline. + + Args: + batched_prompts: List of batched calibration prompts + """ + self.logger.info(f"Starting calibration with {self.config.num_batches} batches") + extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {}) + + with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: + for i, prompt_batch in enumerate(batched_prompts): + if i >= self.config.num_batches: + break + + if self.model_type == ModelType.LTX2: + self._run_ltx2_calibration(prompt_batch, extra_args) + elif self.model_type == ModelType.LTX_VIDEO_DEV: + # Special handling for LTX-Video + self._run_ltx_video_calibration(prompt_batch, extra_args) + elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: + # Special handling for WAN video models + self._run_wan_video_calibration(prompt_batch, extra_args) + else: + common_args = { + "prompt": prompt_batch, + "num_inference_steps": self.config.n_steps, + } + self.pipe(**common_args, **extra_args).images + pbar.update(1) + self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") + self.logger.info("Calibration completed successfully") + + def _run_wan_video_calibration( + self, prompt_batch: list[str], extra_args: dict[str, Any] + ) -> None: + kwargs = {} + kwargs["negative_prompt"] = extra_args["negative_prompt"] + kwargs["height"] = extra_args["height"] + kwargs["width"] = extra_args["width"] + kwargs["num_frames"] = extra_args["num_frames"] + kwargs["guidance_scale"] = extra_args["guidance_scale"] + if "guidance_scale_2" in extra_args: + kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] + kwargs["num_inference_steps"] = self.config.n_steps + + self.pipe(prompt=prompt_batch, **kwargs).frames + + def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None: + from ltx_core.model.video_vae import TilingConfig + + prompt = prompt_batch[0] + extra_params = self.pipeline_manager.config.extra_params + kwargs = { + "negative_prompt": extra_args.get( + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" + ), + "seed": extra_params.get("seed", 0), + "height": extra_params.get("height", extra_args.get("height", 1024)), + "width": extra_params.get("width", extra_args.get("width", 1536)), + "num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)), + "frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)), + "num_inference_steps": self.config.n_steps, + "cfg_guidance_scale": extra_params.get( + "cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0) + ), + "images": extra_params.get("images", []), + "tiling_config": extra_params.get("tiling_config", TilingConfig.default()), + } + self.pipe(prompt=prompt, **kwargs) + + def _run_ltx_video_calibration( + self, prompt_batch: list[str], extra_args: dict[str, Any] + ) -> None: + """ + Run calibration for LTX-Video model using the full multi-stage pipeline. + + Args: + prompt_batch: Batch of prompts + extra_args: Model-specific arguments + """ + # Extract specific args for LTX-Video + expected_height = extra_args.get("height", 512) + expected_width = extra_args.get("width", 704) + num_frames = extra_args.get("num_frames", 121) + negative_prompt = extra_args.get( + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" + ) + + def round_to_nearest_resolution_acceptable_by_vae(height, width): + height = height - (height % self.pipe.vae_spatial_compression_ratio) + width = width - (width % self.pipe.vae_spatial_compression_ratio) + return height, width + + downscale_factor = 2 / 3 + # Part 1: Generate video at smaller resolution + downscaled_height, downscaled_width = ( + int(expected_height * downscale_factor), + int(expected_width * downscale_factor), + ) + downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( + downscaled_height, downscaled_width + ) + + # Generate initial latents at lower resolution + latents = self.pipe( + conditions=None, + prompt=prompt_batch, + negative_prompt=negative_prompt, + width=downscaled_width, + height=downscaled_height, + num_frames=num_frames, + num_inference_steps=self.config.n_steps, + output_type="latent", + ).frames + + # Part 2: Upscale generated video using latent upsampler (if available) + if self.pipe_upsample is not None: + _ = self.pipe_upsample(latents=latents, output_type="latent").frames + + # Part 3: Denoise the upscaled video with few steps to improve texture + # However, in this example code, we will omit the upscale step since its optional. diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 0c151f8d52..d8d8b198b2 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -21,7 +21,6 @@ "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, "*input_quantizer": {"num_bits": (4, 3), "axis": None}, "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": {"num_bits": (4, 3), "axis": None}, "*softmax_quantizer": { "num_bits": (4, 3), "axis": None, @@ -56,7 +55,6 @@ "enable": True, }, "*output_quantizer": {"enable": False}, - "*[qkv]_bmm_quantizer": {"num_bits": (4, 3), "axis": None}, "*softmax_quantizer": { "num_bits": (4, 3), "axis": None, diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py new file mode 100644 index 0000000000..9a061622e0 --- /dev/null +++ b/examples/diffusers/quantization/models_utils.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections.abc import Callable +from enum import Enum +from typing import Any + +from diffusers import ( + DiffusionPipeline, + FluxPipeline, + LTXConditionPipeline, + StableDiffusion3Pipeline, + WanPipeline, +) +from utils import ( + filter_func_default, + filter_func_flux_dev, + filter_func_ltx_video, + filter_func_wan_video, +) + + +class ModelType(str, Enum): + """Supported model types.""" + + SDXL_BASE = "sdxl-1.0" + SDXL_TURBO = "sdxl-turbo" + SD3_MEDIUM = "sd3-medium" + SD35_MEDIUM = "sd3.5-medium" + FLUX_DEV = "flux-dev" + FLUX_SCHNELL = "flux-schnell" + LTX_VIDEO_DEV = "ltx-video-dev" + LTX2 = "ltx-2" + WAN22_T2V_14b = "wan2.2-t2v-14b" + WAN22_T2V_5b = "wan2.2-t2v-5b" + + +def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: + """ + Get the appropriate filter function for a given model type. + + Args: + model_type: The model type enum + + Returns: + A filter function appropriate for the model type + """ + filter_func_map = { + ModelType.FLUX_DEV: filter_func_flux_dev, + ModelType.FLUX_SCHNELL: filter_func_default, + ModelType.SDXL_BASE: filter_func_default, + ModelType.SDXL_TURBO: filter_func_default, + ModelType.SD3_MEDIUM: filter_func_default, + ModelType.SD35_MEDIUM: filter_func_default, + ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, + ModelType.LTX2: filter_func_ltx_video, + ModelType.WAN22_T2V_14b: filter_func_wan_video, + ModelType.WAN22_T2V_5b: filter_func_wan_video, + } + + return filter_func_map.get(model_type, filter_func_default) + + +# Model registry with HuggingFace model IDs +MODEL_REGISTRY: dict[ModelType, str] = { + ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0", + ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo", + ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers", + ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", + ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", + ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", + ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", + ModelType.LTX2: "Lightricks/LTX-2", + ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", +} + +MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { + ModelType.SDXL_BASE: DiffusionPipeline, + ModelType.SDXL_TURBO: DiffusionPipeline, + ModelType.SD3_MEDIUM: StableDiffusion3Pipeline, + ModelType.SD35_MEDIUM: StableDiffusion3Pipeline, + ModelType.FLUX_DEV: FluxPipeline, + ModelType.FLUX_SCHNELL: FluxPipeline, + ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, + ModelType.LTX2: None, + ModelType.WAN22_T2V_14b: WanPipeline, + ModelType.WAN22_T2V_5b: WanPipeline, +} + +# Shared dataset configurations +_SD_PROMPTS_DATASET = { + "name": "Gustavosta/Stable-Diffusion-Prompts", + "split": "train", + "column": "Prompt", +} + +_OPENVID_DATASET = { + "name": "nkp37/OpenVid-1M", + "split": "train", + "column": "caption", +} + +# Model family base configurations +_SDXL_BASE_CONFIG: dict[str, Any] = { + "backbone": "unet", + "dataset": _SD_PROMPTS_DATASET, +} + +_SD3_BASE_CONFIG: dict[str, Any] = { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, +} + +_FLUX_BASE_CONFIG: dict[str, Any] = { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1024, + "guidance_scale": 3.5, + "max_sequence_length": 512, + }, +} + +_WAN_BASE_CONFIG: dict[str, Any] = { + "backbone": "transformer", + "dataset": _OPENVID_DATASET, +} + +# Model-specific default arguments for calibration +MODEL_DEFAULTS: dict[ModelType, dict[str, Any]] = { + ModelType.SDXL_BASE: _SDXL_BASE_CONFIG, + ModelType.SDXL_TURBO: _SDXL_BASE_CONFIG, + ModelType.SD3_MEDIUM: _SD3_BASE_CONFIG, + ModelType.SD35_MEDIUM: _SD3_BASE_CONFIG, + ModelType.FLUX_DEV: _FLUX_BASE_CONFIG, + ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG, + ModelType.LTX_VIDEO_DEV: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 512, + "width": 704, + "num_frames": 121, + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + }, + }, + ModelType.LTX2: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1536, + "num_frames": 121, + "frame_rate": 24.0, + "cfg_guidance_scale": 4.0, + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + }, + }, + ModelType.WAN22_T2V_14b: { + **_WAN_BASE_CONFIG, + "from_pretrained_extra_args": { + "boundary_ratio": 0.875, + }, + "inference_extra_args": { + "height": 720, + "width": 1280, + "num_frames": 81, + "fps": 16, + "guidance_scale": 4.0, + "guidance_scale_2": 3.0, + "negative_prompt": ( + "vivid colors, overexposed, static, blurry details, subtitles, style, " + "work of art, painting, picture, still, overall grayish, worst quality, " + "low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, " + "poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, " + "static image, cluttered background, three legs, many people in the background, " + "walking backwards" + ), + }, + }, + ModelType.WAN22_T2V_5b: { + **_WAN_BASE_CONFIG, + "inference_extra_args": { + "height": 512, + "width": 768, + "num_frames": 81, + "fps": 16, + "guidance_scale": 5.0, + "negative_prompt": ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留" # noqa: RUF001 + ",丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体," # noqa: RUF001 + "手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # noqa: RUF001 + ), + }, + }, +} + + +def _coerce_extra_param_value(value: str) -> Any: + lowered = value.lower() + if lowered in {"true", "false"}: + return lowered == "true" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + return value + + +def parse_extra_params( + kv_args: list[str], unknown_args: list[str], logger: logging.Logger +) -> dict[str, Any]: + extra_params: dict[str, Any] = {} + for item in kv_args: + if "=" not in item: + raise ValueError(f"Invalid --extra-param value: '{item}'. Expected KEY=VALUE.") + key, value = item.split("=", 1) + extra_params[key] = _coerce_extra_param_value(value) + + i = 0 + while i < len(unknown_args): + token = unknown_args[i] + if token.startswith("--extra_param."): + key = token[len("--extra_param.") :] + value = "true" + if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"): + value = unknown_args[i + 1] + i += 1 + extra_params[key] = _coerce_extra_param_value(value) + elif token.startswith("--extra_param"): + raise ValueError( + "Use --extra_param.KEY VALUE or --extra-param KEY=VALUE for extra parameters." + ) + else: + logger.warning("Ignoring unknown argument: %s", token) + i += 1 + + return extra_params diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py new file mode 100644 index 0000000000..a507d24fbd --- /dev/null +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections.abc import Iterator +from typing import Any + +import torch +from diffusers import DiffusionPipeline, LTXLatentUpsamplePipeline +from models_utils import MODEL_DEFAULTS, MODEL_PIPELINE, MODEL_REGISTRY, ModelType +from quantize_config import ModelConfig + +import modelopt.torch.quantization as mtq + + +class PipelineManager: + """Manages diffusion pipeline creation and configuration.""" + + def __init__(self, config: ModelConfig, logger: logging.Logger): + """ + Initialize pipeline manager. + + Args: + config: Model configuration + logger: Logger instance + """ + self.config = config + self.logger = logger + self.pipe: Any | None = None + self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling + self._transformer: torch.nn.Module | None = None + + @staticmethod + def create_pipeline_from( + model_type: ModelType, + torch_dtype: torch.dtype | dict[str, str | torch.dtype] = torch.bfloat16, + override_model_path: str | None = None, + ) -> DiffusionPipeline: + """ + Create and return an appropriate pipeline based on configuration. + + Returns: + Configured diffusion pipeline + + Raises: + ValueError: If model type is unsupported + """ + try: + pipeline_cls = MODEL_PIPELINE[model_type] + if pipeline_cls is None: + raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path + ) + pipe = pipeline_cls.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), + ) + pipe.set_progress_bar_config(disable=True) + return pipe + except Exception as e: + raise e + + def create_pipeline(self) -> Any: + """ + Create and return an appropriate pipeline based on configuration. + + Returns: + Configured diffusion pipeline + + Raises: + ValueError: If model type is unsupported + """ + self.logger.info(f"Creating pipeline for {self.config.model_type.value}") + self.logger.info(f"Model path: {self.config.model_path}") + self.logger.info(f"Data type: {self.config.model_dtype}") + + try: + if self.config.model_type == ModelType.LTX2: + from modelopt.torch.quantization.plugins.diffusion import ltx2 as ltx2_plugin + + ltx2_plugin.register_ltx2_quant_linear() + self.pipe = self._create_ltx2_pipeline() + self.logger.info("LTX-2 pipeline created successfully") + return self.pipe + + pipeline_cls = MODEL_PIPELINE[self.config.model_type] + if pipeline_cls is None: + raise ValueError( + f"Model type {self.config.model_type.value} does not use diffusers pipelines." + ) + self.pipe = pipeline_cls.from_pretrained( + self.config.model_path, + torch_dtype=self.config.model_dtype, + use_safetensors=True, + **MODEL_DEFAULTS[self.config.model_type].get("from_pretrained_extra_args", {}), + ) + if self.config.model_type == ModelType.LTX_VIDEO_DEV: + # Optionally load the upsampler pipeline for LTX-Video + if not self.config.ltx_skip_upsampler: + self.logger.info("Loading LTX-Video upsampler pipeline...") + self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( + "Lightricks/ltxv-spatial-upscaler-0.9.7", + vae=self.pipe.vae, + torch_dtype=self.config.model_dtype, + ) + self.pipe_upsample.set_progress_bar_config(disable=True) + else: + self.logger.info("Skipping upsampler pipeline for faster calibration") + self.pipe.set_progress_bar_config(disable=True) + + self.logger.info("Pipeline created successfully") + return self.pipe + + except Exception as e: + self.logger.error(f"Failed to create pipeline: {e}") + raise + + def setup_device(self) -> None: + """Configure pipeline device placement.""" + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + + if self.config.model_type == ModelType.LTX2: + self.logger.info("Skipping device setup for LTX-2 pipeline (handled internally)") + return + + if self.config.cpu_offloading: + self.logger.info("Enabling CPU offloading for memory efficiency") + self.pipe.enable_model_cpu_offload() + if self.pipe_upsample: + self.pipe_upsample.enable_model_cpu_offload() + else: + self.logger.info("Moving pipeline to CUDA") + self.pipe.to("cuda") + if self.pipe_upsample: + self.logger.info("Moving upsampler pipeline to CUDA") + self.pipe_upsample.to("cuda") + # Enable VAE tiling for LTX-Video to save memory + if self.config.model_type == ModelType.LTX_VIDEO_DEV: + if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"): + self.logger.info("Enabling VAE tiling for LTX-Video") + self.pipe.vae.enable_tiling() + + def get_backbone(self) -> torch.nn.Module: + """ + Get the backbone model (transformer or UNet). + + Returns: + Backbone model module + """ + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + + backbone_pairs = list(self.iter_backbones()) + if len(backbone_pairs) == 1: + return backbone_pairs[0][1] + return torch.nn.ModuleList([module for _, module in backbone_pairs]) + + def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]: + """ + Yield backbone modules by name, based on a backbone spec. + + Yields: + (backbone_name, module) pairs + """ + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + + names = list(self.config.backbone) + + if self.config.model_type == ModelType.LTX2: + self._ensure_ltx2_transformer_cached() + name = names[0] if names else "transformer" + yield name, self._transformer + return + + if not names: + raise RuntimeError("No backbone names provided.") + + for name in names: + module = getattr(self.pipe, name, None) + if module is None: + raise RuntimeError(f"Pipeline missing backbone module '{name}'.") + yield name, module + + def _ensure_ltx2_transformer_cached(self) -> None: + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + if self._transformer is None: + transformer = self.pipe.stage_1_model_ledger.transformer() + self.pipe.stage_1_model_ledger.transformer = lambda: transformer + self._transformer = transformer + + def _create_ltx2_pipeline(self) -> Any: + params = dict(self.config.extra_params) + checkpoint_path = params.pop("checkpoint_path", None) + distilled_lora_path = params.pop("distilled_lora_path", None) + distilled_lora_strength = params.pop("distilled_lora_strength", 0.8) + spatial_upsampler_path = params.pop("spatial_upsampler_path", None) + gemma_root = params.pop("gemma_root", None) + fp8transformer = params.pop("fp8transformer", False) + + if not checkpoint_path: + raise ValueError("Missing required extra_param: checkpoint_path.") + if not distilled_lora_path: + raise ValueError("Missing required extra_param: distilled_lora_path.") + if not spatial_upsampler_path: + raise ValueError("Missing required extra_param: spatial_upsampler_path.") + if not gemma_root: + raise ValueError("Missing required extra_param: gemma_root.") + + from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps + from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline + + distilled_lora = [ + LoraPathStrengthAndSDOps( + str(distilled_lora_path), + float(distilled_lora_strength), + LTXV_LORA_COMFY_RENAMING_MAP, + ) + ] + pipeline_kwargs = { + "checkpoint_path": str(checkpoint_path), + "distilled_lora": distilled_lora, + "spatial_upsampler_path": str(spatial_upsampler_path), + "gemma_root": str(gemma_root), + "loras": [], + "fp8transformer": bool(fp8transformer), + } + pipeline_kwargs.update(params) + return TI2VidTwoStagesPipeline(**pipeline_kwargs) + + def print_quant_summary(self): + backbone_pairs = list(self.iter_backbones()) + for name, backbone in backbone_pairs: + self.logger.info(f"{name} quantization info:") + mtq.print_quant_summary(backbone) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index df2de4faec..bfff207afc 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -17,14 +17,11 @@ import logging import sys import time as time -from collections.abc import Callable -from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any import torch -import torch.nn as nn +from calibration import Calibrator from config import ( FP8_DEFAULT_CONFIG, INT8_DEFAULT_CONFIG, @@ -33,335 +30,25 @@ reset_set_int8_config, set_quant_config_attr, ) - -# This is a workaround for making the onnx export of models that use the torch RMSNorm work. We will -# need to move on to use dynamo based onnx export to properly fix the problem. The issue has been hit -# by both external users https://github.com/NVIDIA/Model-Optimizer/issues/262, and our -# internal users from MLPerf Inference. -# -if __name__ == "__main__": - from diffusers.models.normalization import RMSNorm as DiffuserRMSNorm - - torch.nn.RMSNorm = DiffuserRMSNorm - torch.nn.modules.normalization.RMSNorm = DiffuserRMSNorm - -from diffusers import ( - DiffusionPipeline, - FluxPipeline, - LTXConditionPipeline, - LTXLatentUpsamplePipeline, - StableDiffusion3Pipeline, - WanPipeline, -) +from diffusers import DiffusionPipeline +from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params from onnx_utils.export import generate_fp8_scales, modelopt_export_sd -from tqdm import tqdm -from utils import ( - check_conv_and_mha, - check_lora, - filter_func_default, - filter_func_ltx_video, - filter_func_wan_video, - load_calib_prompts, +from pipeline_manager import PipelineManager +from quantize_config import ( + CalibrationConfig, + CollectMethod, + DataType, + ExportConfig, + ModelConfig, + QuantAlgo, + QuantFormat, + QuantizationConfig, ) +from utils import check_conv_and_mha, check_lora import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq - - -class ModelType(str, Enum): - """Supported model types.""" - - SDXL_BASE = "sdxl-1.0" - SDXL_TURBO = "sdxl-turbo" - SD3_MEDIUM = "sd3-medium" - SD35_MEDIUM = "sd3.5-medium" - FLUX_DEV = "flux-dev" - FLUX_SCHNELL = "flux-schnell" - LTX_VIDEO_DEV = "ltx-video-dev" - WAN22_T2V = "wan2.2-t2v-14b" - - -class DataType(str, Enum): - """Supported data types for model loading.""" - - HALF = "Half" - BFLOAT16 = "BFloat16" - FLOAT = "Float" - - @property - def torch_dtype(self) -> torch.dtype: - return self._dtype_map[self.value] - - -DataType._dtype_map = { - DataType.HALF: torch.float16, - DataType.BFLOAT16: torch.bfloat16, - DataType.FLOAT: torch.float32, -} - - -class QuantFormat(str, Enum): - """Supported quantization formats.""" - - INT8 = "int8" - FP8 = "fp8" - FP4 = "fp4" - - -class QuantAlgo(str, Enum): - """Supported quantization algorithms.""" - - MAX = "max" - SVDQUANT = "svdquant" - SMOOTHQUANT = "smoothquant" - - -class CollectMethod(str, Enum): - """Calibration collection methods.""" - - GLOBAL_MIN = "global_min" - MIN_MAX = "min-max" - MIN_MEAN = "min-mean" - MEAN_MAX = "mean-max" - DEFAULT = "default" - - -def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: - """ - Get the appropriate filter function for a given model type. - - Args: - model_type: The model type enum - - Returns: - A filter function appropriate for the model type - """ - filter_func_map = { - ModelType.FLUX_DEV: filter_func_default, - ModelType.FLUX_SCHNELL: filter_func_default, - ModelType.SDXL_BASE: filter_func_default, - ModelType.SDXL_TURBO: filter_func_default, - ModelType.SD3_MEDIUM: filter_func_default, - ModelType.SD35_MEDIUM: filter_func_default, - ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, - ModelType.WAN22_T2V: filter_func_wan_video, - } - - return filter_func_map.get(model_type, filter_func_default) - - -# Model registry with HuggingFace model IDs -MODEL_REGISTRY: dict[ModelType, str] = { - ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0", - ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo", - ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers", - ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", - ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", - ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", - ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", - ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", -} - -MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = { - ModelType.SDXL_BASE: DiffusionPipeline, - ModelType.SDXL_TURBO: DiffusionPipeline, - ModelType.SD3_MEDIUM: StableDiffusion3Pipeline, - ModelType.SD35_MEDIUM: StableDiffusion3Pipeline, - ModelType.FLUX_DEV: FluxPipeline, - ModelType.FLUX_SCHNELL: FluxPipeline, - ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, - ModelType.WAN22_T2V: WanPipeline, -} - -# Model-specific default arguments for calibration -MODEL_DEFAULTS: dict[ModelType, dict[str, Any]] = { - ModelType.SDXL_BASE: { - "backbone": "unet", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - }, - ModelType.SDXL_TURBO: { - "backbone": "unet", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - }, - ModelType.SD3_MEDIUM: { - "backbone": "transformer", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - }, - ModelType.SD35_MEDIUM: { - "backbone": "transformer", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - }, - ModelType.FLUX_DEV: { - "backbone": "transformer", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - "inference_extra_args": { - "height": 1024, - "width": 1024, - "guidance_scale": 3.5, - "max_sequence_length": 512, - }, - }, - ModelType.FLUX_SCHNELL: { - "backbone": "transformer", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - "inference_extra_args": { - "height": 1024, - "width": 1024, - "guidance_scale": 3.5, - "max_sequence_length": 512, - }, - }, - ModelType.LTX_VIDEO_DEV: { - "backbone": "transformer", - "dataset": { - "name": "Gustavosta/Stable-Diffusion-Prompts", - "split": "train", - "column": "Prompt", - }, - "inference_extra_args": { - "height": 512, - "width": 704, - "num_frames": 121, - "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", - }, - }, - ModelType.WAN22_T2V: { - "backbone": "transformer", - "dataset": {"name": "nkp37/OpenVid-1M", "split": "train", "column": "caption"}, - "from_pretrained_extra_args": { - "boundary_ratio": 0.875, - }, - "inference_extra_args": { - "height": 720, - "width": 1280, - "num_frames": 81, - "fps": 16, - "guidance_scale": 4.0, - "guidance_scale_2": 3.0, - "negative_prompt": ( - "vivid colors, overexposed, static, blurry details, subtitles, style, " - "work of art, painting, picture, still, overall grayish, worst quality, " - "low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, " - "poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, " - "static image, cluttered background, three legs, many people in the background, " - "walking backwards" - ), - }, - }, -} - - -@dataclass -class QuantizationConfig: - """Configuration for model quantization.""" - - format: QuantFormat = QuantFormat.INT8 - algo: QuantAlgo = QuantAlgo.MAX - percentile: float = 1.0 - collect_method: CollectMethod = CollectMethod.DEFAULT - alpha: float = 1.0 # SmoothQuant alpha - lowrank: int = 32 # SVDQuant lowrank - quantize_mha: bool = False - compress: bool = False - - def validate(self) -> None: - """Validate configuration consistency.""" - if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT: - raise NotImplementedError("Only 'default' collect method is implemented for FP8.") - if self.quantize_mha and self.format == QuantFormat.INT8: - raise ValueError("MHA quantization is only supported for FP8, not INT8.") - if self.compress and self.format == QuantFormat.INT8: - raise ValueError("Compression is only supported for FP8 and FP4, not INT8.") - - -@dataclass -class CalibrationConfig: - """Configuration for calibration process.""" - - prompts_dataset: dict | Path - batch_size: int = 2 - calib_size: int = 128 - n_steps: int = 30 - - def validate(self) -> None: - """Validate calibration configuration.""" - if self.batch_size <= 0: - raise ValueError("Batch size must be positive.") - if self.calib_size <= 0: - raise ValueError("Calibration size must be positive.") - if self.n_steps <= 0: - raise ValueError("Number of steps must be positive.") - - @property - def num_batches(self) -> int: - """Calculate number of calibration batches.""" - return self.calib_size // self.batch_size - - -@dataclass -class ModelConfig: - """Configuration for model loading and inference.""" - - model_type: ModelType = ModelType.FLUX_DEV - model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16}) - backbone: str = "" - trt_high_precision_dtype: DataType = DataType.HALF - override_model_path: Path | None = None - cpu_offloading: bool = False - ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration) - - @property - def model_path(self) -> str: - """Get the model path (override or default).""" - if self.override_model_path: - return str(self.override_model_path) - return MODEL_REGISTRY[self.model_type] - - -@dataclass -class ExportConfig: - """Configuration for model export.""" - - quantized_torch_ckpt_path: Path | None = None - onnx_dir: Path | None = None - restore_from: Path | None = None - - def validate(self) -> None: - """Validate export configuration.""" - if self.restore_from and not self.restore_from.exists(): - raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}") - - if self.quantized_torch_ckpt_path: - parent_dir = self.quantized_torch_ckpt_path.parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - - if self.onnx_dir and not self.onnx_dir.exists(): - self.onnx_dir.mkdir(parents=True, exist_ok=True) +from modelopt.torch.export import export_hf_checkpoint def setup_logging(verbose: bool = False) -> logging.Logger: @@ -397,281 +84,6 @@ def setup_logging(verbose: bool = False) -> logging.Logger: return logger -class PipelineManager: - """Manages diffusion pipeline creation and configuration.""" - - def __init__(self, config: ModelConfig, logger: logging.Logger): - """ - Initialize pipeline manager. - - Args: - config: Model configuration - logger: Logger instance - """ - self.config = config - self.logger = logger - self.pipe: DiffusionPipeline | None = None - self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling - - @staticmethod - def create_pipeline_from( - model_type: ModelType, - torch_dtype: torch.dtype | dict[str, str | torch.dtype] = torch.bfloat16, - override_model_path: str | None = None, - ) -> DiffusionPipeline: - """ - Create and return an appropriate pipeline based on configuration. - - Returns: - Configured diffusion pipeline - - Raises: - ValueError: If model type is unsupported - """ - try: - model_id = ( - MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path - ) - pipe = MODEL_PIPELINE[model_type].from_pretrained( - model_id, - torch_dtype=torch_dtype, - use_safetensors=True, - **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), - ) - pipe.set_progress_bar_config(disable=True) - return pipe - except Exception as e: - raise e - - def create_pipeline(self) -> DiffusionPipeline: - """ - Create and return an appropriate pipeline based on configuration. - - Returns: - Configured diffusion pipeline - - Raises: - ValueError: If model type is unsupported - """ - self.logger.info(f"Creating pipeline for {self.config.model_type.value}") - self.logger.info(f"Model path: {self.config.model_path}") - self.logger.info(f"Data type: {self.config.model_dtype}") - - try: - self.pipe = MODEL_PIPELINE[self.config.model_type].from_pretrained( - self.config.model_path, - torch_dtype=self.config.model_dtype, - use_safetensors=True, - **MODEL_DEFAULTS[self.config.model_type].get("from_pretrained_extra_args", {}), - ) - if self.config.model_type == ModelType.LTX_VIDEO_DEV: - # Optionally load the upsampler pipeline for LTX-Video - if not self.config.ltx_skip_upsampler: - self.logger.info("Loading LTX-Video upsampler pipeline...") - self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( - "Lightricks/ltxv-spatial-upscaler-0.9.7", - vae=self.pipe.vae, - torch_dtype=self.config.model_dtype, - ) - self.pipe_upsample.set_progress_bar_config(disable=True) - else: - self.logger.info("Skipping upsampler pipeline for faster calibration") - self.pipe.set_progress_bar_config(disable=True) - - self.logger.info("Pipeline created successfully") - return self.pipe - - except Exception as e: - self.logger.error(f"Failed to create pipeline: {e}") - raise - - def setup_device(self) -> None: - """Configure pipeline device placement.""" - if not self.pipe: - raise RuntimeError("Pipeline not created. Call create_pipeline() first.") - - if self.config.cpu_offloading: - self.logger.info("Enabling CPU offloading for memory efficiency") - self.pipe.enable_model_cpu_offload() - if self.pipe_upsample: - self.pipe_upsample.enable_model_cpu_offload() - else: - self.logger.info("Moving pipeline to CUDA") - self.pipe.to("cuda") - if self.pipe_upsample: - self.logger.info("Moving upsampler pipeline to CUDA") - self.pipe_upsample.to("cuda") - # Enable VAE tiling for LTX-Video to save memory - if self.config.model_type == ModelType.LTX_VIDEO_DEV: - if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"): - self.logger.info("Enabling VAE tiling for LTX-Video") - self.pipe.vae.enable_tiling() - - def get_backbone(self) -> torch.nn.Module: - """ - Get the backbone model (transformer or UNet). - - Returns: - Backbone model module - """ - if not self.pipe: - raise RuntimeError("Pipeline not created. Call create_pipeline() first.") - - return getattr(self.pipe, self.config.backbone) - - -class Calibrator: - """Handles model calibration for quantization.""" - - def __init__( - self, - pipeline_manager: PipelineManager, - config: CalibrationConfig, - model_type: ModelType, - logger: logging.Logger, - ): - """ - Initialize calibrator. - - Args: - pipeline_manager: Pipeline manager with main and upsampler pipelines - config: Calibration configuration - model_type: Type of model being calibrated - logger: Logger instance - """ - self.pipeline_manager = pipeline_manager - self.pipe = pipeline_manager.pipe - self.pipe_upsample = pipeline_manager.pipe_upsample - self.config = config - self.model_type = model_type - self.logger = logger - - def load_and_batch_prompts(self) -> list[list[str]]: - """ - Load calibration prompts from file. - - Returns: - List of batched calibration prompts - """ - self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}") - if isinstance(self.config.prompts_dataset, Path): - return load_calib_prompts( - self.config.batch_size, - self.config.prompts_dataset, - ) - - return load_calib_prompts( - self.config.batch_size, - self.config.prompts_dataset["name"], - self.config.prompts_dataset["split"], - self.config.prompts_dataset["column"], - ) - - def run_calibration(self, batched_prompts: list[list[str]]) -> None: - """ - Run calibration steps on the pipeline. - - Args: - batched_prompts: List of batched calibration prompts - """ - self.logger.info(f"Starting calibration with {self.config.num_batches} batches") - extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {}) - - with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: - for i, prompt_batch in enumerate(batched_prompts): - if i >= self.config.num_batches: - break - - if self.model_type == ModelType.LTX_VIDEO_DEV: - # Special handling for LTX-Video - self._run_ltx_video_calibration(prompt_batch, extra_args) - elif self.model_type == ModelType.WAN22_T2V: - # Special handling for LTX-Video - self._run_wan_video_calibration(prompt_batch, extra_args) - else: - common_args = { - "prompt": prompt_batch, - "num_inference_steps": self.config.n_steps, - } - self.pipe(**common_args, **extra_args).images # type: ignore[misc] - pbar.update(1) - self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") - self.logger.info("Calibration completed successfully") - - def _run_wan_video_calibration( - self, prompt_batch: list[str], extra_args: dict[str, Any] - ) -> None: - negative_prompt = extra_args["negative_prompt"] - height = extra_args["height"] - width = extra_args["width"] - num_frames = extra_args["num_frames"] - guidance_scale = extra_args["guidance_scale"] - guidance_scale_2 = extra_args["guidance_scale_2"] - - self.pipe( - prompt=prompt_batch, - negative_prompt=negative_prompt, - height=height, - width=width, - num_frames=num_frames, - guidance_scale=guidance_scale, - guidance_scale_2=guidance_scale_2, - num_inference_steps=self.config.n_steps, - ).frames # type: ignore[misc] - - def _run_ltx_video_calibration( - self, prompt_batch: list[str], extra_args: dict[str, Any] - ) -> None: - """ - Run calibration for LTX-Video model using the full multi-stage pipeline. - - Args: - prompt_batch: Batch of prompts - extra_args: Model-specific arguments - """ - # Extract specific args for LTX-Video - expected_height = extra_args.get("height", 512) - expected_width = extra_args.get("width", 704) - num_frames = extra_args.get("num_frames", 121) - negative_prompt = extra_args.get( - "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" - ) - - def round_to_nearest_resolution_acceptable_by_vae(height, width): - height = height - (height % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr] - width = width - (width % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr] - return height, width - - downscale_factor = 2 / 3 - # Part 1: Generate video at smaller resolution - downscaled_height, downscaled_width = ( - int(expected_height * downscale_factor), - int(expected_width * downscale_factor), - ) - downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( - downscaled_height, downscaled_width - ) - - # Generate initial latents at lower resolution - latents = self.pipe( # type: ignore[misc] - conditions=None, - prompt=prompt_batch, - negative_prompt=negative_prompt, - width=downscaled_width, - height=downscaled_height, - num_frames=num_frames, - num_inference_steps=self.config.n_steps, - output_type="latent", - ).frames - - # Part 2: Upscale generated video using latent upsampler (if available) - if self.pipe_upsample is not None: - _ = self.pipe_upsample(latents=latents, output_type="latent").frames - - # Part 3: Denoise the upscaled video with few steps to improve texture - # However, in this example code, we will omit the upscale step since its optional. - - class Quantizer: """Handles model quantization operations.""" @@ -724,6 +136,8 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: quant_config = NVFP4_DEFAULT_CONFIG else: raise NotImplementedError(f"Unknown format {self.config.format}") + if self.config.quantize_mha: + quant_config["quant_cfg"]["*[qkv]_bmm_quantizer"] = {"num_bits": (4, 3), "axis": None} # type: ignore[index] set_quant_config_attr( quant_config, self.model_config.trt_high_precision_dtype.value, @@ -731,7 +145,7 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: alpha=self.config.alpha, lowrank=self.config.lowrank, ) - + self.logger.info(f"Quant config {quant_config}") return quant_config def quantize_model( @@ -739,7 +153,7 @@ def quantize_model( backbone: torch.nn.Module, quant_config: Any, forward_loop: callable, # type: ignore[valid-type] - ) -> None: + ) -> torch.nn.Module: """ Apply quantization to the model. @@ -761,21 +175,29 @@ def quantize_model( mtq.disable_quantizer(backbone, model_filter_func) self.logger.info("Quantization completed successfully") + return backbone class ExportManager: """Handles model export operations.""" - def __init__(self, config: ExportConfig, logger: logging.Logger): + def __init__( + self, + config: ExportConfig, + logger: logging.Logger, + pipeline_manager: PipelineManager | None = None, + ): """ Initialize export manager. Args: config: Export configuration logger: Logger instance + pipeline_manager: Pipeline manager for per-backbone IO """ self.config = config self.logger = logger + self.pipeline_manager = pipeline_manager def _has_conv_layers(self, model: torch.nn.Module) -> bool: """ @@ -799,13 +221,18 @@ def save_checkpoint(self, backbone: torch.nn.Module) -> None: Save quantized model checkpoint. Args: - backbone: Model backbone to save + backbone: The quantized backbone module to save (must be the same instance + that was passed to mtq.quantize, as it carries the _modelopt_state). """ if not self.config.quantized_torch_ckpt_path: return - self.logger.info(f"Saving quantized checkpoint to {self.config.quantized_torch_ckpt_path}") - mto.save(backbone, str(self.config.quantized_torch_ckpt_path)) + ckpt_path = self.config.quantized_torch_ckpt_path + ckpt_path.mkdir(parents=True, exist_ok=True) + target_path = ckpt_path / "backbone.pt" + self.logger.info(f"Saving backbone to {target_path}") + mto.save(backbone, str(target_path)) + self.logger.info("Checkpoint saved successfully") def export_onnx( @@ -848,19 +275,41 @@ def export_onnx( self.logger.info("ONNX export completed successfully") - def restore_checkpoint(self, backbone: nn.Module) -> None: + def restore_checkpoint(self) -> None: """ Restore a previously quantized model. - Args: - backbone: Model backbone to restore into """ if not self.config.restore_from: return - self.logger.info(f"Restoring model from {self.config.restore_from}") - mto.restore(backbone, str(self.config.restore_from)) - self.logger.info("Model restored successfully") + restore_path = self.config.restore_from + if self.pipeline_manager is None: + raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.") + + backbone = self.pipeline_manager.get_backbone() + if restore_path.exists() and restore_path.is_dir(): + source_path = restore_path / "backbone.pt" + if not source_path.exists(): + raise FileNotFoundError(f"Backbone checkpoint not found: {source_path}") + self.logger.info(f"Restoring backbone from {source_path}") + mto.restore(backbone, str(source_path)) + self.logger.info("Backbone checkpoints restored successfully") + + # TODO: should not do the any data type + def export_hf_ckpt(self, pipe: Any) -> None: + """ + Export quantized model to HuggingFace checkpoint format. + + Args: + pipe: Diffusion pipeline containing the quantized model + """ + if not self.config.hf_ckpt_dir: + return + + self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}") + export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir) + self.logger.info("HuggingFace checkpoint export completed successfully") def create_argument_parser() -> argparse.ArgumentParser: @@ -904,9 +353,13 @@ def create_argument_parser() -> argparse.ArgumentParser: ) model_group.add_argument( "--backbone", - type=str, + nargs="+", default=None, - help="model backbone in the DiffusionPipeline to work on, if not provided use default based on model type", + help=( + "Model backbone(s) in the DiffusionPipeline to work on. " + "Provide one name or multiple names separated by space or comma. " + "If not provided use default based on model type." + ), ) model_group.add_argument( "--model-dtype", @@ -935,6 +388,16 @@ def create_argument_parser() -> argparse.ArgumentParser: action="store_true", help="Skip upsampler pipeline for LTX-Video (faster calibration, only quantizes main transformer)", ) + model_group.add_argument( + "--extra-param", + action="append", + default=[], + metavar="KEY=VALUE", + help=( + "Extra model-specific parameters in KEY=VALUE form. Can be provided multiple times. " + "These override model-specific CLI arguments when present." + ), + ) quant_group = parser.add_argument_group("Quantization Configuration") quant_group.add_argument( "--format", @@ -994,6 +457,11 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Path to save quantized PyTorch checkpoint", ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") + export_group.add_argument( + "--hf-ckpt-dir", + type=str, + help="Directory for HuggingFace checkpoint export", + ) export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) @@ -1010,12 +478,17 @@ def create_argument_parser() -> argparse.ArgumentParser: def main() -> None: + from diffusers.models.normalization import RMSNorm as DiffuserRMSNorm + + torch.nn.RMSNorm = DiffuserRMSNorm + torch.nn.modules.normalization.RMSNorm = DiffuserRMSNorm + parser = create_argument_parser() - args = parser.parse_args() + args, unknown_args = parser.parse_known_args() model_type = ModelType(args.model) if args.backbone is None: - args.backbone = MODEL_DEFAULTS[model_type]["backbone"] + args.backbone = [MODEL_DEFAULTS[model_type]["backbone"]] s = time.time() model_dtype = {"default": DataType(args.model_dtype).torch_dtype} @@ -1027,6 +500,7 @@ def main() -> None: logger.info("Starting Enhanced Diffusion Model Quantization") try: + extra_params = parse_extra_params(args.extra_param, unknown_args, logger) model_config = ModelConfig( model_type=model_type, model_dtype=model_dtype, @@ -1037,6 +511,7 @@ def main() -> None: else None, cpu_offloading=args.cpu_offloading, ltx_skip_upsampler=args.ltx_skip_upsampler, + extra_params=extra_params, ) quant_config = QuantizationConfig( @@ -1070,6 +545,7 @@ def main() -> None: if args.quantized_torch_ckpt_save_path else None, onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None, restore_from=Path(args.restore_from) if args.restore_from else None, ) @@ -1084,15 +560,11 @@ def main() -> None: pipeline_manager.setup_device() backbone = pipeline_manager.get_backbone() - export_manager = ExportManager(export_config, logger) + export_manager = ExportManager(export_config, logger, pipeline_manager) if export_config.restore_from and export_config.restore_from.exists(): - export_manager.restore_checkpoint(backbone) + export_manager.restore_checkpoint() - if export_config.quantized_torch_ckpt_path and not export_config.restore_from.samefile( - export_config.restore_from - ): - export_manager.save_checkpoint(backbone) else: logger.info("Initializing calibration...") calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) @@ -1101,6 +573,7 @@ def main() -> None: quantizer = Quantizer(quant_config, model_config, logger) backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) + # Pipe loads the ckpt just before the inference. def forward_loop(mod): calibrator.run_calibration(batched_prompts) @@ -1114,10 +587,12 @@ def forward_loop(mod): export_manager.save_checkpoint(backbone) + # TODO (Jingyu): To update this function, as we are focusing more on the torch deployment side. check_conv_and_mha( backbone, quant_config.format == QuantFormat.FP4, quant_config.quantize_mha ) - mtq.print_quant_summary(backbone) + + pipeline_manager.print_quant_summary() export_manager.export_onnx( pipe, @@ -1125,6 +600,9 @@ def forward_loop(mod): model_config.model_type, quant_config.format, ) + + export_manager.export_hf_ckpt(pipe) + logger.info( f"Quantization process completed successfully! Time taken = {time.time() - s} seconds" ) diff --git a/examples/diffusers/quantization/quantize_config.py b/examples/diffusers/quantization/quantize_config.py new file mode 100644 index 0000000000..980d39f31f --- /dev/null +++ b/examples/diffusers/quantization/quantize_config.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any + +import torch +from models_utils import MODEL_REGISTRY, ModelType + + +class DataType(str, Enum): + """Supported data types for model loading.""" + + HALF = "Half" + BFLOAT16 = "BFloat16" + FLOAT = "Float" + + @property + def torch_dtype(self) -> torch.dtype: + return self._dtype_map[self.value] + + +DataType._dtype_map = { + DataType.HALF: torch.float16, + DataType.BFLOAT16: torch.bfloat16, + DataType.FLOAT: torch.float32, +} + + +class QuantFormat(str, Enum): + """Supported quantization formats.""" + + INT8 = "int8" + FP8 = "fp8" + FP4 = "fp4" + + +class QuantAlgo(str, Enum): + """Supported quantization algorithms.""" + + MAX = "max" + SVDQUANT = "svdquant" + SMOOTHQUANT = "smoothquant" + + +class CollectMethod(str, Enum): + """Calibration collection methods.""" + + GLOBAL_MIN = "global_min" + MIN_MAX = "min-max" + MIN_MEAN = "min-mean" + MEAN_MAX = "mean-max" + DEFAULT = "default" + + +@dataclass +class QuantizationConfig: + """Configuration for model quantization.""" + + format: QuantFormat = QuantFormat.INT8 + algo: QuantAlgo = QuantAlgo.MAX + percentile: float = 1.0 + collect_method: CollectMethod = CollectMethod.DEFAULT + alpha: float = 1.0 # SmoothQuant alpha + lowrank: int = 32 # SVDQuant lowrank + quantize_mha: bool = False + compress: bool = False + + def validate(self) -> None: + """Validate configuration consistency.""" + if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT: + raise NotImplementedError("Only 'default' collect method is implemented for FP8.") + if self.quantize_mha and self.format == QuantFormat.INT8: + raise ValueError("MHA quantization is only supported for FP8, not INT8.") + if self.compress and self.format == QuantFormat.INT8: + raise ValueError("Compression is only supported for FP8 and FP4, not INT8.") + + +@dataclass +class CalibrationConfig: + """Configuration for calibration process.""" + + prompts_dataset: dict | Path + batch_size: int = 2 + calib_size: int = 128 + n_steps: int = 30 + + def validate(self) -> None: + """Validate calibration configuration.""" + if self.batch_size <= 0: + raise ValueError("Batch size must be positive.") + if self.calib_size <= 0: + raise ValueError("Calibration size must be positive.") + if self.n_steps <= 0: + raise ValueError("Number of steps must be positive.") + + @property + def num_batches(self) -> int: + """Calculate number of calibration batches.""" + return self.calib_size // self.batch_size + + +@dataclass +class ModelConfig: + """Configuration for model loading and inference.""" + + model_type: ModelType = ModelType.FLUX_DEV + model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16}) + backbone: str = "" + trt_high_precision_dtype: DataType = DataType.HALF + override_model_path: Path | None = None + cpu_offloading: bool = False + ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration) + extra_params: dict[str, Any] = field(default_factory=dict) + + @property + def model_path(self) -> str: + """Get the model path (override or default).""" + if self.override_model_path: + return str(self.override_model_path) + return MODEL_REGISTRY[self.model_type] + + +@dataclass +class ExportConfig: + """Configuration for model export.""" + + quantized_torch_ckpt_path: Path | None = None + onnx_dir: Path | None = None + hf_ckpt_dir: Path | None = None + restore_from: Path | None = None + + def validate(self) -> None: + """Validate export configuration.""" + if self.restore_from and not self.restore_from.exists(): + raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}") + + if self.quantized_torch_ckpt_path: + parent_dir = self.quantized_torch_ckpt_path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + + if self.onnx_dir and not self.onnx_dir.exists(): + self.onnx_dir.mkdir(parents=True, exist_ok=True) + + if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): + self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 7ec49379e2..21fcd87d0b 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -25,7 +25,7 @@ from diffusers.utils import load_image import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.plugins.diffusers import AttentionModuleMixin +from modelopt.torch.quantization.plugins.diffusion.diffusers import AttentionModuleMixin USE_PEFT = True try: @@ -69,13 +69,23 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): def filter_func_ltx_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" - pattern = re.compile(r".*(proj_in|time_embed|caption_projection|proj_out).*") + pattern = re.compile( + r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*" + ) + return pattern.match(name) is not None + + +def filter_func_flux_dev(name: str) -> bool: + """Filter function specifically for Flux-dev models.""" + pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)") return pattern.match(name) is not None def filter_func_wan_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" - pattern = re.compile(r".*(patch_embedding|condition_embedder).*") + pattern = re.compile( + r".*(patch_embedding|condition_embedder|proj_out|blocks.0\.|blocks.1\.|blocks.39|blocks.38).*" + ) return pattern.match(name) is not None diff --git a/examples/llm_autodeploy/api_server.py b/examples/llm_autodeploy/api_server.py index 6e7f9d53c7..0498ed8739 100644 --- a/examples/llm_autodeploy/api_server.py +++ b/examples/llm_autodeploy/api_server.py @@ -20,8 +20,7 @@ import uvicorn from fastapi import FastAPI, HTTPException -from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig -from tensorrt_llm.builder import BuildConfig +from tensorrt_llm._torch.auto_deploy import LLM from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.sampling_params import SamplingParams from tensorrt_llm.serve.openai_protocol import ( @@ -45,11 +44,8 @@ def build_runner_from_config(args) -> LLM: """Builds a model runner from our config.""" mto.enable_huggingface_checkpointing() model_kwargs = {"max_position_embeddings": args.max_seq_len, "use_cache": False} - build_config = BuildConfig(max_seq_len=args.max_seq_len, max_batch_size=args.max_batch_size) - build_config.plugin_config.tokens_per_block = args.max_seq_len - # setup AD config - ad_config = AutoDeployConfig( + llm = LLM( model=args.ckpt_path, compile_backend=args.compile_backend, device=args.device, @@ -58,9 +54,8 @@ def build_runner_from_config(args) -> LLM: max_seq_len=args.max_seq_len, max_num_tokens=args.max_num_tokens, model_kwargs=model_kwargs, - attn_backend="triton", + attn_backend="flashinfer", ) - llm = LLM(**ad_config.to_llm_kwargs()) return llm diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 31103ff869..405e8590a5 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -43,9 +43,11 @@ from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model +from sparse_attention_utils import sparsify_model import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: @@ -60,6 +62,9 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) + # Sparse attention arguments + sparse_cfg = arg_dict.pop("sparse_cfg", None) + additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} @@ -91,6 +96,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | auto_quantize_checkpoint=auto_quantize_checkpoint, ) + if sparse_cfg: + if is_attn_sparsified(model_obj.model): + warnings.warn("Skipping sparse attention: model already has sparse attention applied.") + else: + sparsify_model( + model=model_obj, + sparse_cfg=sparse_cfg, + ) + return model_obj @@ -152,6 +166,11 @@ def setup_parser_with_modelopt_args(): action="store_true", help="Compress the model after quantization", ) + parser.add_argument( + "--sparse_cfg", + type=str, + help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", + ) return parser @@ -177,6 +196,7 @@ def setup_parser_with_modelopt_args(): "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, + "sparse_cfg": args.sparse_cfg, } ) diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index ca244052b8..316f443bb0 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -48,6 +48,7 @@ from fire import Fire from modeling import EvalModel, select_model from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model +from sparse_attention_utils import sparsify_model from tqdm import tqdm try: @@ -56,6 +57,7 @@ LLM = None # type: ignore[misc] import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -230,6 +232,7 @@ def main( auto_quantize_method: str = "gradient", auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, + sparse_cfg: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -289,6 +292,20 @@ def main( auto_quantize_checkpoint=auto_quantize_checkpoint, ) + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + for subject in tqdm(subjects): dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[ :ntrain diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index 747b95d5b2..d06d055603 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel): lora_path: str = "" device: str = "cuda" load_8bit: bool = False + attn_implementation: str | None = None def load(self): if self.model is None: @@ -188,6 +189,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) print_gpu_utilization() if self.lora_path: @@ -241,6 +244,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py new file mode 100644 index 0000000000..8dc560851f --- /dev/null +++ b/examples/llm_eval/sparse_attention_utils.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sparse attention integration with llm_eval.""" + +import modelopt.torch.sparsity.attention_sparsity as mtsa + + +def _extract_model(model_obj): + """Extract actual model from wrapper (HFLM or EvalModel).""" + if hasattr(model_obj, "gpt2"): + return model_obj.gpt2 + elif hasattr(model_obj, "model"): + return model_obj.model + else: + return model_obj + + +def sparsify_model( + model, + sparse_cfg: str, + backend=None, +): + """Apply sparse attention to model with optional RULER calibration. + + Args: + model: Model wrapper (HFLM or EvalModel) or raw model + sparse_cfg: Sparse attention config name or dict + backend: Backend to use (optional, overrides config backend) + + Returns: + The model with sparse attention applied + + Note: + Calibration is automatically triggered if the config contains a 'calibration' field. + The calibration will auto-generate RULER dataset from the model's tokenizer. + """ + # Extract actual model + net = _extract_model(model) + + # Resolve config + if isinstance(sparse_cfg, str): + # Get config from mtsa module (e.g., SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT) + mtsa_cfg = getattr(mtsa, sparse_cfg, None) + if mtsa_cfg is None: + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}.") + else: + mtsa_cfg = sparse_cfg + + # Override backend if specified + if backend: + if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: + modified_sparse_cfg = {} + for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): + modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg + if isinstance(modified_cfg, dict): + modified_cfg["backend"] = backend + modified_sparse_cfg[pattern] = modified_cfg + mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + + # Apply sparsification + print(f"\nApplying sparse attention with config: {sparse_cfg}") + mtsa.sparsify(net, mtsa_cfg) + print("Sparse attention applied successfully!") + + return model diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 23ab1ecf90..187eed7f15 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -109,6 +109,9 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | QWen3 MOE, Next 6 | ✅ | - | - | - | ✅ | | QwQ | ✅ | - | - | - | ✅ | | DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ | +| GLM-4.78 | ✅ | - | - | - | ✅ | +| Kimi K2 | - | - | - | - | ✅ | +| MiniMax M2.1 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper | ✅ | ❌ | ❌ | ❌ | - | @@ -120,7 +123,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *4.For some models, KV cache quantization may result in a higher accuracy penalty.* \ > *5.A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later* \ > *6.Some models currently support export to HF format only.* \ -> *7.[PTQ for DeepSeek](../deepseek/README.md)* +> *7.[PTQ for DeepSeek](../deepseek/README.md)* \ +> *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.* @@ -161,6 +165,23 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_ [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. +#### VLM calibration with image-text pairs (e.g., Nemotron VL) + +For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: + +```bash +python hf_ptq.py \ + --pyt_ckpt_path \ + --qformat nvfp4 \ + --export_path \ + --trust_remote_code \ + --calib_with_images \ + --calib_size 512 +``` + +> Note: when `--calib_with_images` is set, `--calib_size` must be a single value, and the calibration dataset is nvidia/nemotron_vlm_dataset_v2. +This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. + ### NeMo Example Script NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. @@ -226,7 +247,7 @@ export HF_PATH= tuple[list[str], dict[str, torch.Tensor]]: + """Load MTP weights from the model checkpoint. - return None + Some models store additional layers in separate safetensors files with non-standard + names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these + files even though they're referenced in model.safetensors.index.json. + + This function detects such cases and explicitly loads the missing weights. + + Args: + model: The loaded model that may be missing weights + model_path: Path to the model directory + + Returns: + List of layer prefixes that were loaded from non-standard safetensors files. + These layers should typically be excluded from quantization. + Empty list if no additional weights were loaded. + Dictionary of MTP weights that were not loaded into the model state dict. + """ + model_path = Path(model_path) + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + return [], {} + + # Load the index to find all referenced safetensors files + index = json.load(open(index_file)) + weight_map = index["weight_map"] + # Find all files in weight_map whose key or value contains "mtp" + mtp_weight_map = {} + for k, v in weight_map.items(): + if "mtp" in k or "mtp" in v: + mtp_weight_map.setdefault(v, []).append(k) + + if not mtp_weight_map: + return [], {} + + def _extract_layer_prefixes(keys): + mtp_layer_prefixes = set() + for key in keys: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + prefix = ".".join(parts[: i + 2]) + mtp_layer_prefixes.add(prefix) + break + + return mtp_layer_prefixes + + # Flatten mtp_weight_map.values() (list of list of str) to a single list of str + mtp_keys = [k for keys in mtp_weight_map.values() for k in keys] + mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys) + + # Check which non-standard files exist and have missing weights + model_state = model.state_dict() + total_loaded = 0 + + not_in_state_dict = {} + + for filename, mtp_keys in mtp_weight_map.items(): + filepath = model_path / filename + if not filepath.exists(): + continue + + print(f"Loading {len(mtp_keys)} mtp weights from {filename}...") + weights = load_file(str(filepath), device="cpu") + weights = {k: v for k, v in weights.items() if k in mtp_keys} + # Load the MTP weights to the model state dict + in_state_dict = {k: weights[k] for k in weights if k in model_state} + not_in_state_dict = not_in_state_dict | { + k: weights[k] for k in weights if k not in model_state + } + + if in_state_dict: + model.load_state_dict(in_state_dict, strict=False) + total_loaded += len(in_state_dict) + + if total_loaded > 0: + print( + f"✓ Successfully loaded {total_loaded} MTP weights, " + f"{len(not_in_state_dict)} MTP weights not in model.state_dict" + ) + + if mtp_layer_prefixes: + print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}") + + return list(mtp_layer_prefixes), not_in_state_dict def get_dtype(dtype): @@ -301,6 +467,7 @@ def get_model( # Load config once and handle VL model detection try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " @@ -320,8 +487,6 @@ def get_model( model_kwargs.setdefault("torch_dtype", "auto") if "vila" in ckpt_path.lower(): - from transformers import AutoModel - hf_vila = AutoModel.from_pretrained( ckpt_path, device_map=device_map, @@ -346,6 +511,17 @@ def get_model( device_map=device_map, **model_kwargs, ) + elif ( + hasattr(hf_config, "quantization_config") + and hf_config.quantization_config.get("format", None) == "pack-quantized" + ): + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + device_map="auto", + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + ) else: architecture = hf_config.architectures[0] @@ -353,24 +529,28 @@ def get_model( if not hasattr(transformers, architecture): warnings.warn( f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." + "Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)." ) assert trust_remote_code, ( "Please set trust_remote_code to True if you want to use this architecture" ) - auto_model_module = AutoModelForCausalLM + # Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models + if getattr(hf_config, "is_encoder_decoder", False): + auto_model_module = AutoModel + else: + auto_model_module = AutoModelForCausalLM from_config = auto_model_module.from_config else: auto_model_module = getattr(transformers, architecture) from_config = auto_model_module._from_config with init_empty_weights(): - # When computing the device_map, assuming half precision by default, + # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: + if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) model_kwargs2["torch_dtype"] = torch_dtype model_kwargs2.pop("max_memory", None) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742b..d7aadf994f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -25,11 +25,13 @@ from example_utils import ( build_quant_cfg, copy_custom_model_files, + create_vlm_calibration_loop, get_model, get_processor, get_tokenizer, is_enc_dec, is_nemotron_vl, + load_mtp_weights, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -51,6 +53,7 @@ export_hf_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, + save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration @@ -83,6 +86,8 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, + "mxfp8": mtq.MXFP8_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { @@ -97,6 +102,39 @@ mto.enable_huggingface_checkpointing() +def extract_and_prepare_language_model_from_vl(full_model): + """Extract language model from VL model and disable quantization for non-language components. + + Args: + full_model: The full VLM model + + Returns: + tuple: (language_model, model_type) or (None, None) if not a VLM + """ + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + language_model = language_model_lineage.pop(-1) + ancestors = language_model_lineage + # Apply disabled quant to all modules that are not part of language_model + # This excludes them during HF export + disabled_quant_cfg = { + "quant_cfg": {"default": {"enable": False}}, + "algorithm": "max", + } + + memo = set(ancestors) | {language_model} + for ancestor in ancestors: + for _, module in ancestor.named_children(): + if module not in memo: + mtq.quantize(module, disabled_quant_cfg, forward_loop=None) + memo.add(module) + + model_type = get_model_type(language_model) + return language_model, model_type + + return None, None + + def make_calib_dataloader( args: argparse.Namespace, language_model: torch.nn.Module, @@ -107,7 +145,30 @@ def make_calib_dataloader( ) -> tuple[DataLoader, str | None]: calib_dataloader = None first_text_speech_dataset = None - if model_type == "mllama": + if args.calib_with_images: + # VLM image-text calibration path: assume Nemotron VLM dataset by default. + assert processor is not None, ( + "Please provide a processor (e.g., AutoProcessor) for image calibration." + ) + assert len(args.calib_size) == 1, ( + "Image calibration currently supports a single dataset. " + "Please pass --calib_size with one value (e.g., --calib_size 256)." + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name="nemotron_vlm_dataset_v2", + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + max_length=args.calib_seq, + require_image=True, + subsets=["sparsetables", "plotqa_cot", "wiki_en"], + shuffle_buffer_size=10_000, + seed=42, + use_media_shards=True, + max_shards=1, + ) + elif model_type == "mllama": assert processor is not None and isinstance(processor, MllamaImageProcessor), ( "The MllamaImageProcessor must be set." ) @@ -164,6 +225,12 @@ def auto_quantize( ): """Auto search quantization of multiple formats.""" + if args.calib_with_images: + raise NotImplementedError( + "AutoQuantize with image-text calibration is not supported yet. " + "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." + ) + assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( "Auto Quantization is not supported for pipeline parallel size > 1" ) @@ -184,6 +251,7 @@ def auto_quantize( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] for args.qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -291,6 +359,14 @@ def load_model(args: argparse.Namespace): tokenizer = None language_model = full_model default_padding_side = None + default_pad_token = None + + is_nemotron_vl_model = is_nemotron_vl(full_model) + + # Default to image-text calibration for VLM models + if is_nemotron_vl_model and not args.calib_with_images: + print("Nemotron VL model detected. Enabling image-text calibration by default.") + args.calib_with_images = True if model_type == "mllama": processor = get_processor( @@ -307,6 +383,31 @@ def load_model(args: argparse.Namespace): device, trust_remote_code=args.trust_remote_code, ) + elif is_nemotron_vl_model and args.calib_with_images: + # For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs. + processor = AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left" + ) + + if hasattr(processor, "tokenizer") and processor.tokenizer is not None: + tokenizer = processor.tokenizer + else: + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + + default_pad_token = tokenizer.pad_token + # Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration. + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!" + + default_padding_side = tokenizer.padding_side + tokenizer.padding_side = "left" + + # Quantize only the language model, but keep the full_model for calibration forward. + extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) + if extracted_lm is not None: + language_model = extracted_lm + model_type = extracted_model_type else: if args.dataset is None: args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] @@ -320,29 +421,15 @@ def load_model(args: argparse.Namespace): tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) default_padding_side = tokenizer.padding_side + default_pad_token = tokenizer.pad_token # Left padding usually provides better calibration result. tokenizer.padding_side = "left" # We only quantize the language model for VLMs other than the type supported above. - language_model_lineage = get_language_model_from_vl(full_model) - if language_model_lineage is not None: - language_model = language_model_lineage.pop(-1) - ancestors = language_model_lineage - # Apply disabled quant to all modules that are not part of language_model so we can exclude them during - # HF export. - disabled_quant_cfg = { - "quant_cfg": {"default": {"enable": False}}, - "algorithm": "max", - } - - memo = set(ancestors) | {language_model} - for ancestor in ancestors: - for _, module in ancestor.named_children(): - if module not in memo: - mtq.quantize(module, disabled_quant_cfg, forward_loop=None) - memo.add(module) - - model_type = get_model_type(language_model) + extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) + if extracted_lm is not None: + language_model = extracted_lm + model_type = extracted_model_type if model_type == "phi4mm": warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") @@ -355,6 +442,7 @@ def load_model(args: argparse.Namespace): processor, tokenizer, default_padding_side, + default_pad_token, device, ) @@ -418,9 +506,12 @@ def mono_quantize( print("Disabling quantization for vision components in Nemotron VL model") quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - # Also disable radio model components specifically + # Also disable radio model components specifically (for Nemotron-Parse) quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder + quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific + print("Quantization will only be applied to the decoder (text generation) component") if not model_is_already_quantized or calibration_only: if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": @@ -432,9 +523,14 @@ def mono_quantize( if not use_calibration: warnings.warn("Dynamic quantization. Calibration skipped.") - calibrate_loop = ( - create_forward_loop(dataloader=calib_dataloader) if use_calibration else None - ) + calibrate_loop = None + if use_calibration: + # For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values). + # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. + if args.calib_with_images and is_nemotron_vl_model: + calibrate_loop = create_vlm_calibration_loop(full_model, calib_dataloader) + else: + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) if calibration_only: language_model = mtq.calibrate( @@ -461,6 +557,7 @@ def export_quantized( model_type: str | None, tokenizer: PreTrainedTokenizerBase | None, default_padding_side, + default_pad_token, ): with torch.inference_mode(): if model_type is None: @@ -506,6 +603,10 @@ def export_quantized( or args.sparsity_fmt != "dense" or "int8_sq" in args.qformat ): + if ( + args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1 + ) and args.qformat == "nvfp4_svdquant": + raise NotImplementedError("Svdquant does not support multiple GPUs yet.") warnings.warn( "Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime." ) @@ -535,9 +636,17 @@ def export_quantized( "They will be set at deployment time." ) + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) + # Store the MTP layer prefixes on the model for later exclusion from quantization + mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) + + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes + export_hf_checkpoint( full_model, export_dir=export_path, + extra_state_dict=mtp_state_dict, ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used @@ -546,6 +655,8 @@ def export_quantized( # Restore default padding and export the tokenizer as well. if tokenizer is not None: tokenizer.padding_side = default_padding_side + if default_pad_token is not None: + tokenizer.pad_token = default_pad_token tokenizer.save_pretrained(export_path) end_time = time.time() @@ -575,14 +686,17 @@ def pre_quantize( ][0:1] # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: + if model_type == "deepseek": + # DeepSeek generation may go OOM, so we skip it + generated_ids_before_ptq = None + elif is_nemotron_vl_model and tokenizer is not None: generated_ids_before_ptq = run_nemotron_vl_preview( full_model, tokenizer, preview_input_ids, args.pyt_ckpt_path, "before quantization", - allow_fallback=True, + allow_fallback=False, ) else: # Standard generation for non-Nemotron VL models @@ -613,12 +727,19 @@ def post_quantize( """ if args.verbose: - mtq.print_quant_summary(full_model) + try: + mtq.print_quant_summary(full_model, args.export_path) + save_expert_token_count_table(full_model, args.export_path) + except Exception as e: + print(f"Error saving quant summary: {e}") + print("Continuing with generation...") # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: + if generated_ids_before_ptq is None: + pass + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: @@ -690,39 +811,46 @@ def quantize_main( processor: BaseImageProcessor | ProcessorMixin | None, tokenizer: PreTrainedTokenizerBase | None, default_padding_side, + default_pad_token, device: torch.device, ): if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = language_model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( - language_model.device - ) - * 100 - ) + # For VL models with image-text calibration, skip automatic batch size detection + # since get_max_batch_size can't handle multimodal inputs + if args.calib_with_images: + print("Image-text calibration enabled. Using default batch_size=1 for calibration.") + args.batch_size = 1 else: - sample_input_single_batch = None + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None - args.batch_size = get_max_batch_size( - language_model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) print(f"Use calib batch_size {args.batch_size}") @@ -766,6 +894,7 @@ def quantize_main( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "mxfp8", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES ), f"Plain quantization format {args.qformat} not supported for HF export path" @@ -779,6 +908,19 @@ def quantize_main( KV_QUANT_CFG_CHOICES, ) + # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) + # These layers are typically speculative decoding layers that should be exported as-is + mtp_layer_prefixes = getattr(full_model, "_mtp_layer_prefixes", None) + if mtp_layer_prefixes: + import copy + + quant_cfg = copy.deepcopy(quant_cfg) + for prefix in mtp_layer_prefixes: + # Add exclusion pattern for this MTP layer (e.g., "*layers.92*") + pattern = f"*{prefix.split('.')[-2]}.{prefix.split('.')[-1]}*" + quant_cfg["quant_cfg"][pattern] = {"enable": False} + print(f"Excluding MTP layer from quantization: {pattern}") + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, @@ -805,7 +947,15 @@ def quantize_main( is_nemotron_vl_model, first_text_speech_dataset, ) - export_quantized(args, full_model, language_model, model_type, tokenizer, default_padding_side) + export_quantized( + args, + full_model, + language_model, + model_type, + tokenizer, + default_padding_side, + default_pad_token, + ) def parse_args() -> argparse.Namespace: @@ -856,6 +1006,14 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--calib_with_images", + action="store_true", + help=( + "Calibrate with image-text pairs (for VLMs). " + "This uses nemotron_vlm_dataset_v2 with default subsets (sparsetables, plotqa_cot, wiki_en)." + ), + ) parser.add_argument("--inference_tensor_parallel", type=int, default=1) parser.add_argument("--inference_pipeline_parallel", type=int, default=1) parser.add_argument("--awq_block_size", default=0, type=int) @@ -993,6 +1151,7 @@ def main(args: argparse.Namespace): processor, tokenizer, default_padding_side, + default_pad_token, device, ) = load_model(args) @@ -1010,6 +1169,7 @@ def main(args: argparse.Namespace): processor, tokenizer, default_padding_side, + default_pad_token, device, ) @@ -1020,6 +1180,6 @@ def main(args: argparse.Namespace): if args.export_fmt != "hf": warnings.warn("Deprecated. --export_fmt forced to hf.") - args.dataset = args.dataset.split(",") if args.dataset else None + args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] main(args) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 2ae7dde4a3..c2194111ca 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -36,7 +36,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export import get_model_type from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets @@ -243,7 +243,7 @@ def export_model( export_dir = Path(export_path) export_dir.mkdir(parents=True, exist_ok=True) - post_state_dict, hf_quant_config = _export_hf_checkpoint( + post_state_dict, hf_quant_config = _export_transformers_checkpoint( model, torch.bfloat16, accelerator=accelerator ) diff --git a/examples/llm_ptq/nemotron_vl_calib.py b/examples/llm_ptq/nemotron_vl_calib.py new file mode 100644 index 0000000000..398ef67fe5 --- /dev/null +++ b/examples/llm_ptq/nemotron_vl_calib.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VL calibration helpers. + +Nemotron Nano VL v2 remote-code wrapper `forward()` is not ideal to call during PTQ calibration because it may: +- Call `torch.distributed.get_rank()` unconditionally +- Assume `past_key_values` exists in the language model output + +Instead, we run a "safe multimodal forward" that exercises: +- Vision encoder feature extraction (C-RADIOv2-H) +- Insertion of vision embeddings into token embeddings at `img_context_token_id` +- Language model forward pass (to trigger quantizer calibration) +""" + +from __future__ import annotations + +import contextlib +from typing import Any + +import torch + + +def safe_nemotron_vl_forward(full_model: torch.nn.Module, batch: dict[str, Any]) -> None: + """Run a minimal multimodal forward for Nemotron VL that avoids wrapper output packaging.""" + pixel_values = batch.get("pixel_values") + input_ids = batch.get("input_ids") + attention_mask = batch.get("attention_mask") + position_ids = batch.get("position_ids") + image_flags = batch.get("image_flags") + + if pixel_values is None or input_ids is None: + return + + # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor doesn't always emit it. + # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. + if image_flags is None and torch.is_tensor(pixel_values): + image_flags = torch.ones( + (pixel_values.shape[0], 1), device=pixel_values.device, dtype=torch.long + ) + if image_flags is None: + return + + # Match the model's preferred vision dtype (usually bf16). + vision_dtype = None + with contextlib.suppress(AttributeError, TypeError): + vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) + if vision_dtype is None: + with contextlib.suppress(AttributeError, TypeError): + vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) + if ( + vision_dtype is not None + and torch.is_tensor(pixel_values) + and pixel_values.dtype != vision_dtype + ): + pixel_values = pixel_values.to(dtype=vision_dtype) + + # Token embeddings + inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) + image_flags_s = image_flags.squeeze(-1) + + b, n, c = inputs_embeds.shape + flat_embeds = inputs_embeds.reshape(b * n, c) + flat_ids = input_ids.reshape(b * n) + selected = flat_ids == full_model.img_context_token_id + + # Vision embeddings + vit_embeds = full_model.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags_s == 1] + try: + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c) + except Exception: + vit_embeds = vit_embeds.reshape(-1, c) + n_token = selected.sum() + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + + inputs_embeds = flat_embeds.reshape(b, n, c) + + # LLM forward (drives activation stats) + full_model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=False, + ) diff --git a/examples/llm_ptq/requirements.txt b/examples/llm_ptq/requirements.txt index 3485f10e10..1469d5552b 100644 --- a/examples/llm_ptq/requirements.txt +++ b/examples/llm_ptq/requirements.txt @@ -1,3 +1,4 @@ +compressed-tensors==0.12.0 fire flash-attn>=2.6.0 rouge_score>=0.1.2 diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5f..a6455ed37a 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant | mxfp8) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant, mxfp8]" >&2 exit 1 ;; esac diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py index 6c9d921b83..9919e405ba 100644 --- a/examples/llm_ptq/vlm_utils.py +++ b/examples/llm_ptq/vlm_utils.py @@ -105,27 +105,31 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): else: processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - messages = [ - {"role": "system", "content": "/no_think"}, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "", - }, - { - "type": "text", - "text": question, - }, - ], - }, - ] - - # Apply chat template - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + # Use chat template if available, otherwise fall back to default task prompt + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "", + }, + { + "type": "text", + "text": question, + }, + ], + }, + ] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # For models without chat templates (e.g., encoder-decoder VL models), + # use the tokenizer's bos/eos tokens as a minimal prompt + prompt = (tokenizer.bos_token or "") + question # Process inputs using the processor with single image inputs = processor( @@ -139,6 +143,12 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): inputs = inputs.to(model_device) print(f" Moved inputs to {model_device}") + # Verify we have pixel_values for the vision encoder + if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None: + raise ValueError( + "Processor did not generate pixel_values. Check processor configuration." + ) + # Generate response using model.generate generated_ids = model.generate( pixel_values=inputs.pixel_values, @@ -148,12 +158,23 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): ) # Decode the response (trim input tokens like in the working example) + if generated_ids is None: + raise ValueError("Model generate returned None") + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + # Use processor.batch_decode if available, otherwise fall back to tokenizer + decoder = processor if hasattr(processor, "batch_decode") else tokenizer + output_text = decoder.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) + + if output_text is None or len(output_text) == 0: + raise ValueError("Decoding returned empty output") + response = output_text[0] print(f"✅ VL generation {stage_name} successful!") diff --git a/examples/llm_qad/qad.sh b/examples/llm_qad/qad.sh index 52ec2bd6ae..ac416ad355 100644 --- a/examples/llm_qad/qad.sh +++ b/examples/llm_qad/qad.sh @@ -181,7 +181,7 @@ CHECKPOINT_ARGS=" \ ${LOAD_OPTIM_ARGS} \ --load ${LOAD_CHECKPOINT_DIR} \ --export-kd-teacher-load ${TEACHER_CKPT} \ - --teacher-model-config ${TEACHER_MODEL_CONFIG}" + --export-kd-teacher-model-config ${TEACHER_MODEL_CONFIG}" # KD config (optional) if [[ -n "$KD_CFG_PATH" && -f "$KD_CFG_PATH" ]]; then diff --git a/examples/llm_qat/export.py b/examples/llm_qat/export.py index 77d75d47ab..1c9e6f4b11 100644 --- a/examples/llm_qat/export.py +++ b/examples/llm_qat/export.py @@ -23,7 +23,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.quantization.utils import set_quantizer_state_dict from modelopt.torch.utils import print_rank_0 @@ -51,6 +51,7 @@ def get_model( # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this if hasattr(model, "peft_config"): + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) restore_from_modelopt_state(model, modelopt_state) print_rank_0("Restored modelopt state") @@ -80,7 +81,9 @@ def main(args): base_model_dir = export_dir try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, is_modelopt_qlora=is_qlora + ) with open(f"{base_model_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) diff --git a/examples/llm_sparsity/attention_sparsity/.gitignore b/examples/llm_sparsity/attention_sparsity/.gitignore new file mode 100644 index 0000000000..480901bac3 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/.gitignore @@ -0,0 +1,2 @@ +# Data directory for calibration +data diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md new file mode 100644 index 0000000000..e9d50ae10c --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -0,0 +1,165 @@ +# Attention Sparsity for HuggingFace Models + +In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. + +## Getting Started + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +# Load your model +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-8B", + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, +) + +# Apply sparse attention +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +> [!Note] +> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. + +## Configuration Options + +Two pre-defined configurations are available: + +### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) + +Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) + +Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) +``` + +## Prerequisites + +### Local Installation + +For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example: + +```bash +pip install nvidia-modelopt[hf] +``` + +### Download RULER Calibration Data (Required for Calibration) + +If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: + +```bash +bash ./download_ruler_data.sh +``` + +This downloads the Paul Graham essays dataset used for generating calibration samples. + +## Run Sparse Attention on HuggingFace Models + +### Basic Usage (Without Calibration) + +Apply sparse attention with a fixed threshold: + +```bash +python hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax +``` + +### With RULER Calibration + +Apply sparse attention with calibrated thresholds for optimal sparsity: + +```bash +python hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib +``` + +The calibration process: + +1. Generates RULER calibration samples +2. Collects attention statistics during forward passes +3. Determines optimal threshold scale factor for target sparsity ratio + +### Command Line Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--pyt_ckpt_path` | Required | HuggingFace model path or name | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | +| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) | +| `--seq_len` | `2048` | Maximum sequence length for input prompts | +| `--export_dir` | `None` | Directory to export the sparsified model | + +## Output Comparison + +The script automatically compares outputs before and after applying sparse attention: + +1. Loads a test sample from the NarrativeQA dataset +2. Generates text before sparse attention is applied +3. Applies sparse attention (with optional calibration) +4. Generates text after sparse attention is applied +5. Compares and displays both outputs + +## Export Model + +Export the sparsified model to a HuggingFace checkpoint: + +```bash +python hf_sa.py \ + --pyt_ckpt_path Qwen/Qwen3-8B \ + --sparse_attn skip_softmax_calib \ + --export_dir ./exported_sparse_model +``` + +The exported model can be loaded and used with standard HuggingFace APIs. + +## Custom Configuration + +You can create custom sparse attention configurations: + +```python +custom_config = { + "sparse_cfg": { + "calibration": { # Optional: omit for fixed threshold + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, # Target 50% sparsity + "samples": 128, # Number of calibration samples + "max_seqlen": 8192, # Maximum sequence length + # Optional: customize threshold trials for calibration + "threshold_trials": [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1], + }, + "*attn*": { # Pattern to match attention modules + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, # Phase-specific thresholds (ignored if calibration is used) + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +## References + +- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/) +- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) diff --git a/examples/llm_sparsity/attention_sparsity/download_ruler_data.sh b/examples/llm_sparsity/attention_sparsity/download_ruler_data.sh new file mode 100644 index 0000000000..54797f2a58 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/download_ruler_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER calibration data for attention sparsity. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +mkdir -p "${ESSAYS_DIR}" + +# Download URL list if not exists +if [ ! -f "${URLS_FILE}" ]; then + echo "Downloading URL list..." + curl -fsSL "${URLS_URL}" -o "${URLS_FILE}" +fi + +# Download essays from GitHub URLs +echo -n "Downloading essays" +count=0 +while IFS= read -r url || [ -n "$url" ]; do + if [[ "${url}" == https://github.com*.txt ]]; then + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + if [ ! -f "${filepath}" ]; then + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null && echo -n "." + count=$((count + 1)) + fi + fi +done < "${URLS_FILE}" +echo " done" + +echo "Downloaded ${count} essays to ${ESSAYS_DIR}" diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 11564a4ece..74c5e9a540 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -17,20 +17,21 @@ """Example script for applying sparse attention to HuggingFace models.""" import argparse +import copy import random from pathlib import Path import numpy as np import torch -from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint -from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT -from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 @@ -38,47 +39,20 @@ # Enable HuggingFace checkpointing support mto.enable_huggingface_checkpointing() -# You can define custom configurations or use the default +# Sparse attention configuration choices SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } -def get_narrativeqa_samples(num_samples=3): - """Load samples from NarrativeQA dataset for testing. - - Args: - num_samples: Number of samples to generate - - Raises: - RuntimeError: If dataset loading fails - ValueError: If no valid samples could be loaded - """ - # Load NarrativeQA dataset with retry logic - try: - dataset = load_dataset("narrativeqa", split="test", streaming=True) - except Exception as e: - raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}") - - samples = [] - for i, item in enumerate(dataset): - if i >= num_samples: - break - - # Combine document context and question - context = item.get("document", {}).get("text", "") - question = item.get("question", {}).get("text", "") - - if context and question: - # Use the full context as-is - prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" - samples.append(prompt) - - if not samples: - raise ValueError("Could not load NarrativeQA samples") - - print(f"Loaded {len(samples)} NarrativeQA samples") - return samples +def get_test_prompts(): + """Get simple test prompts for sample output generation.""" + return [ + "What is the capital of France? Answer:", + "Explain the theory of relativity in simple terms:", + "Write a short poem about the ocean:", + ] def truncate_text(text: str, tokenizer, max_length: int): @@ -116,30 +90,23 @@ def truncate_text(text: str, tokenizer, max_length: int): return begin_text + " [...] " + end_text -def verify_outputs(model, tokenizer, args): - """Compare outputs between baseline and sparse attention models.""" - # Update seq_len to match calibration max_seqlen if calibration was used - base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) - if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: - calib_max_seqlen = base_config["calibration"]["max_seqlen"] - if args.seq_len != calib_max_seqlen: - print( - f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " - f"to match calibration config" - ) - args.seq_len = calib_max_seqlen +def generate_sample_output(model, tokenizer, args): + """Generate sample output for comparison. + + Args: + model: The model to generate with + tokenizer: Tokenizer for encoding/decoding + args: Command line arguments - # Load and prepare a single test prompt - print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") - prompts = get_narrativeqa_samples(num_samples=1) + Returns: + Tuple of (generated_text, input_prompt, input_ids) + """ + # Load test sample + prompts = get_test_prompts() prompt = prompts[0] # Prepare inputs truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) - display_prompt = ( - truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt - ) - inputs = tokenizer( truncated_prompt, return_tensors="pt", @@ -150,14 +117,7 @@ def verify_outputs(model, tokenizer, args): if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} - print("\n" + "=" * 60) - print("BASELINE vs SPARSE ATTENTION COMPARISON") - print("=" * 60) - print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]}") - - # Helper function to generate text - def generate_text(model, inputs, args, tokenizer): + # Generate with torch.no_grad(): outputs = model.generate( **inputs, @@ -168,60 +128,9 @@ def generate_text(model, inputs, args, tokenizer): ) input_length = inputs["input_ids"].shape[1] generated_ids = outputs[0][input_length:] - return tokenizer.decode(generated_ids, skip_special_tokens=True) - - # Find all sparse attention modules - sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] - - # Generate baseline by temporarily disabling sparse attention - print("\n" + "-" * 60) - print("Generating baseline (sparse attention disabled)...") - for module in sparse_modules: - module.disable() - baseline_text = generate_text(model, inputs, args, tokenizer) - - # Generate with sparse attention enabled - print("\nGenerating with sparse attention (calibrated thresholds)...") - for module in sparse_modules: - module.enable() - sparse_text = generate_text(model, inputs, args, tokenizer) - - # Display comparison - print("\n" + "-" * 60) - print("RESULTS:") - baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text - sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text - - print(f"\nBaseline: {baseline_display}") - print(f"With Sparse: {sparse_display}") - - if baseline_text == sparse_text: - print("\nOutputs are identical") - else: - print("\nOutputs differ") - + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) -def sparsify_model(model, args): - """Apply sparse attention to the model with optional calibration.""" - print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") - base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] - - # Create modified config with selected backend - modified_sparse_cfg = {} - for pattern, cfg in base_config["sparse_cfg"].items(): - modified_cfg = cfg.copy() - modified_cfg["backend"] = args.backend - modified_sparse_cfg[pattern] = modified_cfg - - # Create new config with modified settings - sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - - # Sparsify the model - model = mtsa.sparsify(model, config=sparse_config) - - print("Sparse attention applied successfully!") - - return model + return generated_text, truncated_prompt, inputs["input_ids"] def main(args): @@ -254,12 +163,54 @@ def main(args): model = model.cuda() print("Model moved to CUDA") - # Apply sparse attention to the model (with calibration if configured) - model = sparsify_model(model, args) + # Generate sample output BEFORE sparse attention + print("\nGenerating sample output before sparse attention...") + output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args) + + # Apply sparse attention with optional calibration + print(f"\nApplying sparse attention: {args.sparse_attn}") + sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Override calibration options if provided via CLI + if args.target_sparse_ratio is not None: + sparse_config = copy.deepcopy(sparse_config) + sparse_cfg = sparse_config.get("sparse_cfg", {}) + if isinstance(sparse_cfg, dict) and "calibration" in sparse_cfg: + calibration_cfg = sparse_cfg["calibration"] + if isinstance(calibration_cfg, dict): + calibration_cfg["target_sparse_ratio"] = { + "prefill": args.target_sparse_ratio, + "decode": args.target_sparse_ratio, + } + print(f"Overriding target_sparse_ratio to {args.target_sparse_ratio}") + + model = mtsa.sparsify(model, config=sparse_config) + print("Sparse attention applied successfully!") - # Verify outputs if requested (compares baseline vs calibrated sparse model) - if args.verify_output: - verify_outputs(model, tokenizer, args) + # Generate sample output AFTER sparse attention + print("\nGenerating sample output after sparse attention...") + output_after, _, _ = generate_sample_output(model, tokenizer, args) + + # Display comparison + print("\n" + "=" * 60) + print("OUTPUT COMPARISON (Before vs After Sparse Attention)") + print("=" * 60) + display_prompt = test_prompt[:150] + "..." if len(test_prompt) > 150 else test_prompt + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {input_ids.shape[1]}") + + output_before_display = ( + output_before[:300] + "..." if len(output_before) > 300 else output_before + ) + output_after_display = output_after[:300] + "..." if len(output_after) > 300 else output_after + + print(f"\nBefore sparse attention: {output_before_display}") + print(f"After sparse attention: {output_after_display}") + + if output_before == output_after: + print("\nOutputs are identical") + else: + print("\nOutputs differ") # Export if requested if args.export_dir: @@ -306,12 +257,6 @@ def main(args): default=2048, help="Maximum sequence length for input prompts (will be truncated if longer)", ) - parser.add_argument( - "--num_samples", - type=int, - default=3, - help="Number of samples to use from NarrativeQA dataset", - ) # Generation arguments parser.add_argument( @@ -321,11 +266,6 @@ def main(args): parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") # Operation arguments - parser.add_argument( - "--verify_output", - action="store_true", - help="Verify that sparse attention outputs match baseline", - ) parser.add_argument( "--export_dir", type=str, @@ -333,5 +273,13 @@ def main(args): help="Directory to export the model with sparse attention applied", ) + # Calibration arguments + parser.add_argument( + "--target_sparse_ratio", + type=float, + default=None, + help="Target sparsity ratio for calibration (0.0 to 1.0). Overrides config value.", + ) + args = parser.parse_args() main(args) diff --git a/examples/llm_sparsity/weight_sparsity/README.md b/examples/llm_sparsity/weight_sparsity/README.md index 4fba746d37..ca4df236ff 100644 --- a/examples/llm_sparsity/weight_sparsity/README.md +++ b/examples/llm_sparsity/weight_sparsity/README.md @@ -84,7 +84,7 @@ python data_prep.py --save_path data The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs. The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs. -The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions. +The input data is tokenized to a maximum length of 1024 tokens. ```sh bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \ diff --git a/examples/llm_sparsity/weight_sparsity/data_prep.py b/examples/llm_sparsity/weight_sparsity/data_prep.py index b37212f6a1..d91caaba8b 100644 --- a/examples/llm_sparsity/weight_sparsity/data_prep.py +++ b/examples/llm_sparsity/weight_sparsity/data_prep.py @@ -19,7 +19,7 @@ from datasets import load_dataset -dataset_id = "cnn_dailymail" +dataset_id = "abisee/cnn_dailymail" dataset_config = "3.0.0" text_column = "article" summary_column = "highlights" diff --git a/examples/llm_sparsity/weight_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py index d13b43fde9..7110846683 100644 --- a/examples/llm_sparsity/weight_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,5 +1,6 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py + +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py - -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -32,7 +26,6 @@ import argparse import copy import os -import pickle from collections.abc import Sequence from dataclasses import dataclass, field @@ -232,27 +225,17 @@ def __init__( ): super().__init__() - pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle" with training_args.main_process_first(): - if os.path.isfile(pickle_name): - with open(pickle_name, "rb") as f: - print_rank_0("Reuse pickled data") - data_dict = pickle.load(f) - else: - print_rank_0("Loading data...") - list_data_dict = utils.jload(data_path) - - print_rank_0("Formatting inputs...") - prompt_input = PROMPT_DICT["prompt_input"] - sources = [prompt_input.format_map(example) for example in list_data_dict] - targets = [ - f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict - ] - - print_rank_0("Tokenizing inputs... This may take some time...") - data_dict = preprocess(sources, targets, tokenizer) - with open(pickle_name, "wb") as f: - pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL) + print_rank_0("Loading data...") + list_data_dict = utils.jload(data_path) + + print_rank_0("Formatting inputs...") + prompt_input = PROMPT_DICT["prompt_input"] + sources = [prompt_input.format_map(example) for example in list_data_dict] + targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] + + print_rank_0("Tokenizing inputs... This may take some time...") + data_dict = preprocess(sources, targets, tokenizer) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] diff --git a/examples/llm_sparsity/weight_sparsity/hf_pts.py b/examples/llm_sparsity/weight_sparsity/hf_pts.py index 12e80b0b2c..ad8061211d 100644 --- a/examples/llm_sparsity/weight_sparsity/hf_pts.py +++ b/examples/llm_sparsity/weight_sparsity/hf_pts.py @@ -35,7 +35,7 @@ def get_calib_dataloader( ): print("Loading calibration dataset") if data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + dataset = load_dataset("abisee/cnn_dailymail", name="3.0.0", split="train") dataset = dataset["article"][:calib_size] else: raise NotImplementedError diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md new file mode 100644 index 0000000000..ce63f2612f --- /dev/null +++ b/examples/megatron_bridge/README.md @@ -0,0 +1,218 @@ +# Megatron Bridge + +This directory contains examples of using Model Optimizer with [NeMo Megatron-Bridge](https://github.com/NVIDIA-Nemo/Megatron-Bridge) framework for pruning, distillation, quantization, etc. + +
+ +| **Section** | **Description** | **Link** | +| :------------: | :------------: | :------------: | +| Pre-Requisites | Development environment setup | \[[Link](#pre-requisites)\] | +| Pruning | Examples of pruning a model using Minitron algorithm | \[[Link](#pruning)\] | +| Distillation | Examples of distillation a pruned or quantized model | \[[Link](#distillation)\] | +| Quantization | Examples of quantizing a model | \[[Link](#quantization)\] | +| Resources | Extra links to relevant resources | \[[Link](#resources)\] | + +
+ +## Pre-Requisites + +Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. + +To get the latest ModelOpt features and examples scripts, mount your Model-Optimizer repo to the container. + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Once inside the container, you need to login with your HuggingFace token to download gated datasets / models. +Note that the default dataset for pruning and quantization is [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2), which is gated. + +```bash +hf auth login --token +``` + +## Pruning + +This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). + +Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: + 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration, + at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...), + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_params 6e9 \ + --hparams_to_skip num_attention_heads \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B +``` + +Example usage for manually pruning to a specific architecture using following defaults: + 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration. + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_export_config '{"hidden_size": 3584, "ffn_hidden_size": 9216}' \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B-manual +``` + +To see the full usage for advanced configurations, run: + +```bash +torchrun --nproc_per_node 1 prune_minitron.py --help +``` + +> [!TIP] +> If number of layers in the model is not divisible by number of GPUs i.e. pipeline parallel (PP) size, you can configure +> uneven PP by setting `--num_layers_in_first_pipeline_stage` and `--num_layers_in_last_pipeline_stage`. +> E.g. for Qwen3-8B with 36 layers and 8 GPUs, you can set both to 3 to get 3-5-5-5-5-5-5-3 layers per GPU. + +## Distillation + +This section shows how to distill a student model from a teacher model in the Megatron-Bridge framework. + +This can be used stand-alone or after pruning (see [Pruning](#pruning)) / quantization (see [Quantization](#quantization)) to recover accuracy of the model by distilling from the original model (teacher). + +The [distill.py](distill.py) script loads student and teacher models from HuggingFace checkpoints and saves the distilled model to `/checkpoints` in Megatron distributed checkpoint format. + +### Data Preparation + +The distillation script expects pre-tokenized data in Megatron's binary format (`.bin` / `.idx` files). + +You can tokenize your JSONL datasets using the following command: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --jsonl_paths /path/to/data1.jsonl /path/to/data2.jsonl ... \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/data/qwen3 \ + --workers 32 \ + --max_sequence_length 256_000 +``` + +Instead of `--jsonl_paths`, you can also pass a directory path to the `--input_dir` argument to tokenize all JSONL files in the directory. +We are setting a maximum sequence length of 256k to avoid rare OOM errors in tokenization if text is too long. + +If you want to download and tokenize a dataset from Hugging Face Hub directly, you can use the following command: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset nvidia/Nemotron-Pretraining-SFT-v1 \ + --hf_name Nemotron-SFT-General \ + --hf_split train \ + --hf_max_samples_per_split 10_000_000 \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/data/qwen3 \ + --workers 32 \ + --max_sequence_length 256_000 +``` + +The [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) dataset is huge, so it will take a while to download and tokenize. You can also split the large `.jsonl` into multiple files (e.g. 10M samples per file using `split -l 10000000 -d --additional-suffix=.jsonl .jsonl _part`) and tokenize them parallelly. +To quickly test the script, you can try the [nvidia/Nemotron-Pretraining-Dataset-sample](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-Dataset-sample) dataset. + +If you skip `--hf_name`, it will download and tokenize all subsets for the dataset. +If you skip `--hf_split`, it will download and tokenize all splits for the subset. +If you skip `--hf_max_samples_per_split`, it will download and tokenize all samples for the split. + +### Distillation with Real Data + +Example usage to distill a 4B student (HF) from an 8B teacher (HF) on 8 GPUs (TP=8, PP=1): + +```bash +torchrun --nnodes 1 --nproc_per_node 8 distill.py \ + --tp_size 8 \ + --teacher_hf_path Qwen/Qwen3-8B \ + --student_hf_path Qwen/Qwen3-4B \ + --data_paths 1.0 /path/to/tokenized/data/qwen3 \ + --data_path_to_cache /path/to/cache/dataset_indices_qwen3 \ + --seq_length 8192 \ + --mbs 1 \ + --gbs 768 \ + --train_iters 15000 \ + --lr 1e-4 \ + --min_lr 1e-5 \ + --lr_warmup_iters 50 \ + --eval_interval 100 \ + --eval_iters 32 \ + --log_interval 10 \ + --output_dir /output/qwen3_8b_to_4b_distill +``` + +Tensorboard logging is enabled by default and logs are saved to `/tensorboard` directory. +To use Weights & Biases for logging, set the `WANDB_API_KEY` environment variable and pass the `--wandb_project` argument. +Optionally, you can also pass `--wandb_entity` and `--wandb_exp_name` arguments to group runs under a project and experiment name. + +To see all available arguments: + +```bash +torchrun --nproc_per_node 1 distill.py --help +``` + +### Quick Test with Mock Data + +Example usage with mock data for quick testing (no pre-tokenized data needed): + +```bash +torchrun --nproc_per_node 8 distill.py \ + --tp_size 8 \ + --teacher_hf_path Qwen/Qwen3-0.6B \ + --student_hf_path Qwen/Qwen3-0.6B \ + --use_mock_data \ + --seq_length 512 \ + --mbs 1 \ + --gbs 8 \ + --train_iters 100 \ + --eval_interval 10 \ + --eval_iters 4 \ + --output_dir /tmp/test_distill +``` + +### Slurm Usage + +To run the distillation script on a Slurm cluster for multi-node training, you just need use `python` instead of `torchrun` and set the number of nodes using `#SBATCH --nodes=` clause in your Slurm script. + +### Convert Megatron checkpoint to Hugging Face format + +To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`/checkpoints/iter_`) to run the following command: + +```bash +uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ + --hf-model \ + --megatron-path /checkpoints/iter_ \ + --hf-path +``` + +For more details, you can refer to the checkpoint conversion scripts in the [Megatron-Bridge README](https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/conversion). + +## Quantization + +TODO + +## Resources + +- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) +- 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) +- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) +- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) +- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py new file mode 100644 index 0000000000..8fc5cff6f8 --- /dev/null +++ b/examples/megatron_bridge/distill.py @@ -0,0 +1,252 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to `/checkpoints` in megatron distributed checkpoint format. + +See `README.md` in this directory for example usage and data preparation instructions. +""" + +import argparse +import os + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.distill import distill +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import print_rank_0 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments (accepts HuggingFace input only at the moment) + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training & Eval arguments + parser.add_argument( + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + # Logging arguments + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path) + provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.95, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=5, # Keeps 5 most recent checkpoints (not metric-based) + ckpt_format="torch_dist", + async_save=True, + fully_parallel_save=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0( + f"\nDistillation done! Saved checkpoint to {checkpoint_dir} in megatron distributed checkpoint format.\n" + ) + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py new file mode 100644 index 0000000000..c4da627f14 --- /dev/null +++ b/examples/megatron_bridge/prune_minitron.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example script for pruning a GPT / Mamba model using Minitron algorithm on a Megatron-Bridge model (load from HF). + +Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) +while skipping pruning of num_attention_heads using following defaults: + 1024 samples from nemotron-post-training-dataset-v2 for calibration, + at-most 20% depth (num_layers) and 40% width is pruned per prunable hparam (hidden_size, ffn_hidden_size, ...), + top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. + + torchrun --nproc_per_node 2 prune_minitron.py \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_params 6e9 \ + --hparams_to_skip num_attention_heads \ + --output_hf_path /tmp/Qwen3-8B-Pruned-6B + +To see the full usage for advanced configurations, run: + torchrun --nproc_per_node 1 prune_minitron.py --help + +See `README.md` in this directory for more details. +""" + +# TODO: Test multi-node pruning +import argparse +import json +import os + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider +from transformers import AutoConfig, AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.prune as mtp +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import get_supported_datasets, num2hrb, print_rank_0, warn_rank_0 +from modelopt.torch.utils.plugins.mbridge import ( + get_hf_mbridge_calibration_loop, + load_mbridge_model_from_hf, +) +from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--hf_model_name_or_path", type=str, required=True) + parser.add_argument("--trust_remote_code", action="store_true") + + target_group = parser.add_mutually_exclusive_group(required=True) + target_group.add_argument( + "--output_megatron_path", + type=str, + help="Path to save the pruned model in Megatron checkpoint format", + ) + target_group.add_argument( + "--output_hf_path", type=str, help="Path to save the pruned model in HF checkpoint format" + ) + + # Parallelism arguments + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + parser.add_argument( + "--num_layers_in_first_pipeline_stage", + type=int, + default=None, + help="Number of layers in the first pipeline stage (Uneven Pipeline Parallelism)", + ) + parser.add_argument( + "--num_layers_in_last_pipeline_stage", + type=int, + default=None, + help="Number of layers in the last pipeline stage (Uneven Pipeline Parallelism)", + ) + + # Calibration dataset parameters + parser.add_argument( + "--calib_dataset_name", + type=str, + default="nemotron-post-training-dataset-v2", + choices=get_supported_datasets(), + help="Dataset name for calibration", + ) + parser.add_argument( + "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" + ) + # TODO: Add support for pre-training dataset (pre-tokenized) + # TODO: only allow mbs>1 for pretraining dataset + parser.add_argument( + "--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size" + ) + parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size") + parser.add_argument("--seq_length", type=int, default=4096) + + # Pruning parameters + parser.add_argument( + "--prune_intermediate_ckpt", + type=str, + default=None, + help=( + "Path to save/restore intermediate pruning scores for resuming / faster re-run. " + "If not provided, it will default to `/modelopt_pruning_scores.pth`" + ), + ) + + target_group = parser.add_mutually_exclusive_group(required=True) + target_group.add_argument( + "--prune_export_config", + type=str, + help=( + 'Target pruned config as JSON e.g., \'{"hidden_size": 512, "ffn_hidden_size": 2048}\'. ' + f"Supported hyperparameters: {mtp.mcore_minitron.SUPPORTED_HPARAMS}. " + "Cannot be used with --prune_target_params." + ), + ) + target_group.add_argument( + "--prune_target_params", + type=float, + help=( + "Target parameter count for pruning e.g., 6e9 for pruning to 6B params (total params, not active params). " + "Uses Neural Architecture Search (NAS) to find the best pruned model that maximizes the --prune_score_func." + "Cannot be used with --prune_export_config." + ), + ) + + parser.add_argument( + "--prune_score_func", + type=str, + choices=["mmlu_5pct"], + default="mmlu_5pct", + help=( + "Score function to use for NAS-based pruning (--prune_target_params). Currently supported: " + "mmlu_5pct (MMLU on 5% sampled data per subject for faster eval). " + ), + ) + parser.add_argument( + "--max_width_pruning", + type=float, + default=0.4, + help=( + f"Maximum width pruning percentage ({mtp.mcore_minitron.SUPPORTED_HPARAMS - {'num_layers'}}) " + "for NAS-based pruning (--prune_target_params)" + ), + ) + parser.add_argument( + "--max_depth_pruning", + type=float, + default=0.2, + help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning (--prune_target_params)", + ) + parser.add_argument( + "--hparams_to_skip", + nargs="*", + type=str, + default=[], + choices=mtp.mcore_minitron.SUPPORTED_HPARAMS, + help=( + "Space-separated list of hparams to skip for NAS-based pruning (--prune_target_params) " + "e.g. dont prune 'num_attention_heads'" + ), + ) + parser.add_argument( + "--top_k", + type=int, + default=10, + help=( + "Number of top candidates to consider for NAS-based pruning (--prune_target_params). " + "Higher values will take longer to prune but may find a better model." + ), + ) + + args = parser.parse_args() + + # Post-process arguments + if args.prune_intermediate_ckpt is None: + if args.output_megatron_path: + args.prune_intermediate_ckpt = ( + f"{args.output_megatron_path}/modelopt_pruning_scores.pth" + ) + elif args.output_hf_path: + args.prune_intermediate_ckpt = f"{args.output_hf_path}/modelopt_pruning_scores.pth" + print_rank_0( + "No checkpoint provided to cache intermediate pruning scores. " + f"Setting to: {args.prune_intermediate_ckpt}" + ) + + if args.prune_export_config: + try: + prune_export_config = json.loads(args.prune_export_config) + except json.JSONDecodeError as exc: + raise ValueError( + f"Invalid JSON for --prune_export_config: {args.prune_export_config}" + ) from exc + if not isinstance(prune_export_config, dict): + raise ValueError("--prune_export_config must parse to a dictionary.") + args.prune_export_config = prune_export_config + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + assert dist.size() == args.pp_size, "Only Pipeline parallelism is supported for pruning." + + if args.output_megatron_path and os.path.exists( + f"{args.output_megatron_path}/latest_checkpointed_iteration.txt" + ): + warn_rank_0(f"\nPruned model already exists at {args.output_megatron_path}. Exiting...") + return + elif args.output_hf_path and os.path.exists(f"{args.output_hf_path}/config.json"): + warn_rank_0(f"\nPruned model already exists at {args.output_hf_path}. Exiting...") + return + + bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf( + hf_model_name_or_path=args.hf_model_name_or_path, + trust_remote_code=args.trust_remote_code, + provider_overrides={ + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": args.pp_size, + "num_layers_in_first_pipeline_stage": args.num_layers_in_first_pipeline_stage, + "num_layers_in_last_pipeline_stage": args.num_layers_in_last_pipeline_stage, + "pipeline_dtype": torch.bfloat16, + "seq_length": args.seq_length, + }, + init_model_parallel=True, + ) + print_rank_0(f"\nPruning {unwrapped_model=}") + print_rank_0( + f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" + ) + + forward_loop = get_hf_mbridge_calibration_loop( + model=model, + provider=provider, + tokenizer=tokenizer, + hf_model_name_or_path=args.hf_model_name_or_path, + trust_remote_code=args.trust_remote_code, + dataset_name=args.calib_dataset_name, + num_samples=args.calib_num_samples, + micro_batch_size=args.calib_mbs, + global_batch_size=args.calib_gbs, + ) + + pruning_config = { + "forward_loop": forward_loop, + "checkpoint": args.prune_intermediate_ckpt, + } + if args.prune_target_params is not None: + # Restrict search space to a smaller set of candidates + # NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model. + ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=256, + ffn_hidden_size_divisor=512, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=2, + ) + + pruning_constraints = {"params": args.prune_target_params} + print_rank_0( + f"Using NAS-based automatic pruning with score function: {args.prune_score_func}" + "You can change this to be any other metric you want to maximize (e.g. negative validation loss)." + ) + + def score_func_mmlu(m): + return megatron_mmlu(m, tokenizer, percentage=0.05) + + pruning_config["score_func"] = score_func_mmlu + pruning_config["max_width_pruning"] = args.max_width_pruning + pruning_config["max_depth_pruning"] = args.max_depth_pruning + pruning_config["hparams_to_skip"] = args.hparams_to_skip + pruning_config["top_k"] = args.top_k + elif args.prune_export_config is not None: + # Less restrictive search space for manual pruning + ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=64, + ffn_hidden_size_divisor=64, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=1, + ) + + pruning_constraints = {"export_config": args.prune_export_config} + print_rank_0(f"Pruning constraints: {pruning_constraints}") + + unwrapped_model, pruning_scores = mtp.prune( # in-place pruning + unwrapped_model, + mode=[("mcore_minitron", ss_config)], # type: ignore[arg-type] + constraints=pruning_constraints, + dummy_input=None, + config=pruning_config, + ) + # Remove unnecessary modelopt_state since ckpt is homogeneous + if mto.ModeloptStateManager.has_state_for_mode_type("prune", model=unwrapped_model): + mto.ModeloptStateManager.remove_state(unwrapped_model) + if isinstance(provider, MambaModelProvider): + provider.hybrid_override_pattern = unwrapped_model.hybrid_override_pattern + print_rank_0(f"\nPruned {unwrapped_model=}") + print_rank_0( + f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" + ) + + if args.output_megatron_path is not None: + print_rank_0( + f"Saved pruned model to {args.output_megatron_path} in Megatron checkpoint format" + ) + + # NOTE: Issue with NemotronH tokenizer's len() hence using use_fast=True as a WAR + use_fast_tokenizer = isinstance(provider, NemotronHModelProvider) + bridge.save_megatron_model( + model, + args.output_megatron_path, + hf_tokenizer_path=args.hf_model_name_or_path, + hf_tokenizer_kwargs={ + "trust_remote_code": args.trust_remote_code, + "use_fast": use_fast_tokenizer, + }, + ) + print_rank_0( + f"Saved pruned model to {args.output_megatron_path} in Megatron checkpoint format" + ) + else: + print_rank_0(f"Saving pruned model to {args.output_hf_path} in HF checkpoint format") + + # [WAR] Hacky way to save pruned HF model until Megatron-Bridge natively supports it + bridge.hf_pretrained.save_artifacts(args.output_hf_path) + hf_cfg = AutoConfig.from_pretrained( + args.output_hf_path, trust_remote_code=args.trust_remote_code + ) + mcore_cfg = unwrapped_model.config + + hf_cfg.hidden_size = mcore_cfg.hidden_size + hf_cfg.intermediate_size = mcore_cfg.ffn_hidden_size + hf_cfg.num_attention_heads = mcore_cfg.num_attention_heads + hf_cfg.head_dim = mcore_cfg.kv_channels + hf_cfg.num_key_value_heads = mcore_cfg.num_query_groups + if hasattr(hf_cfg, "mamba_num_heads"): + hf_cfg.mamba_num_heads = mcore_cfg.mamba_num_heads + if hasattr(hf_cfg, "mamba_head_dim"): + hf_cfg.mamba_head_dim = mcore_cfg.mamba_head_dim + if hasattr(hf_cfg, "moe_intermediate_size"): + hf_cfg.moe_intermediate_size = mcore_cfg.moe_ffn_hidden_size + if hasattr(hf_cfg, "moe_shared_expert_intermediate_size"): + hf_cfg.moe_shared_expert_intermediate_size = ( + mcore_cfg.moe_shared_expert_intermediate_size + ) + if hasattr(hf_cfg, "num_experts"): + hf_cfg.num_experts = mcore_cfg.num_moe_experts + if hasattr(hf_cfg, "n_routed_experts"): + hf_cfg.n_routed_experts = mcore_cfg.num_moe_experts + if hasattr(hf_cfg, "n_shared_experts"): + hf_cfg.n_shared_experts = ( + mcore_cfg.moe_shared_expert_intermediate_size // mcore_cfg.moe_ffn_hidden_size + ) + if hasattr(hf_cfg, "layer_types"): + kept_layer_nums = pruning_scores["sorted_layers"][: mcore_cfg.num_layers] # 1-indexed + hf_cfg.layer_types = [ + lt for i, lt in enumerate(hf_cfg.layer_types) if i + 1 in kept_layer_nums + ] + if hasattr(hf_cfg, "hybrid_override_pattern"): + hf_cfg.hybrid_override_pattern = unwrapped_model.hybrid_override_pattern + hf_cfg.num_hidden_layers = mcore_cfg.num_layers + + # Save dummy pruned HF model to get the correct bridge for saving pruned weights + AutoModelForCausalLM.from_config( + hf_cfg, trust_remote_code=args.trust_remote_code + ).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code) + pruned_bridge = AutoBridge.from_hf_pretrained( + args.output_hf_path, trust_remote_code=args.trust_remote_code + ) + pruned_bridge.save_hf_weights(model, args.output_hf_path) + print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format") + + print_rank_0("Done!") + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() diff --git a/examples/nemo_run/common/process_climbmix.py b/examples/nemo_run/common/process_climbmix.py index 18fd35f2d2..a6f91cc11b 100644 --- a/examples/nemo_run/common/process_climbmix.py +++ b/examples/nemo_run/common/process_climbmix.py @@ -67,7 +67,7 @@ def get_args(): print("Tokenizing ClimbMix dataset...") input_paths = [raw_dir / name for name in subset_filenames] megatron_preprocess_data( - input_paths, + jsonl_paths=input_paths, output_dir=proc_dir, tokenizer_name_or_path=args.tokenizer, append_eod=True, diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 9792f2932c..c34b957723 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -18,27 +18,36 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar | Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | | Getting Started | Learn how to use the pruning API | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html)\] | | Support Matrix | View the support matrix to see available pruning algorithms and their compatibility with different models and frameworks | \[[Link](#support-matrix)\] | | -| Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | | | Examples | Examples of different pruning methods | \[[Link](#examples)\] | | +| Pruning Guidelines | Guidelines for choosing how and how much to prune for best results | \[[Link](#pruning-guidelines)\] | | | Resources | Extra links to relevant resources | \[[Link](#resources)\] | | ## Pre-Requisites -For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.09`) which has all the dependencies installed. +For Minitron pruning for Megatron-LM / NeMo models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:25.11`) which has all the dependencies installed. Make sure to upgrade Model Optimizer to the latest version using `pip`. For FastNAS pruning for PyTorch Computer Vision models, no additional dependencies are required. -For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are requisred. +For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are required. ## Getting Started -As part of the pruning process, you will need to set up the training and/or validation data loaders, and optionally define a validation score function (FastNAS) or loss function (GradNAS) and specify the desired pruning constraints (See [Support Matrix](#support-matrix) for available pruning constraints). +As part of the pruning process, you will need to set up the training and/or validation data loaders, and optionally define a validation score function (Minitron, FastNAS) or loss function (GradNAS) and specify the desired pruning constraints (See [Support Matrix](#support-matrix) for available pruning constraints). + +To prune your model, you can simply call the `mtp.prune` API and save the pruned model. If the model is pruned using Minitron, you can use your standard saving and loading functions since it is a homogeneous pruning; while for FastNAS or GradNAS, you need to use `mto.save` and `mto.restore` to save and restore the heterogeneous pruned model. + +### Minitron + +Minitron pruning supports two modes: + +1. **Manual Pruning**: Manually specify the target dimensions for each pruning axis (e.g., `constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}}`) +2. **NAS-based Auto Pruning (New)**: Specify a target parameter count (e.g., `constraints = {"params": 6e9}`) and let the algorithm automatically search for the best architecture that maximizes a user-defined score function (e.g. MMLU, negative validation loss, etc.) -To prune your model, you can simply call the `mtp.prune` API and save the pruned model. If the model is pruned using FastNAS or GradNAS, you need to use `mto.save` and `mto.restore` to save and restore the pruned model; while for Minitron pruning, you can use your standard saving and loading functions since it is a homogeneous pruning. +Please see example snippets of both modes for Minitron pruning on Megatron-Core GPT model below. For end-to-end examples script (M-LM / NeMo framework), please refer to the examples below. -Please see an example snippet of Minitron pruning for Megatron-Core GPT model below (for other algorithms, please refer to the examples below). +#### Common Setup ```python import modelopt.torch.prune as mtp @@ -46,11 +55,11 @@ from megatron.core.models.gpt import GPTModel from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec from megatron.core.transformer.transformer_config import TransformerConfig -# Load the Megatron-Core GPTModel with ModelOpt transformer layer spec -config = TransformerConfig(...) +# Load the Megatron-Core GPTModel MambaModel with ModelOpt transformer layer spec +model_config = TransformerConfig(...) model = GPTModel( - config=config, - transformer_layer_spec=get_gpt_modelopt_spec(config, remap_te_layernorm=True), + config=model_config, + transformer_layer_spec=get_gpt_modelopt_spec(model_config, remap_te_layernorm=True), ... ) @@ -61,41 +70,141 @@ from megatron.training.training import evaluate_and_print_results def forward_loop(_): evaluate_and_print_results(prefix, forward_step, train_iterator, model, ...) - -# Specify the pruning constraints (Check Support Matrix for available pruning dimensions) -export_config = { - "hidden_size": 3072, - "ffn_hidden_size": 9216, -} - - # Run the pruning process (if model is a list then pass model[0] to the prune API) -# Save minitron scores at scores_path so we can re-run pruning with different export configs without running the forward loop again -# NOTE: Skip scores_path on re-running if you want to change the dataset and re-calibrate +# Save minitron scores at checkpoint so we can re-run pruning with different constraints without running the forward loop again +# NOTE: Skip checkpoint on re-running if you want to change the dataset and re-calibrate model, pruning_scores = mtp.prune( model, mode="mcore_minitron", - constraints={"export_config": export_config}, + constraints=constraints, dummy_input=None, # Not used - config={"forward_loop": forward_loop, "scores_path": "modelopt_minitron_scores.pth"}, + config=config, ) ``` -If your model parameters are already sorted, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`. - > [!Note] > Fine-tuning / distillation is required after pruning to recover the accuracy. Please refer to [end-to-end pruning and distillation tutorial](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) for more details. +#### 1. Manual Pruning + +This mode can be useful when you know the exact dimensions you want to prune to (e.g. fitting a specific latency / memory budget). + +```python +# Specify the pruning constraints (Check Support Matrix for available pruning dimensions) +constraints = {"export_config": {"hidden_size": 3072, "ffn_hidden_size": 9216}} +config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"} + +mtp.prune(...) +``` + +**Under the Hood:** + +1. **Importance Scoring**: Runs forward passes on calibration data (512-1024 samples) to compute activation magnitudes for each neuron/head/layer (takes ~5 minutes for an 8B model) +2. **Ranking**: Ranks all parameters within each pruning dimension (e.g., all hidden dimensions, all attention heads) by their importance scores +3. **Pruning**: Removes the least important parameters to meet the specified target dimensions in `export_config` +4. **Weight Slicing**: Slices the model weights according to the pruned architecture (homogeneous pruning - all layers pruned uniformly) + +> [!TIP] +> Checkout the [Pruning Guidelines](#pruning-guidelines) section for more details on how to choose the best pruning strategy and distillation hyperparameters. + +#### 2. NAS-based Auto Pruning + +This mode can be useful when you don't know the exact dimensions you want to prune to and want the algorithm to search for the best architecture that maximizes a user-defined score function at the cost of longer runtime. + +```python +# Define the score function to maximize (e.g., MMLU, negative validation loss, etc.) +# The algorithm will search for the best architecture that maximizes this score +from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu + +def score_func(m): + return megatron_mmlu(m, tokenizer, percentage=0.05) # 5% sampled data for faster eval + +# Specify target parameter count and configure the auto pruning algorithm +constraints = {"params": 6e9} # Prune to 6B parameters +config = { + "forward_loop": forward_loop, + "checkpoint": "/path/to/cache/pruning/scores.pth", + "score_func": score_func, + # Optional: Configure search space constraints (showing defaults) + "max_width_pruning": 0.4, # Maximum 40% per width pruning hparam + "max_depth_pruning": 0.2, # Maximum 20% per depth pruning hparam (num_layers) + "hparams_to_skip": [], # Disable pruning specific hparams, e.g., ["num_attention_heads"] + "top_k": 10, # Number of top architectures to evaluate (use 20 for better results at the cost of 2x time) +} + +mtp.prune(...) +``` + +**Under the Hood:** + +1. **Importance Scoring**: Same as manual pruning - computes activation magnitudes for all parameters (takes ~5 minutes for an 8B model) +2. **Search Space Construction**: Generates a search space of possible architectures based search space config and other configs (`max_width_pruning`, `max_depth_pruning`, `hparams_to_skip`) +3. **Architecture Search**: Find candidate architectures that meet the parameter constraint and evaluate `top_k` (based on number of parameters) of them using `score_func` e.g. MMLU, negative validation loss, etc. (takes ~10 mins per candidate for an 8B model pruning) +4. **Best Architecture Selection**: Returns the architecture (best `export_config`) with the highest actual score from the top-K evaluated architectures +5. **Weight Slicing**: Slices the model weights according to the best pruned architecture found + +> [!Note] +> As per the [original paper](https://arxiv.org/pdf/2407.14679), ideally we need to perform a short Knowledge Distillation on ~2B tokens for all top-K candidate architectures before evaluating the score function, which will take a lot longer to prune, require splitting the pruning process into multiple stages and a lot more compute for pruning but can lead to better pruned model. If you are interested to do this, you can take the top-K candidate's `export_config` from the pruning logs and then export all models separately and perform Knowledge Distillation on each of them before evaluating the score function. + +#### Advanced Configuration + +For finer control over the search space (e.g., granularity of pruning choices), you can configure the divisors: + +```python +# Configure search space granularity (showing defaults) +ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=256, + ffn_hidden_size_divisor=512, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=2, +) + +# Use the custom search space config +mtp.prune(model, mode=[("mcore_minitron", ss_config)], ...) +``` + +If your model parameters are already sorted and you just want to prune the weights, you can skip the sorting step by setting `"skip_sorting": True` in `config` instead of passing `forward_loop`. + ## Support Matrix | **Algorithm** | **Model** | **Pruning Constraints** | | :---: | :---: | :---: | -| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models1 | Export config with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) values | -| FastNAS | Computer Vision models | flops, parameters | -| GradNAS | HuggingFace BERT, GPT-J | flops, parameters | +| Minitron | Megatron-core / NeMo based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** `params` (requires `score_func` in config) | +| FastNAS | Computer Vision models | `flops`, `params` | +| GradNAS | HuggingFace BERT, GPT-J | `flops`, `params` | > *1.Only Pipeline Parallel models are supported. Hugging Face models can be converted to Megatron-LM/NeMo format and used subsequently.* +## Examples + +### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) + +Checkout the Minitron pruning example for the [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) or [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama-3.1-8B, Qwen3-8B, Nemotron-Nano-9B-v2, Nemotron-3-Nano-30B-A3B, etc. +Both frameworks support importing from a Hugging Face pretrained checkpoint. + +You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen3-8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. + +Some of the models pruned using Minitron method followed by distillation and post-training are: + +- [Minitron Collection on Hugging Face](https://huggingface.co/collections/nvidia/minitron) +- [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) + +### FastNAS Pruning for PyTorch Computer Vision Models + +Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). + +You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory +which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook +also shows how to profile the model to understand the search space of possible pruning options and demonstrates +how to save and restore pruned models. + +### GradNAS Pruning for HuggingFace Language Models (e.g. BERT) + +Checkout the BERT pruning example in [chained_optimizations](../chained_optimizations/README.md) directory +which showcases the usage of GradNAS for pruning BERT model for Question Answering followed by fine-tuning +with distillation and quantization. The example also demonstrates how to save and restore pruned models. + ## Pruning Guidelines ### Minitron @@ -174,35 +283,6 @@ After pruning, distillation is required to recover model accuracy. Below are rec > [!TIP] > If you know the maximum learning rate used during the original training, a good rule of thumb for knowledge distillation is to use **1/5th of that maximum LR** when compressing by ~50%. -## Examples - -### Minitron Pruning for Megatron-LM / NeMo Framework LLMs (e.g. Qwen 3, Nemotron Nano) - -Checkout the Minitron pruning example for the [Megatron-LM Framework](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-pruning) or [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html) which showcases the usage of the powerful Minitron pruning algorithm developed by NVIDIA Research for pruning LLMs like Llama 3.1 8B, Qwen 3 8B, Nemotron Nano 12B v2, etc. -Both frameworks support importing from a Hugging Face pretrained checkpoint. - -You can also look at the NeMo tutorial notebooks [here](https://github.com/NVIDIA-NeMo/NeMo/tree/main/tutorials/llm/qwen/pruning-distillation) which showcase the usage of Minitron pruning followed by distillation for Qwen 3 8B step-by-step in NeMo framework. Hugging Face models can also be converted to NeMo format and used subsequently as shown in the tutorial. - -Some of the models pruned using Minitron method followed by distillation and post-training are: - -- [Minitron Collection on Hugging Face](https://huggingface.co/collections/nvidia/minitron) -- [NVIDIA-Nemotron-Nano-9B-v2](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2) - -### FastNAS Pruning for PyTorch Computer Vision Models - -Check out the FastNAS pruning example usage in the [documentation](https://nvidia.github.io/Model-Optimizer/guides/3_pruning.html#pruning-and-subnet-search). - -You can also take a look at FastNAS pruning interactive notebook [cifar_resnet](./cifar_resnet.ipynb) in this directory -which showcases the usage of FastNAS for pruning a ResNet 20 model for the CIFAR-10 dataset. The notebook -also shows how to profile the model to understand the search space of possible pruning options and demonstrates -how to save and restore pruned models. - -### GradNAS Pruning for HuggingFace Language Models (e.g. BERT) - -Checkout the BERT pruning example in [chained_optimizations](../chained_optimizations/README.md) directory -which showcases the usage of GradNAS for pruning BERT model for Question Answering followed by fine-tuning -with distillation and quantization. The example also demonstrates how to save and restore pruned models. - ## Resources - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) diff --git a/examples/puzzletron/GPTOSS.md b/examples/puzzletron/GPTOSS.md new file mode 100644 index 0000000000..c996363b6c --- /dev/null +++ b/examples/puzzletron/GPTOSS.md @@ -0,0 +1,14 @@ + +## GptOss - 20b + +With this release Puzzle algorithm supports only experts removal for `Gpt-Oss-20b`. + +This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with _MXFP4_ format. +In the prunning steps puzzle utilizes decompressed model (back to BF16) for statistics and scores computation. +This means, during the conversion to puzzle format we decompress the model and store it as a BF16. +Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the _MXFP4_ format of the checkpoint. +To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in _MXFP4_ format. + +```bash +python -m modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_pruned_to_mxfp4 --student-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --num-layers 24 +``` diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 48f64d3c41..d8f5164023 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -9,7 +9,7 @@ The supported modifications are: To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements. -In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. +In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). > **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). @@ -275,21 +275,9 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 -- ## Knowledge Distillation -To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. -First, convert the HF model to NeMo format: - -```bash -python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo -``` - -Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). - -[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. - -```bash -python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF -``` +See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation. ## Advanced Usage diff --git a/examples/puzzletron/automodel_distillation/README.md b/examples/puzzletron/automodel_distillation/README.md new file mode 100644 index 0000000000..35621eca3c --- /dev/null +++ b/examples/puzzletron/automodel_distillation/README.md @@ -0,0 +1,122 @@ +# Knowledge Distillation with NeMo AutoModel + +This guide shows how to run knowledge distillation on Puzzletron-compressed AnyModel (heterogeneous) checkpoints using **NeMo AutoModel**. AutoModel enables efficient training of any HuggingFace model with a unified API; here we extend it to load heterogeneous checkpoints and use TP-friendly KD loss. + +## Overview + +1. **AutoModel + AnyModel**: We monkey-patch NeMo AutoModel so `from_pretrained(..., anymodel_descriptor=..., block_configs_path=...)` can load heterogeneous checkpoints. The patch uses ModelOpt’s `ModelDescriptorFactory` and `deci_x_patcher` to apply per-layer configs during model init. +2. **Custom KD recipe**: For distillation we use a custom recipe (`recipe.py`) that adds pipeline-parallel (PP) support, better logging, and TP-friendly KD loss. Pretraining is unchanged and uses AutoModel’s built-in recipe. Once the AutoModel repo gains these features, the custom recipe can be dropped and the upstream KD recipe used instead. +3. **KD loss** (`loss.py`): We provide a TP-aware KD on precomputed logits only; CE is computed separately and mixed with `kd_ratio`. + +**Supported parallelisms** +FSDP is fully supported. Pipeline parallelism (PP) is supported for most models; exceptions are those whose layer naming does not follow the usual HuggingFace convention. Tensor parallelism (TP) and sequence parallelism (SP) are mostly supported—a known exception is GPT-OSS due to sink tokens (AutoModel has the same limitation; it is not specific to AnyModel). Context parallelism (CP) is supported for all models tested. Expert parallelism (EP) is not supported: AutoModel relies on custom (non–HuggingFace) model implementations for EP, which conflicts with the goal of supporting any HF model. + +## Setup + +**Requirements** + +- NeMo AutoModel (install from source or use a container that provides it). +- ModelOpt installed (`pip install nvidia-modelopt` or install from the Model-Optimizer repo). +- For KD: this example’s `recipe.py`, `loss.py`, and `patch_automodel.py` (the run entrypoint always applies the patch before loading models). + +**Environment** + +Set `PYTHONPATH` so that the Model-Optimizer root is on the path (for ModelOpt and, if you run this example as a module, for `automodel_distillation`): + +```bash +export PYTHONPATH="/path/to/Model-Optimizer:${PYTHONPATH}" +``` + +If you use a NeMo AutoModel container, ensure the AutoModel package is installed (e.g. clone AutoModel and `pip install -e .`). Upgrade HuggingFace Transformers if needed (e.g. for compatibility): + +```bash +python -m pip install -e /path/to/AutoModel +python -m pip install -U omegaconf fire transformers +``` + +## Configuration + +- **pretrain.yaml** – Pretrain/finetune on an AnyModel checkpoint. Set `model.pretrained_model_name_or_path` and `model.anymodel_descriptor` (e.g. `gpt_oss_20b`, `llama`, `qwen2`, `qwen3`). Optional: `model.block_configs_path`; if omitted, block configs are auto-detected from `/block_configs.json`. +- **kd.yaml** – Knowledge distillation. Set `model.pretrained_model_name_or_path` and `model.anymodel_descriptor` for the student, and `teacher_model.pretrained_model_name_or_path` and `teacher_model.anymodel_descriptor` for the teacher. + +Paths and descriptors can be overridden from the command line (see below). + +## Run + +**Apply the patch and run KD** + +Before loading models, the run entrypoint calls `apply_patch()` so that `from_pretrained` accepts `anymodel_descriptor` and `block_configs_path`. Then it loads the config and runs the chosen recipe. + +Run from the **automodel_distillation** directory so that `run.py` can import `patch_automodel` and `recipe`: + +```bash +cd /path/to/Model-Optimizer/examples/puzzletron/automodel_distillation +torchrun --nproc_per_node=2 \ + -m run \ + --mode kd \ + -c kd.yaml +``` + +Override config (e.g. paths and descriptor) on the command line: + +```bash +torchrun --nproc_per_node=2 \ + -m run \ + --mode kd \ + -c kd.yaml \ + model.pretrained_model_name_or_path=/path/to/student \ + model.anymodel_descriptor=gpt_oss_20b \ + teacher_model.pretrained_model_name_or_path=/path/to/teacher \ + teacher_model.anymodel_descriptor=gpt_oss_20b +``` + +**Pretrain (uses AutoModel’s built-in recipe)** + +```bash +torchrun --nproc_per_node=2 \ + -m run \ + --mode pretrain \ + -c pretrain.yaml \ + model.pretrained_model_name_or_path=/path/to/checkpoint \ + model.anymodel_descriptor=gpt_oss_20b +``` + +**Note:** If you run from a different layout (e.g. from the Model-Optimizer repo root or under another package name), set `PYTHONPATH` to include this directory so `run` can import `patch_automodel` and `recipe`, and ensure the config `kd_loss_fn._target_` (e.g. `loss.KDLoss`) resolves to the correct module. + +## Example: Running on a cluster + +Below is an example job setup: NeMo AutoModel container, clone AutoModel main, install it and upgrade Transformers, then run KD from a directory that contains your config and run script (e.g. a copy of this example or the RealAnyModel layout). + +```bash +# Submit interactive job +srun --partition=interactive --time=2:00:00 --gres=gpu:2 \ + --container-image=nvcr.io/nvidia/nemo-automodel:25.11.00 \ + --container-mounts="/path/to/AutoModel/:/opt/Automodel/" \ + --pty bash + +# Inside the container +source /opt/venv/bin/activate +cd /opt/Automodel/ +python -m pip install -e . +python -m pip install -U omegaconf fire transformers +python -m pip uninstall nvidia-modelopt +cd /path/to/Model-Optimizer +python -m pip install -e . + +# Run KD (from your project dir that has run.py, kd.yaml, patch_automodel, loss, recipe) +cd ./examples/puzzletron/automodel_distillation/ +torchrun --nproc_per_node 2 -m run --mode kd -c kd.yaml 2>&1 | tee logs +``` + +Use your own paths for mounts, checkpoint dirs, and config overrides as needed. + +## Files in this example + +| File | Purpose | +|------|--------| +| `patch_automodel.py` | Monkey-patch so `from_pretrained` accepts `anymodel_descriptor` and `block_configs_path`; uses ModelOpt’s `deci_x_patcher`. | +| `loss.py` | KDLoss: TP-aware KD on precomputed logits (CE is mixed via `kd_ratio` in the recipe). | +| `recipe.py` | Custom KD recipe (PP support, logging, TP-friendly KD). | +| `run.py` | Entrypoint: applies patch, then runs pretrain or KD using the config. | +| `pretrain.yaml` | Pretrain config (no hardcoded paths; override on CLI). | +| `kd.yaml` | KD config (no hardcoded paths; override on CLI). | diff --git a/examples/puzzletron/automodel_distillation/kd.yaml b/examples/puzzletron/automodel_distillation/kd.yaml new file mode 100644 index 0000000000..ef81a88871 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/kd.yaml @@ -0,0 +1,133 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Knowledge distillation: student and teacher are AnyModel checkpoints. +# Requires apply_patch() from patch_automodel. Set model and teacher_model paths and descriptors. +# anymodel_descriptor must match a ModelOpt ModelDescriptorFactory name (e.g. gpt_oss_20b, llama, qwen2, qwen3). +# +# KD loss (kd_loss_fn._target_): use loss.KDLoss for TP-aware KD on precomputed logits. +# CE is computed by loss_fn and mixed with KD via kd_ratio in the recipe. +# If running under a different package name, use that module path (e.g. automodel_distillation.loss.KDLoss). +# +# To run: +# torchrun --nproc_per_node -m automodel_distillation.run --mode kd -c kd.yaml +# Override: model.pretrained_model_name_or_path=/path/to/student model.anymodel_descriptor=llama ... + +step_scheduler: + global_batch_size: 128 + local_batch_size: 4 + ckpt_every_steps: 200 + val_every_steps: 100 + num_epochs: 2 + +dist_env: + backend: nccl + timeout_minutes: 5 + +rng: + _target_: nemo_automodel.components.training.rng.StatefulRNG + seed: 1111 + ranked: true + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: ./heterogeneous_ckpts/meta-llama-Llama-3.1-8B-Instruct/ # student checkpoint dir + anymodel_descriptor: llama # e.g. gpt_oss_20b, llama, qwen2, qwen3 + force_hf: true + torch_dtype: bf16 + trust_remote_code: true + +teacher_model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: ./heterogeneous_ckpts/meta-llama-Llama-3.1-8B-Instruct-teacher/ # teacher checkpoint dir + anymodel_descriptor: llama # same format as model.anymodel_descriptor + force_hf: true + torch_dtype: bf16 + trust_remote_code: true + +checkpoint: + enabled: true + checkpoint_dir: checkpoints/ + model_save_format: safetensors + save_consolidated: false + +distributed: + dp_size: none + tp_size: 2 + cp_size: 1 + ep_size: 1 + sequence_parallel: false + pp_size: 1 + pipeline: + pp_schedule: interleaved1f1b + pp_microbatch_size: 1 + scale_grads_in_schedule: false + round_virtual_stages_to_pp_multiple: up + dtype: bf16 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + activation_checkpointing: false + +compile_config: + enabled: true + +packed_sequence: + packed_sequence_size: 1024 + split_across_pack: false + +loss_fn: + _target_: nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy + +# 0 = pure CE (better to run pretrain instead of loading a teacher and not using it) +# 1 = pure KD (common practice for puzzletron distillation) +kd_ratio: 1.0 + +kd_loss_fn: + _target_: loss.KDLoss + ignore_index: -100 + temperature: 1.0 + fp32_upcast: true + +optimizer: + _target_: torch.optim.Adam + betas: [0.9, 0.999] + eps: 1.0e-8 + lr: 1.0e-5 + weight_decay: 0 + +dataset: + _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset + dataset_name: rajpurkar/squad + split: train + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.default_collater + shuffle: false + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset + dataset_name: rajpurkar/squad + split: validation + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.default_collater + +# wandb: +# project: +# entity: +# name: +# save_dir: diff --git a/examples/puzzletron/automodel_distillation/loss.py b/examples/puzzletron/automodel_distillation/loss.py new file mode 100644 index 0000000000..68d5a92ef8 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/loss.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.tensor import DTensor, Shard + + +def _infer_tp_group_from_dtensor(tensor: "torch.Tensor"): + """Return device_mesh process group if tensor is a DTensor sharded on vocab (logits last dim, lm_head dim 0).""" + if not isinstance(tensor, DTensor): + return None + # Vocab sharding: Shard on last dim (logits) or Shard(0) (weight matrix) + has_shard = any(isinstance(p, Shard) for p in tensor.placements) + if not has_shard: + return None + return tensor.device_mesh.get_group() + + +def _kl_forward_tp( + t_logits: torch.Tensor, + s_logits: torch.Tensor, + tp_group, +) -> torch.Tensor: + """ + Compute KL (negative cross entropy sum(P*log Q)) with tensor parallelism. + Returns per-token negative cross entropy (sum over vocab). + """ + teacher_max = t_logits.max(dim=-1, keepdim=True).values + dist.all_reduce(teacher_max, op=dist.ReduceOp.MAX, group=tp_group) + output_teacher = t_logits - teacher_max + + denom_teacher = torch.exp(output_teacher).sum(dim=-1, keepdim=True) + dist.all_reduce(denom_teacher, op=dist.ReduceOp.SUM, group=tp_group) + teacher_prob = torch.exp(output_teacher - torch.log(denom_teacher.clamp(min=1e-12))) + + student_max = s_logits.max(dim=-1, keepdim=True).values + dist.all_reduce(student_max, op=dist.ReduceOp.MAX, group=tp_group) + output_student = s_logits - student_max.detach() + + denom_student = torch.exp(output_student).sum(dim=-1, keepdim=True) + dist.all_reduce(denom_student, op=dist.ReduceOp.SUM, group=tp_group) + student_log_prob = output_student - torch.log(denom_student.clamp(min=1e-12)) + + term = teacher_prob * student_log_prob + inf_mask = torch.isinf(s_logits) + term = torch.masked_fill(term, inf_mask, 0.0) + ce_local = term.sum(dim=-1) + dist.all_reduce(ce_local, op=dist.ReduceOp.SUM, group=tp_group) + return ce_local.view(-1) + + +class KDLoss(nn.Module): + """TP-aware KD on precomputed logits.""" + + def __init__( + self, + ignore_index: int = -100, + temperature: float = 1.0, + fp32_upcast: bool = True, + tp_group=None, + **kwargs, + ): + super().__init__() + self.ignore_index = ignore_index + self.temperature = temperature + self.fp32_upcast = fp32_upcast + self.tp_group = tp_group + + def forward( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + labels: torch.Tensor, + num_batch_labels: int | None = None, + ) -> torch.Tensor: + valid_mask = (labels != self.ignore_index).view(-1) + if valid_mask.sum() == 0: + return student_logits.new_tensor(0.0) + + if student_logits.ndim > 2: + student_logits = student_logits.view(-1, student_logits.shape[-1]) + if teacher_logits.ndim > 2: + teacher_logits = teacher_logits.view(-1, teacher_logits.shape[-1]) + if labels.ndim > 1: + labels = labels.view(-1) + + tp_group = self.tp_group + if isinstance(student_logits, DTensor) and tp_group is None: + tp_group = _infer_tp_group_from_dtensor(student_logits) + + if tp_group is not None: + if isinstance(student_logits, DTensor): + student_logits = student_logits.to_local() + if isinstance(teacher_logits, DTensor): + teacher_logits = teacher_logits.to_local() + else: + if isinstance(student_logits, DTensor): + student_logits = student_logits.full_tensor() + if isinstance(teacher_logits, DTensor): + teacher_logits = teacher_logits.full_tensor() + + t_logits = teacher_logits[valid_mask] + s_logits = student_logits[valid_mask] + + if self.fp32_upcast: + t_logits = t_logits.float() + s_logits = s_logits.float() + if self.temperature != 1.0: + t_logits = t_logits.mul(1.0 / self.temperature) + s_logits = s_logits.mul(1.0 / self.temperature) + + if tp_group is not None: + kl_per_token = _kl_forward_tp(t_logits, s_logits, tp_group) + else: + teacher_prob = F.softmax(t_logits, dim=-1, dtype=torch.float32) + student_logprob = F.log_softmax(s_logits, dim=-1, dtype=torch.float32) + inf_mask = torch.isinf(s_logits) + kl_per_token = ( + torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0.0).sum(-1).view(-1) + ) + + if self.temperature != 1.0: + kl_per_token = kl_per_token * (self.temperature**2) + + if num_batch_labels is not None: + return -torch.sum(kl_per_token) / num_batch_labels + return -torch.mean(kl_per_token) diff --git a/examples/puzzletron/automodel_distillation/patch_automodel.py b/examples/puzzletron/automodel_distillation/patch_automodel.py new file mode 100644 index 0000000000..72c2c18119 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/patch_automodel.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime patch so NeMo AutoModel.from_pretrained(..., anymodel_descriptor=..., block_configs_path=...) +uses ModelOpt's AnyModel support (ModelDescriptorFactory + deci_x_patcher). + +Requires ModelOpt to be installed. Call apply_patch() before loading models; call remove_patch() to restore. +""" + +import functools +import json +import logging +import threading +from contextlib import nullcontext +from pathlib import Path + +logger = logging.getLogger(__name__) + +_anymodel_ctx = threading.local() + + +def _get_ctx_stack(): + if not hasattr(_anymodel_ctx, "stack"): + _anymodel_ctx.stack = [] + return _anymodel_ctx.stack + + +def load_block_configs(block_configs_path: str | Path) -> list[dict]: + path = Path(block_configs_path) + if not path.exists(): + raise FileNotFoundError(f"Block configs not found: {path}") + with open(path) as f: + out = json.load(f) + logger.info("Loaded %d block configs from %s", len(out), path) + return out + + +def auto_detect_block_configs(checkpoint_dir: str | Path) -> list[dict] | None: + checkpoint_dir = Path(checkpoint_dir) + block_configs_path = checkpoint_dir / "block_configs.json" + if block_configs_path.exists(): + return load_block_configs(block_configs_path) + return None + + +def apply_patch() -> None: + """Patch nemo_automodel so from_pretrained(..., anymodel_descriptor=..., block_configs_path=...) + uses ModelOpt's deci_x_patcher for heterogeneous (AnyModel) checkpoints. + """ + import nemo_automodel._transformers.auto_model as _auto_model + + if getattr(_auto_model, "_anymodel_patch_applied", False): + logger.debug("AutoModel AnyModel patch already applied") + return + + from modelopt.torch.puzzletron.anymodel import ModelDescriptorFactory, deci_x_patcher + + _orig_init_model = _auto_model._init_model + _orig_from_pretrained = _auto_model._BaseNeMoAutoModelClass.from_pretrained.__func__ + + def _patched_init_model(cls, *model_args, **kwargs): + stack = _get_ctx_stack() + block_configs, anymodel_descriptor = stack[-1] if stack else (None, None) + + patcher_ctx = nullcontext() + if block_configs is not None and anymodel_descriptor is not None: + descriptor = ModelDescriptorFactory.get(anymodel_descriptor) + if descriptor is not None: + patcher_ctx = deci_x_patcher( + model_descriptor=descriptor, + block_configs=block_configs, + ) + logger.info( + "Using deci_x_patcher with %d heterogeneous layer configs (descriptor=%s)", + len(block_configs), + anymodel_descriptor, + ) + else: + logger.warning( + "anymodel_descriptor=%r not found in ModelDescriptorFactory; skipping deci_x_patcher", + anymodel_descriptor, + ) + + with patcher_ctx: + return _orig_init_model(cls, *model_args, **kwargs) + + def _patched_from_pretrained_impl(cls, *args, **kwargs): + kwargs = dict(kwargs) + pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", None) + anymodel_descriptor = kwargs.pop("anymodel_descriptor", None) + block_configs_path = kwargs.pop("block_configs_path", None) + if args: + pretrained_model_name_or_path = pretrained_model_name_or_path or args[0] + model_args = args[1:] + else: + model_args = () + if pretrained_model_name_or_path is None: + raise TypeError( + "from_pretrained() missing 1 required argument: 'pretrained_model_name_or_path'" + ) + + block_configs = None + if anymodel_descriptor is not None: + if block_configs_path is not None: + block_configs = load_block_configs(block_configs_path) + else: + checkpoint_dir = Path(pretrained_model_name_or_path) + if checkpoint_dir.is_dir(): + block_configs = auto_detect_block_configs(checkpoint_dir) + if block_configs: + logger.info( + "Auto-detected %d block configs from %s/block_configs.json", + len(block_configs), + checkpoint_dir, + ) + + stack = _get_ctx_stack() + stack.append((block_configs, anymodel_descriptor)) + kwargs_for_orig = { + k: v + for k, v in kwargs.items() + if k not in ("anymodel_descriptor", "block_configs_path") + } + if isinstance(pretrained_model_name_or_path, type): + raise TypeError( + "pretrained_model_name_or_path must be a path (str or PathLike), got a type. " + "Ensure the config model.pretrained_model_name_or_path is the checkpoint path." + ) + try: + return _orig_from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + **kwargs_for_orig, + ) + finally: + stack.pop() + + class _FromPretrainedDescriptor: + def __get__(self, obj, owner): + if owner is None: + return self + return functools.partial(_patched_from_pretrained_impl, owner) + + _auto_model._init_model = _patched_init_model + _auto_model._BaseNeMoAutoModelClass.from_pretrained = _FromPretrainedDescriptor() + _auto_model._anymodel_patch_applied = True + _auto_model._anymodel_orig_init_model = _orig_init_model + _auto_model._anymodel_orig_from_pretrained = _orig_from_pretrained + logger.info("Applied AnyModel patch to nemo_automodel._transformers.auto_model (ModelOpt)") + + +def remove_patch() -> None: + """Restore nemo_automodel to its original state.""" + import nemo_automodel._transformers.auto_model as _auto_model + + if not getattr(_auto_model, "_anymodel_patch_applied", False): + logger.debug("AutoModel AnyModel patch was not applied") + return + + _auto_model._init_model = _auto_model._anymodel_orig_init_model + _auto_model._BaseNeMoAutoModelClass.from_pretrained = _auto_model._anymodel_orig_from_pretrained + del _auto_model._anymodel_orig_init_model + del _auto_model._anymodel_orig_from_pretrained + _auto_model._anymodel_patch_applied = False + logger.info("Removed AnyModel patch from nemo_automodel._transformers.auto_model") diff --git a/examples/puzzletron/automodel_distillation/pretrain.yaml b/examples/puzzletron/automodel_distillation/pretrain.yaml new file mode 100644 index 0000000000..99e6346c15 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/pretrain.yaml @@ -0,0 +1,108 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Pretrain/finetune with a heterogeneous (AnyModel) checkpoint. +# Requires apply_patch() from patch_automodel so from_pretrained accepts anymodel_descriptor. +# +# Set pretrained_model_name_or_path to your checkpoint directory. +# anymodel_descriptor must match a ModelOpt ModelDescriptorFactory name, e.g.: +# gpt_oss_20b, llama, qwen2, qwen3, mistral_small, nemotron_h, nemotron_h_v2, qwen3_vl +# block_configs_path is optional; if omitted, block_configs are auto-detected from +# /block_configs.json when present. +# +# To run (from repo root or examples/puzzletron/automodel_distillation): +# torchrun --nproc_per_node -m automodel_distillation.run --mode pretrain -c pretrain.yaml +# Override config on the command line, e.g.: +# model.pretrained_model_name_or_path=/path/to/checkpoint model.anymodel_descriptor=llama + +step_scheduler: + global_batch_size: 16 + local_batch_size: 2 + ckpt_every_steps: 2000 + num_epochs: 1 + max_steps: 60 + +dist_env: + backend: nccl + timeout_minutes: 5 + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + # Set to your AnyModel checkpoint directory (must contain config and weights; optional block_configs.json) + pretrained_model_name_or_path: null # e.g. /path/to/heterogeneous_checkpoint + # ModelOpt descriptor name (see list in header comment) + anymodel_descriptor: null # e.g. gpt_oss_20b or llama + force_hf: true + torch_dtype: bf16 + trust_remote_code: true + # block_configs_path: null # optional; default auto-detect from checkpoint dir + +checkpoint: + enabled: false + checkpoint_dir: checkpoints/ + model_save_format: torch_save + save_consolidated: false + +dataset: + _target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset + dataset_name: rajpurkar/squad + split: train + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + shuffle: false + collate_fn: nemo_automodel.components.datasets.utils.default_collater + +loss_fn: + _target_: nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy + +optimizer: + _target_: torch.optim.AdamW + lr: 2.0e-5 + betas: [0.9, 0.95] + weight_decay: 0.1 + +distributed: + dp_size: none + tp_size: 1 + sequence_parallel: false + cp_size: 1 + ep_size: 1 + pp_size: 1 + pipeline: + pp_schedule: interleaved1f1b + pp_microbatch_size: 1 + scale_grads_in_schedule: false + round_virtual_stages_to_pp_multiple: up + dtype: bf16 + +distributed_config: + _target_: nemo_automodel.components.distributed.config.FSDP2Config + sequence_parallel: false + activation_checkpointing: true + +packed_sequence: + packed_sequence_size: 2048 + split_across_pack: false + +lr_scheduler: + lr_decay_style: cosine + min_lr: 1.0e-6 + lr_warmup_steps: 15 + +# wandb: +# project: +# entity: +# name: +# save_dir: diff --git a/examples/puzzletron/automodel_distillation/recipe.py b/examples/puzzletron/automodel_distillation/recipe.py new file mode 100644 index 0000000000..cc225ffa16 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/recipe.py @@ -0,0 +1,950 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Knowledge Distillation recipe for next-token prediction with NeMo-AutoModel. + +This recipe fine-tunes a *student* model using the logits of a frozen *teacher* model. It +extends ``FinetuneRecipeForNextTokenPrediction`` adding: + +1. teacher_model – an additional HF/NeMo model loaded in ``eval`` mode +2. kd_loss_fn – KL-divergence between temperature-scaled distributions +3. kd_ratio – linear mix between CE loss and KD loss + +The training loop is copied from the parent class but the loss becomes: + loss = (1-kd_ratio) * ce_loss + kd_ratio * kd_loss + +The file exposes ``KnowledgeDistillationRecipeForNextTokenPrediction`` and a +``main`` entry-point so it can be launched exactly the same way as other +recipes: + + python -m torch.distributed.run --nproc-per-node=8 \ + nemo_automodel/recipes/llm/knowledge_distillation.py \ + -c examples/llm/llama_3_2_1b_kd.yaml +""" + +from __future__ import annotations + +import logging +import time +from contextlib import nullcontext +from typing import Any + +import torch +import wandb +from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx +from nemo_automodel.components.distributed.pipelining.config import PipelineConfig +from nemo_automodel.components.distributed.utils import get_sync_ctx +from nemo_automodel.components.loggers.metric_logger import MetricsSample +from nemo_automodel.components.loss.linear_ce import FusedLinearCrossEntropy +from nemo_automodel.components.training.rng import ScopedRNG +from nemo_automodel.components.training.utils import ( + ScopedModuleOffloading, + count_tail_padding, + prepare_after_first_microbatch, + prepare_for_final_backward, + prepare_for_grad_accumulation, + scale_grads_and_clip_grad_norm, +) +from nemo_automodel.recipes.llm.train_ft import ( + TrainFinetuneRecipeForNextTokenPrediction, + _get_num_thd_chunks, + _uses_te_dot_product_attention, + _uses_thd_collater, + build_model, + calculate_loss, +) +from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + +logger = logging.getLogger(__name__) + + +def _get_lm_head_weight(model) -> torch.Tensor: + """Return the lm_head weight from the model (possibly sharded/DTensor). Do not gather to full.""" + lm_head = None + if hasattr(model, "get_output_embeddings"): + emb = model.get_output_embeddings() + if emb is not None and hasattr(emb, "weight"): + lm_head = emb.weight + if lm_head is None: + for n, p in model.named_parameters(remove_duplicate=False): + if "lm_head" in n and n.endswith(".weight"): + lm_head = p + break + if lm_head is None: + raise ValueError("lm_head.weight not found in model") + return lm_head + + +def _build_kd_loss_fn(cfg_kd): + if cfg_kd is None: + logger.info("No KD loss function provided, using KLDivLoss") + return torch.nn.KLDivLoss(reduction="batchmean") + return cfg_kd.instantiate() + + +def _build_teacher_model( + cfg_teacher, + seed, + has_packed_sequence, + device_mesh=None, + moe_mesh=None, + distributed_config=None, + device=None, +): + """Build and initialize the teacher model for knowledge distillation. + + Uses the same infrastructure as student model (NeMoAutoModelForCausalLM) but without + PEFT, FP8, or QAT since the teacher should be frozen in full precision. + + Args: + cfg_teacher: Configuration for teacher model instantiation. + seed: Random seed for reproducibility. + has_packed_sequence: Whether using packed sequences. + device_mesh: Device mesh for distributed training. + moe_mesh: MOE mesh for expert parallelism. + distributed_config: Strategy-specific distributed config. + device: Device to place the teacher model on. + + Returns: + The frozen teacher model ready for inference. + + Note: + The `offload_teacher_model` config option is not supported with this approach. + Device placement is handled internally by NeMoAutoModelForCausalLM infrastructure. + """ + + assert cfg_teacher is not None, "`teacher_model` section missing from YAML config" + logger.info("Instantiating teacher model") + + # Build teacher model using the same infrastructure as student + # but without PEFT/FP8/QAT (teacher should be frozen in full precision) + with ScopedRNG(seed=seed, ranked=True): + kwargs: dict[str, Any] = { + "has_packed_sequence": has_packed_sequence, + "device_mesh": device_mesh, + "moe_mesh": moe_mesh, + "distributed_config": distributed_config, + } + + teacher_model = cfg_teacher.instantiate(**kwargs) + + # Ensure the teacher model is on the correct device + teacher_model = teacher_model.to(device) + + # Set teacher to eval mode and freeze parameters + teacher_model.eval() + for p in teacher_model.parameters(): + p.requires_grad_(False) + + return teacher_model + + +def _build_teacher_model_with_pp( + cfg_teacher, + seed: int, + has_packed_sequence: bool, + device_mesh, + moe_mesh, + distributed_config, + pipeline_config: PipelineConfig, + dist_setup, +) -> Any: + """Build teacher model with same parallelization as student (TP/EP/SP/PP). + + Teacher is built via build_model with pipeline_config so it is an AutoPipeline + when PP is enabled. No PEFT/FP8/QAT. Teacher is frozen and set to eval. + """ + assert cfg_teacher is not None, "`teacher_model` section missing from YAML config" + logger.info("Instantiating teacher model (parallelized with TP/EP/SP/PP)") + + # Copy pipeline config and use a capture loss_fn so we can read teacher logits after eval + teacher_pipeline_config = PipelineConfig( + pp_schedule=pipeline_config.pp_schedule, + pp_schedule_csv=pipeline_config.pp_schedule_csv, + pp_microbatch_size=pipeline_config.pp_microbatch_size, + pp_batch_size=pipeline_config.pp_batch_size, + layers_per_stage=pipeline_config.layers_per_stage, + round_virtual_stages_to_pp_multiple=pipeline_config.round_virtual_stages_to_pp_multiple, + module_fqns_per_model_part=pipeline_config.module_fqns_per_model_part, + patch_inner_model=pipeline_config.patch_inner_model, + patch_causal_lm_model=pipeline_config.patch_causal_lm_model, + patch_stage_backward_maybe_with_nosync=pipeline_config.patch_stage_backward_maybe_with_nosync, + dtype=pipeline_config.dtype, + scale_grads_in_schedule=pipeline_config.scale_grads_in_schedule, + loss_fn=None, # Set below via closure + ) + + # Mutable container for teacher logits (set by capture fn when last stage runs) + teacher_logits_capture = [None] + + def _teacher_capture_loss_fn(logits, target, **kwargs): + teacher_logits_capture[0] = logits.detach().clone() + return logits.new_tensor(0.0, dtype=logits.dtype) + + teacher_pipeline_config.loss_fn = _teacher_capture_loss_fn + + with ScopedRNG(seed=seed, ranked=True): + teacher_model = build_model( + cfg_teacher, + cfg_peft=None, + has_packed_sequence=has_packed_sequence, + seed=seed, + cfg_fp8=None, + cfg_compile=None, + cfg_quantization=None, + device_mesh=device_mesh, + moe_mesh=moe_mesh, + distributed_config=distributed_config, + pipeline_config=teacher_pipeline_config, + cfg_qat=None, + cfg_moe=dist_setup.moe_config, + activation_checkpointing=dist_setup.activation_checkpointing, + ) + + # Freeze teacher + for part in getattr(teacher_model, "parts", [teacher_model]): + part.eval() + for p in part.parameters(): + p.requires_grad_(False) + + # Attach capture ref so recipe can read teacher logits after eval + teacher_model._teacher_logits_capture = teacher_logits_capture + return teacher_model + + +def _verify_tokenizer_compatibility(student_cfg, teacher_cfg, trust_remote_code=True): + if student_cfg is None or teacher_cfg is None: + raise ValueError("Student and teacher model configs are required") + student_tokenizer = NeMoAutoTokenizer.from_pretrained( + student_cfg.pretrained_model_name_or_path, trust_remote_code=trust_remote_code + ) + teacher_tokenizer = NeMoAutoTokenizer.from_pretrained( + teacher_cfg.pretrained_model_name_or_path, trust_remote_code=trust_remote_code + ) + if student_tokenizer.vocab_size != teacher_tokenizer.vocab_size: + raise ValueError( + "Student and teacher tokenizers have different vocab sizes; Support will be added in the future" + ) + if student_tokenizer.pad_token != teacher_tokenizer.pad_token: + raise ValueError("Student and teacher tokenizers have different pad tokens") + del student_tokenizer, teacher_tokenizer + + +class KnowledgeDistillationRecipeForNextTokenPrediction(TrainFinetuneRecipeForNextTokenPrediction): + """Fine-tune a student model via knowledge distillation.""" + + def setup(self): + """Build student & teacher, dataloaders, optimizers, etc.""" + # Right now, we only support tokenizer compatibility for the same tokenizer. + # We will add support for different tokenizers in the future. + _verify_tokenizer_compatibility( + self.cfg.get("model", None), self.cfg.get("teacher_model", None) + ) + + # Let the parent class build *everything* for the student first + super().setup() + + self._offload_teacher_model = self.cfg.get("offload_teacher_model", False) + teacher_device = self.dist_env.device if not self._offload_teacher_model else "cpu" + + if self.pp_enabled: + # PP + FusedLinearCrossEntropy requires hidden_states at loss; pipeline last stage only has logits. + if isinstance(self.loss_fn, FusedLinearCrossEntropy): + raise ValueError( + "Pipeline parallelism with KD requires a loss that uses only logits and labels " + "(e.g. MaskedCrossEntropy). FusedLinearCrossEntropy is not supported for PP KD." + ) + self.teacher_model = _build_teacher_model_with_pp( + cfg_teacher=self.cfg.get("teacher_model", None), + seed=self.cfg.get("seed", 42), + has_packed_sequence=self.cfg.get("packed_sequence.packed_sequence_size", 0) > 0, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + distributed_config=self.distributed_config, + pipeline_config=self.pipeline_config, + dist_setup=self.dist_setup, + ) + self.teacher_pp = self.teacher_model + else: + self.teacher_model = _build_teacher_model( + cfg_teacher=self.cfg.get("teacher_model", None), + seed=self.cfg.get("seed", 42), + has_packed_sequence=self.cfg.get("packed_sequence.packed_sequence_size", 0) > 0, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + distributed_config=self.distributed_config, + device=teacher_device, + ) + self.teacher_pp = None + + logger.info("Teacher Model: " + str(self.teacher_model)) + # KD + self.kd_loss_fn = _build_kd_loss_fn(self.cfg.get("kd_loss_fn", None)) + self.kd_ratio: float = float(self.cfg.get("kd_ratio", 0.5)) + logger.info("KD Loss config: " + str(self.cfg.get("kd_loss_fn", None))) + temperature = getattr(self.kd_loss_fn, "temperature", "N/A") + logger.info(f"Knowledge-distillation enabled: ratio={self.kd_ratio}, T={temperature}") + + # Buffers for logging + self._kd_loss_buffer = [] + self._ce_loss_buffer = [] + + if self.pp_enabled: + schedule = self.pp.info.schedule + # Schedule objects use _loss_fn (e.g. ScheduleInterleaved1F1B), not loss_fn + self._original_pp_loss_fn = getattr(schedule, "_loss_fn", None) + schedule._loss_fn = self._make_pp_kd_loss_wrapper() + + def _make_pp_kd_loss_wrapper(self): + """Return a callable used as the student pipeline loss_fn; reads _current_teacher_logits from self.""" + recipe_ref = self + + def pp_kd_loss_fn(logits, target, **kwargs): + teacher_logits = getattr(recipe_ref, "_current_teacher_logits", None) + num_label_tokens = getattr(recipe_ref, "_current_num_label_tokens", None) + if teacher_logits is None: + raise RuntimeError( + "KD loss wrapper: _current_teacher_logits not set. " + "Teacher pipeline eval must run before student step." + ) + if recipe_ref.kd_ratio >= 1.0: + ce_loss = logits.new_tensor(0.0, dtype=logits.dtype) + else: + ce_loss = calculate_loss( + recipe_ref.loss_fn, + logits=logits, + labels=target, + num_label_tokens=num_label_tokens, + ) + kd_loss = recipe_ref.kd_loss_fn( + logits, + teacher_logits, + target, + num_batch_labels=num_label_tokens, + ) + recipe_ref._ce_loss_buffer.append(ce_loss.detach().clone()) + recipe_ref._kd_loss_buffer.append(kd_loss.detach().clone()) + return (1.0 - recipe_ref.kd_ratio) * ce_loss + recipe_ref.kd_ratio * kd_loss + + return pp_kd_loss_fn + + # Override the forward backward step to inject KD loss + def _forward_backward_step( + self, + idx, + batch, + *, + num_label_tokens, + num_batches, + is_train: bool = True, + ): + """Override the forward backward step to include knowledge distillation loss.""" + if self.pp_enabled: + raise RuntimeError( + "_forward_backward_step should not be called when pp_enabled; use _forward_backward_step_pp" + ) + batch = {k: v.to(self.dist_env.device, non_blocking=True) for k, v in batch.items()} + labels = batch.pop("labels") + train_ctx, batch = make_cp_batch_and_ctx(self.device_mesh, batch, labels) + + model = self.model_parts[0] + sync_ctx = ( + get_sync_ctx( + model, + idx == num_batches - 1, + defer_fsdp_grad_sync=getattr(self.distributed_config, "defer_fsdp_grad_sync", True), + ) + if is_train + else nullcontext() + ) + with train_ctx(), sync_ctx: + with ( + ScopedModuleOffloading(self.teacher_model, enabled=self._offload_teacher_model), + torch.inference_mode(), + ): + teacher_logits = self.teacher_model(**batch) + teacher_logits = getattr(teacher_logits, "logits", teacher_logits).detach().clone() + + # Student forward + student_keep_last = isinstance(self.loss_fn, FusedLinearCrossEntropy) + if student_keep_last: + # Student forward keeping only last token logits to match loss_fn + student_out = model(logits_to_keep=1, **batch) + else: + student_out = model(**batch) + + student_logits = getattr(student_out, "logits", student_out) # shape (B, S, V) + # Cross-entropy loss against true labels (skip when kd_ratio >= 1.0) + if self.kd_ratio >= 1.0: + ce_loss = student_logits.new_tensor(0.0, dtype=student_logits.dtype) + else: + ce_loss = calculate_loss( + self.loss_fn, + logits=student_logits, + labels=labels, + model=model, + hidden_states=student_out.hidden_states[-1] + if "hidden_states" in student_out + else None, + num_label_tokens=num_label_tokens, + ) + # Reminder: kd_loss is normalized by num_label_tokens, + # which typically is larger than the number of labels in this batch, + # because it contains the total number of labels for all batches contained + # in one optimization step (grad_acc_steps = gbs / mbs). + kd_loss = self.kd_loss_fn( + student_logits, + teacher_logits, + labels, + num_batch_labels=num_label_tokens, + ) + local_loss = (1.0 - self.kd_ratio) * ce_loss + self.kd_ratio * kd_loss + if is_train: + (local_loss * self._get_dp_group_size(include_cp=True)).backward() + # return the losses for logging + detached_local = local_loss.detach().clone() + return detached_local, kd_loss.detach().clone(), ce_loss.detach().clone() + + def _forward_backward_step_pp( + self, + idx, + batch, + *, + loss_buffer, + num_label_tokens, + num_batches, + is_train: bool = True, + ): + """PP path: runs 1T then 1F1B and appends to loss_buffer.""" + batch = { + k: ( + { + dk: dv.to(self.dist_env.device, non_blocking=True) + for dk, dv in v.items() + if dv is not None + } + if isinstance(v, dict) + else ( + v.to(self.dist_env.device, non_blocking=True) + if isinstance(v, torch.Tensor) + else v + ) + ) + for k, v in batch.items() + } + train_ctx, batch = make_cp_batch_and_ctx( + self.device_mesh, + batch, + use_te=_uses_te_dot_product_attention(self.cfg.model) + and _uses_thd_collater(self.cfg.dataloader), + padding_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0, + num_chunks=_get_num_thd_chunks(True, self.cfg), + ) + labels = batch.pop("labels") + input_ids = batch.pop("input_ids") + batch_filtered = { + k: v + for k, v in batch.items() + if v is not None and not (isinstance(v, dict) and len(v) == 0) + } + + if self.pp.info.has_last_stage: + targets = labels.clone() + else: + targets = None + + fp8_ctx = self.te_fp8.maybe_te_autocast() if self.te_fp8 is not None else nullcontext() + + with train_ctx(), fp8_ctx: + with torch.inference_mode(): + teacher_losses = [] if self.teacher_pp.info.has_last_stage else None + if self.teacher_pp.info.has_first_stage: + self.teacher_pp.info.schedule.eval( + input_ids, target=targets, losses=teacher_losses, **batch_filtered + ) + else: + self.teacher_pp.info.schedule.eval( + target=targets, losses=teacher_losses, **batch_filtered + ) + capture = getattr(self.teacher_model, "_teacher_logits_capture", None) + if capture is not None and capture[0] is not None: + self._current_teacher_logits = capture[0] + capture[0] = None + else: + self._current_teacher_logits = None + self._current_num_label_tokens = num_label_tokens + + student_losses = [] if self.pp.info.has_last_stage else None + if is_train: + if self.pp.info.has_first_stage: + self.pp.info.schedule.step( + input_ids, target=targets, losses=student_losses, **batch_filtered + ) + else: + self.pp.info.schedule.step( + target=targets, losses=student_losses, **batch_filtered + ) + elif self.pp.info.has_first_stage: + self.pp.info.schedule.eval( + input_ids, target=targets, losses=student_losses, **batch_filtered + ) + else: + self.pp.info.schedule.eval(target=targets, losses=student_losses, **batch_filtered) + + if self.pp.info.has_last_stage: + loss_buffer.append(torch.sum(torch.stack(student_losses)).detach().clone()) + else: + loss_buffer.append(torch.tensor(0.0, device=self.dist_env.device)) + + def _run_train_optim_step(self, batches, max_grad_norm: float | None = None): + """Execute a single training step. + + Args: + batches: List of batches of training data. + max_grad_norm: Gradient clipping norm. Optional, if None will not clip gradients. + """ + if self.pp_enabled: + return self._run_train_optim_step_pp(batches, max_grad_norm) + + num_label_tokens = torch.tensor( + sum((batch["labels"] != -100).sum().item() for batch in batches), dtype=torch.long + ) + num_label_tokens = self._dp_allreduce(num_label_tokens).item() + loss_buffer = [] + + # number of tokens in the batch, excluding any tail padding. + num_tokens_in_batch = torch.tensor( + sum(batch["labels"].numel() - count_tail_padding(batch["labels"]) for batch in batches), + dtype=torch.long, + ) + num_tokens_in_batch = self._dp_allreduce(num_tokens_in_batch).item() + num_batches = len(batches) + for i, batch in enumerate(batches): + local_loss, kd_loss, ce_loss = self._forward_backward_step( + i, batch, num_label_tokens=num_label_tokens, num_batches=num_batches + ) + loss_buffer.append(local_loss) + self._ce_loss_buffer.append(ce_loss) + self._kd_loss_buffer.append(kd_loss) + + grad_norm = 0 + # Clip gradients **after** any rescaling. + # TODO(@boxiangw): Fix TP gradient clipping + if max_grad_norm is not None: + if not self.device_mesh or self.device_mesh["tp"].size() == 1: + grad_norm = torch.nn.utils.clip_grad_norm_( + [p for p in self.model_parts[0].parameters() if p.requires_grad], max_grad_norm + ) + if hasattr(grad_norm, "full_tensor"): + grad_norm = grad_norm.full_tensor() # collect the summed grad norm across ranks + + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() + + self.checkpointer.maybe_wait_for_staging() + for opt in self.optimizer: + opt.step() + opt.zero_grad() + + if self.lr_scheduler is not None: + for scheduler in self.lr_scheduler: + scheduler.step(1) + + # Precompute FP8 scales + fp8_config = self.cfg.get("fp8", None) + if ( + fp8_config is not None + and fp8_config.get("enabled", False) + and fp8_config.get("precompute_float8_dynamic_scale_for_fsdp", False) + and not self.pp_enabled + and self.device_mesh is not None + and self.device_mesh["dp_shard"].size() > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self.model_parts[0]) + + # Note(MegatronFSDP): Need to call these functions for MegatronFSDP if not using latest api + # self.model_parts[0].install_optimized_model_weights() + # self.model_parts[0].zero_grad_buffer() + + t = time.perf_counter() + time_delta = t - self.timestamp + self.timestamp = t + tps = num_tokens_in_batch / time_delta + reporting_loss = torch.sum(torch.stack(loss_buffer)) + reporting_loss = self._dp_allreduce(reporting_loss, include_cp=True) + reporting_loss = reporting_loss.cpu().item() + # fix reporting_loss, tps across ranks + + ce_loss = self._dp_allreduce( + torch.stack(self._ce_loss_buffer).sum(), include_cp=True + ).item() + kd_loss = self._dp_allreduce( + torch.stack(self._kd_loss_buffer).sum(), include_cp=True + ).item() + # Clear buffers for next step + self._ce_loss_buffer.clear() + self._kd_loss_buffer.clear() + + # return reporting_loss, grad_norm, tps, num_tokens_in_batch, num_label_tokens + return MetricsSample( + step=self.step_scheduler.step, + epoch=self.step_scheduler.epoch, + metrics={ + "loss": reporting_loss, + "ce_loss": ce_loss, + "kd_loss": kd_loss, + "grad_norm": grad_norm, + "lr": self.optimizer[0].param_groups[0]["lr"], + "mem": torch.cuda.max_memory_allocated() / 1024**3, + "tps": tps, + "tps_per_gpu": tps / max(self._get_dp_group_size(), 1), + "num_tokens_per_step": num_tokens_in_batch, + "num_label_tokens": num_label_tokens, + "kd_ratio": self.kd_ratio, + "temperature": getattr(self.kd_loss_fn, "temperature", float("nan")), + }, + ) + + def _run_train_optim_step_pp(self, batches, max_grad_norm: float | None = None): + """Execute a single training step when pipeline parallelism is enabled.""" + num_label_tokens = torch.tensor( + sum((b["labels"] != -100).sum().item() for b in batches), dtype=torch.long + ) + num_label_tokens = self._dp_allreduce(num_label_tokens).item() + loss_buffer = [] + + num_tokens_in_batch = torch.tensor( + sum(b["labels"].numel() - count_tail_padding(b["labels"]) for b in batches), + dtype=torch.long, + ) + num_tokens_in_batch = self._dp_allreduce(num_tokens_in_batch).item() + num_batches = len(batches) + + prepare_for_grad_accumulation(self.model_parts, pp_enabled=True) + + for i, batch in enumerate(batches): + if i == num_batches - 1: + prepare_for_final_backward(self.model_parts, pp_enabled=True) + self._forward_backward_step_pp( + i, + batch, + loss_buffer=loss_buffer, + num_label_tokens=num_label_tokens, + num_batches=num_batches, + ) + if i == 0: + prepare_after_first_microbatch() + + grad_norm = scale_grads_and_clip_grad_norm( + max_grad_norm, + self.model_parts, + norm_type=2.0, + pp_enabled=True, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name="pp", + foreach=True, + num_label_tokens=num_label_tokens, + dp_group_size=self._get_dp_group_size(include_cp=True), + ) + + self.checkpointer.maybe_wait_for_staging() + for opt in self.optimizer: + opt.step() + opt.zero_grad() + + if hasattr(self.model_parts[0], "update_moe_gate_bias"): + for mp in self.model_parts: + mp.update_moe_gate_bias() + + if self.lr_scheduler is not None: + for scheduler in self.lr_scheduler: + scheduler.step(1) + + fp8_config = self.cfg.get("fp8", None) + if ( + fp8_config is not None + and fp8_config.get("enabled", False) + and fp8_config.get("precompute_float8_dynamic_scale_for_fsdp", False) + and self.device_mesh is not None + and self.device_mesh["dp_shard"].size() > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self.model_parts[0]) + + t = time.perf_counter() + time_delta = t - self.timestamp + self.timestamp = t + tps = num_tokens_in_batch / time_delta + reporting_loss = torch.sum(torch.stack(loss_buffer)) + reporting_loss = self._dp_allreduce(reporting_loss, include_cp=True) + reporting_loss = reporting_loss / num_label_tokens + reporting_loss = reporting_loss.to(self.dist_env.device) + src_rank = self.device_mesh.mesh.reshape(-1)[-1].item() + if self.dist_env.rank == src_rank: + torch.distributed.send(reporting_loss, dst=0) + elif self.dist_env.is_main: + torch.distributed.recv(reporting_loss, src=src_rank) + reporting_loss = reporting_loss.cpu().item() + + # CE/KD buffers are only populated on the last pipeline stage (in the loss wrapper). + # Match train_ft: allreduce within DP group (last-stage ranks get sum, others get 0), + # then send from last stage to rank 0 for logging (same as reporting_loss). + ce_tensor = ( + torch.stack(self._ce_loss_buffer).sum() + if self._ce_loss_buffer + else torch.tensor(0.0, device=self.dist_env.device) + ) + kd_tensor = ( + torch.stack(self._kd_loss_buffer).sum() + if self._kd_loss_buffer + else torch.tensor(0.0, device=self.dist_env.device) + ) + ce_tensor = self._dp_allreduce(ce_tensor, include_cp=True) + kd_tensor = self._dp_allreduce(kd_tensor, include_cp=True) + ce_tensor = ce_tensor.to(self.dist_env.device) + kd_tensor = kd_tensor.to(self.dist_env.device) + if self.dist_env.rank == src_rank and not self.dist_env.is_main: + torch.distributed.send(ce_tensor, dst=0) + torch.distributed.send(kd_tensor, dst=0) + elif self.dist_env.is_main and self.dist_env.rank != src_rank: + torch.distributed.recv(ce_tensor, src=src_rank) + torch.distributed.recv(kd_tensor, src=src_rank) + ce_loss = ce_tensor.cpu().item() + kd_loss = kd_tensor.cpu().item() + self._ce_loss_buffer.clear() + self._kd_loss_buffer.clear() + + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() + grad_norm = float(grad_norm) + + return MetricsSample( + step=self.step_scheduler.step, + epoch=self.step_scheduler.epoch, + metrics={ + "loss": reporting_loss, + "ce_loss": ce_loss, + "kd_loss": kd_loss, + "grad_norm": grad_norm, + "lr": self.optimizer[0].param_groups[0]["lr"], + "mem": torch.cuda.max_memory_allocated() / 1024**3, + "tps": tps, + "tps_per_gpu": tps / self._get_cp_group_size() / max(self._get_dp_group_size(), 1), + "num_tokens_per_step": num_tokens_in_batch, + "num_label_tokens": num_label_tokens, + "kd_ratio": self.kd_ratio, + "temperature": getattr(self.kd_loss_fn, "temperature", float("nan")), + }, + ) + + def run_train_validation_loop(self): + """Run training loop; skip validation when PP is enabled.""" + if not self.pp_enabled: + return super().run_train_validation_loop() + + # PP path: same as parent but skip validation block + for mp in self.model_parts: + mp.train() + self.timestamp = time.perf_counter() + for epoch in self.step_scheduler.epochs: + self.step_scheduler.set_epoch(epoch) + for batches in self.step_scheduler: + self._enable_qat_if_delayed(self.step_scheduler.step) + train_log_data = self._run_train_optim_step(batches, self.max_grad_norm) + self._collect_moe_load_balance() + self.log_train_metrics(train_log_data) + val_losses = {} + if self.step_scheduler.is_val_step: + logger.warning("Validation is not supported for pipeline parallelism; skipping") + if self.step_scheduler.is_ckpt_step: + self.save_checkpoint( + epoch, + self.step_scheduler.step, + train_log_data.metrics["loss"], + val_losses, + best_metric_key=self.best_metric_key, + ) + self.metric_logger_train.close() + for v in self.metric_logger_valid.values(): + v.close() + self.checkpointer.close() + + @torch.no_grad() + def _run_validation_epoch(self, val_dataloader): + """Run one pass over `self.val_dataloader`.""" + if self.pp_enabled: + logger.warning("Validation is not supported for pipeline parallelism") + return + + with ScopedRNG(seed=1, ranked=True): + for mp in self.model_parts: + mp.eval() + + total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device) + ce_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device) + kd_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device) + total_num_label_tokens = 0 + + for batch in val_dataloader: + num_label_tokens = (batch["labels"] != -100).sum().item() + local_loss, _kd_loss, _ce_loss = self._forward_backward_step( + 0, + batch, + num_label_tokens=num_label_tokens, + num_batches=1, + is_train=False, + ) + total_num_label_tokens += num_label_tokens + ce_loss += _ce_loss + kd_loss += _kd_loss + total_loss += local_loss + + total_loss = self._dp_allreduce(total_loss, include_cp=True).item() + ce_loss = self._dp_allreduce(ce_loss, include_cp=True).item() + kd_loss = self._dp_allreduce(kd_loss, include_cp=True).item() + total_num_label_tokens = self._dp_allreduce( + torch.tensor(total_num_label_tokens, dtype=torch.long) + ).item() + + val_loss = total_loss / max(total_num_label_tokens, 1e-8) + return MetricsSample( + step=self.step_scheduler.step, + epoch=self.step_scheduler.epoch, + metrics={ + "val_loss": val_loss, + "ce_loss": ce_loss, + "kd_loss": kd_loss, + "lr": self.optimizer[0].param_groups[0]["lr"], + "num_label_tokens": total_num_label_tokens, + "mem": torch.cuda.max_memory_allocated() / 1024**3, + }, + ) + + def log_val_metrics(self, val_name, log_data, metric_logger=None): + if not self.dist_env.is_main or log_data is None: + return + + if wandb.run is not None: + wandb.log(log_data.to_dict() | {"val_name": val_name}, step=log_data.step) + + # JSONL validation log + if metric_logger is not None: + metric_logger.log(log_data) + + # assumes all model parts' optimizers have the same learning rate + if self.kd_ratio >= 1.0: + logging.info( + "[val] {} | step {} | epoch {} | loss {:.4f} | kd_loss {:.4f} | lr {:.2e} | num_label_tokens {}".format( + val_name, + log_data.step, + log_data.epoch, + log_data.metrics["val_loss"], + log_data.metrics["kd_loss"], + log_data.metrics["lr"], + log_data.metrics["num_label_tokens"], + ) + ) + else: + logging.info( + "[val] {} | step {} | epoch {} | loss {:.4f} | ce_loss {:.4f} | kd_loss {:.4f} | " + "lr {:.2e} | num_label_tokens {}".format( + val_name, + log_data.step, + log_data.epoch, + log_data.metrics["val_loss"], + log_data.metrics["ce_loss"], + log_data.metrics["kd_loss"], + log_data.metrics["lr"], + log_data.metrics["num_label_tokens"], + ) + ) + + def log_train_metrics(self, log_data) -> float: + """Log metrics to wandb and other loggers. + + Args: + log_data: MetricsSample object, containing: + step: int, the current step. + epoch: int, the current epoch. + metrics: Dict[str, float], containing: + "loss": Training loss. + "grad_norm": Grad norm from the training step. + "lr": Learning rate. + "mem": Memory allocated. + "tps": Tokens per second. + "tps_per_gpu": Tokens per second per GPU. + "num_label_tokens": Number of label tokens. + """ + if not self.dist_env.is_main: + return + + # Log to remote services (WandB) according to step_scheduler frequency + if self.step_scheduler.is_remote_logging_step: + if wandb.run is not None: + wandb.log(log_data.to_dict(), step=log_data.step) + + # JSONL training log (always log for detailed local records) + self.metric_logger_train.log(log_data) + + if self.kd_ratio >= 1.0: + logging.info( + "step {} | epoch {} | " + "loss {:.4f} | lr {:.2e} | mem {:.2f} GiB | " + "tps {:.2f} | kd_ratio {:.2f} | temperature {:.2f}".format( + log_data.step, + log_data.epoch, + log_data.metrics["loss"], + log_data.metrics["lr"], + log_data.metrics["mem"], + log_data.metrics["tps"], + log_data.metrics["kd_ratio"], + log_data.metrics["temperature"], + ) + ) + else: + logging.info( + "step {} | epoch {} | " + "loss {:.4f} | ce_loss {:.4f} | kd_loss {:.4f} | " + "lr {:.2e} | mem {:.2f} GiB | tps {:.2f} | kd_ratio {:.2f} | temperature {:.2f}".format( + log_data.step, + log_data.epoch, + log_data.metrics["loss"], + log_data.metrics["ce_loss"], + log_data.metrics["kd_loss"], + log_data.metrics["lr"], + log_data.metrics["mem"], + log_data.metrics["tps"], + log_data.metrics["kd_ratio"], + log_data.metrics["temperature"], + ) + ) + torch.cuda.reset_peak_memory_stats() + + +# Entry point +def main(config_path="examples/llm_kd/llama3_2/llama3_2_1b_kd.yaml"): + """Run the KD recipe from CLI or directly.""" + cfg = parse_args_and_load_config(config_path) + trainer = KnowledgeDistillationRecipeForNextTokenPrediction(cfg) + trainer.setup() + trainer.run_train_validation_loop() + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/examples/puzzletron/automodel_distillation/run.py b/examples/puzzletron/automodel_distillation/run.py new file mode 100644 index 0000000000..eb94849bd7 --- /dev/null +++ b/examples/puzzletron/automodel_distillation/run.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single entrypoint for AnyModel training pipelines. + +Modes: pretrain | kd + +Run from this directory so that patch_automodel and recipe are importable. Example: + torchrun --nproc_per_node=2 -m run --mode pretrain -c ./pretrain.yaml + torchrun --nproc_per_node=2 -m run --mode kd -c ./kd.yaml + +If -c is omitted, a default config path is used per mode. +""" + +from __future__ import annotations + +import sys + +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from patch_automodel import apply_patch + +# Default config path per mode (used when -c is not passed) +_DEFAULT_CONFIG = { + "pretrain": "./pretrain.yaml", + "kd": "./kd.yaml", +} + + +def _parse_mode() -> str: + """Parse --mode from argv; remove it so parse_args_and_load_config does not see it.""" + argv = sys.argv[1:] + mode = None + new_argv = [] + i = 0 + while i < len(argv): + tok = argv[i] + if tok in ("--mode", "-m"): + if i + 1 >= len(argv): + raise ValueError("Expected a value after --mode (pretrain | kd)") + mode = argv[i + 1] + i += 2 + continue + new_argv.append(tok) + i += 1 + if mode is None: + raise ValueError( + "Missing --mode. Choose one of: pretrain, kd. " + "Example: python -m run --mode kd -c kd.yaml" + ) + if mode not in _DEFAULT_CONFIG: + raise ValueError(f"Invalid mode '{mode}'. Choose one of: pretrain, kd") + sys.argv = [sys.argv[0], *new_argv] + return mode + + +def main() -> None: + mode = _parse_mode() + default_config = _DEFAULT_CONFIG[mode] + apply_patch() + cfg = parse_args_and_load_config(default_config) + + if mode == "pretrain": + from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction + + recipe = TrainFinetuneRecipeForNextTokenPrediction(cfg) + elif mode == "kd": + from recipe import KnowledgeDistillationRecipeForNextTokenPrediction + + recipe = KnowledgeDistillationRecipeForNextTokenPrediction(cfg) + else: + raise ValueError(f"Invalid mode '{mode}'. Choose one of: pretrain, kd") + + recipe.setup() + recipe.run_train_validation_loop() + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml new file mode 100644 index 0000000000..ded4f65140 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: gpt_oss_20b +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml new file mode 100644 index 0000000000..8ed06e9568 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml @@ -0,0 +1,17 @@ +defaults: + - gptoss-20b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 16_000 # 45 GiB diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..00d7829e01 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +eval_samples: 2500 #10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_20b_model_descriptor.GptOss20bExpertRemovalLayerDescriptor + target_name: "mlp.router" + +hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0eff799d7e --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10_000 +micro_batch_size: 1 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 100644 index 0000000000..b80faea5f5 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ab8c892182 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/examples/puzzletron/evaluation/hf_deployable_anymodel.py b/examples/puzzletron/evaluation/hf_deployable_anymodel.py index f4fd4e4148..ec61bd4698 100644 --- a/examples/puzzletron/evaluation/hf_deployable_anymodel.py +++ b/examples/puzzletron/evaluation/hf_deployable_anymodel.py @@ -1,10 +1,35 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c + +# MIT License +# +# Copyright (c) 2020 EleutherAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +38,7 @@ # limitations under the License. +import json import logging from typing import Any @@ -28,6 +54,11 @@ from peft import PeftModel from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + try: from pytriton.decorators import batch from pytriton.model_config import Tensor @@ -139,18 +170,12 @@ def _load( # Wraps model loading with deci_x_patcher for heterogeneous layer configs. # See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py # ========================================================================= - import os - import sys - modelopt_workdir = os.environ.get("MODELOPT_WORKDIR") or os.environ.get( - "PUZZLE_WORKDIR" + descriptor = resolve_descriptor_from_pretrained( + self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False) ) - if modelopt_workdir and modelopt_workdir not in sys.path: - sys.path.insert(0, modelopt_workdir) - from modelopt.torch.puzzletron.anymodel.models.llama import LlamaModelDescriptor - from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher - with deci_x_patcher(model_descriptor=LlamaModelDescriptor): + with deci_x_patcher(model_descriptor=descriptor): self.model = AutoModelForCausalLM.from_pretrained( self.hf_model_id_path, torch_dtype=torch_dtype, @@ -587,8 +612,6 @@ def ray_infer_fn(self, inputs: dict[Any, Any]): - log_probs: Optional list of log probabilities if compute_logprob is True - top_logprobs: Optional list of top log probabilities if n_top_logprobs > 0 """ - import json - try: prompts = inputs.pop("prompts") temperature = inputs.pop("temperature", 1.0) diff --git a/examples/puzzletron/evaluation/lm_eval_anymodel.py b/examples/puzzletron/evaluation/lm_eval_anymodel.py index 6d6fcd44e7..94e31b001e 100644 --- a/examples/puzzletron/evaluation/lm_eval_anymodel.py +++ b/examples/puzzletron/evaluation/lm_eval_anymodel.py @@ -1,18 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c # MIT License @@ -37,6 +22,21 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Run lm-eval directly on AnyModel (Puzzletron) checkpoints without a deployment server. @@ -48,55 +48,21 @@ from lm_eval.__main__ import cli_evaluate from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM -from transformers import AutoConfig # Trigger factory registration for all model descriptors import modelopt.torch.puzzletron.anymodel.models # noqa: F401 -from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import ( + resolve_descriptor_from_pretrained, +) from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher -# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. -# Local to this script; add entries when supporting new model types for auto-detection. -_MODEL_TYPE_TO_DESCRIPTOR = { - "llama": "llama", - "mistral": "mistral_small", - "qwen2": "qwen2", - "qwen3": "qwen3", - "nemotron_h": "nemotron_h", - "nemotron_h_v2": "nemotron_h_v2", - "gpt_oss_20b": "gpt_oss_20b", -} - - -def _resolve_descriptor_from_pretrained(pretrained: str | None): - """Resolve the model descriptor by loading the checkpoint config and mapping model_type.""" - if not pretrained: - raise ValueError( - "pretrained must be set in --model_args " - "(e.g. --model_args pretrained=/path/to/checkpoint,dtype=bfloat16)." - ) - - config = AutoConfig.from_pretrained(pretrained, trust_remote_code=True) - model_type = getattr(config, "model_type", None) - - if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: - detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] - print( - f"[lm_eval_anymodel] Auto-detected model_type='{model_type}' → descriptor='{detected}'" - ) - return ModelDescriptorFactory.get(detected) - - known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) - raise ValueError( - f"Cannot auto-detect descriptor for model_type='{model_type}'. " - f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." - ) - def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: """Override HFLM.create_from_arg_obj to wrap model loading with deci_x_patcher.""" pretrained = arg_dict.get("pretrained") - descriptor = _resolve_descriptor_from_pretrained(pretrained) + descriptor = resolve_descriptor_from_pretrained( + pretrained, trust_remote_code=arg_dict.get("trust_remote_code", False) + ) additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} diff --git a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md index b8c97af5d5..f8b53889c6 100644 --- a/examples/puzzletron/evaluation/nemo_evaluator_instructions.md +++ b/examples/puzzletron/evaluation/nemo_evaluator_instructions.md @@ -8,17 +8,45 @@ Evaluate AnyModel checkpoints by deploying a local OpenAI-compatible completions This flow requires Ray for serving the model and NeMo Export-Deploy (included in NeMo containers): ```bash -pip install ray +pip install -r examples/puzzletron/requirements.txt ``` **1. Deploy the model (2 GPUs example):** +We need to patch the `hf_deployable.py` script from Export-Deploy. Best way is to do it as a mount in docker run: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it +if [ ! -d "${MODELOPT_DIR}" ]; then + git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +fi + +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +docker run \ + --gpus all \ + --shm-size=16GB \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/examples/puzzletron/evaluation/hf_deployable_anymodel.py:/opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py \ + -w /opt/Model-Optimizer/examples/megatron_bridge \ + ${DOCKER_IMAGE} bash +``` + +Alternatively you can manually update the file + ```bash # Install the AnyModel-patched deployable (first time only: backs up the original) # /opt/Export-Deploy is the default path in NeMo containers — adjust if needed cp /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py.bak cp examples/puzzletron/evaluation/hf_deployable_anymodel.py /opt/Export-Deploy/nemo_deploy/llm/hf_deployable.py +``` +Now start ray server and deploy the model + +```bash # Start the server (blocks while running — use a separate terminal) ray start --head --num-gpus 2 --port 6379 --disable-usage-stats python /opt/Export-Deploy/scripts/deploy/nlp/deploy_ray_hf.py \ diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md new file mode 100644 index 0000000000..deae06e616 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -0,0 +1,151 @@ +# Knowledge Distillation with Megatron-Bridge + +This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge. + +## Overview + +1. Set up the environment with Megatron-Bridge +2. Prepare tokenized dataset +3. Run knowledge distillation training directly from HuggingFace checkpoints +4. Review MMLU evaluation results (before/after distillation) + +## Setup + +**Clone Model-Optimizer repo:** + +The NeMo container does not include Model-Optimizer examples, so you need to clone the Model-Optimizer repo: + +```bash +export MODELOPT_DIR=${PWD}/Model-Optimizer +git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} +``` + +**Start Docker container:** + +Use the [NeMo 26.02 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02): + +```bash +# Recommended to mount a workspace directory for storing datasets and distilled models +docker run --gpus all -it --rm \ + -v /path/to/your/project:/workspace \ + -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -w /opt/Model-Optimizer \ + nvcr.io/nvidia/nemo:26.02 \ + /bin/bash +``` + +## Dataset Preparation + +This section describes how to prepare datasets for knowledge distillation. We provide examples using WikiText-103, which is a small dataset that can still produce decent results (see the Qwen3-8B example below showing +10.11 percentage point improvement). For production use, larger datasets like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) are recommended. + +### Download and Tokenize Dataset + +Download and tokenize the dataset in a single step. This downloads the dataset from HuggingFace, tokenizes it, and saves it in the Megatron format (`.bin` and `.idx` files): + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset Salesforce/wikitext \ + --hf_name wikitext-103-v1 \ + --hf_split train \ + --output_dir path/to/hf_datasets/wikitext-103-v1 \ + --tokenizer meta-llama/Llama-3.1-8B-Instruct \ + --json_keys text \ + --workers 32 +``` + +This will create: + +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.bin` - Binary tokenized data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document.idx` - Index file for the binary data +- `Salesforce--wikitext_wikitext-103-v1_train_text_document/cache/` - Cache directory (created after running distillation) + +## Run Knowledge Distillation + +Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \ + --student_hf_path /path/to/student/huggingface/checkpoint \ + --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ + --data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \ + --output_dir /path/to/distilled/checkpoint \ + --hf-export-path /path/to/exported/hf/model \ + --hf-model meta-llama/Llama-3.1-8B-Instruct \ + --seq_length 4096 \ + --tp_size 8 \ + --pp_size 1 \ + --mbs 1 \ + --gbs 4 \ + --train_iters 100 \ + --lr 0.0001 \ + --min_lr 1e-05 \ + --lr_warmup_iters 10 \ + --eval_interval 10 \ + --eval_iters 10 \ + --log_interval 1 +``` + +**Notes:** + +- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_`. +- Add `--hf-export-path` to automatically export the final checkpoint to HuggingFace format after distillation. When using `--hf-export-path`, you must also provide `--hf-model` to specify the HuggingFace model ID to use as a template for export (e.g., `meta-llama/Llama-3.1-8B-Instruct`). The `--hf-model` should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation). +- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices. + +## MMLU Evaluation Results + +This section presents MMLU evaluation results for knowledge distillation experiments compressing Qwen3-8B and Llama-3.1-8B-Instruct. + +### Successful Case: Qwen3-8B (80% of original) + +Distillation results for a memory-compressed Qwen3-8B checkpoint (80% of original size): + +| Model | MMLU | Humanities | Other | Social Sci | STEM | +|-------|------|------------|-------|------------|------| +| 80% pre-distillation | 0.5910 | 0.5046 | 0.6363 | 0.6831 | 0.5855 | +| 80% post-distillation | 0.6921 | 0.5906 | 0.7316 | 0.7975 | 0.7016 | +| Original Qwen3-8B | 0.7493 | 0.6648 | 0.7856 | 0.8385 | 0.7526 | + +**Key observations:** + +- MMLU accuracy improved from 59.10% to 69.21% (+10.11 percentage points) after distillation +- Achieved with just 100 iterations on WikiText-103, demonstrating efficient knowledge transfer +- Recovery of 64% of the gap to the teacher model (from 59.10% to 69.21%, closing 64% of the gap from 59.10% to 74.93%) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved significantly + +### Successful Case: Llama-3.1-8B-Instruct (50% of original, 56,810 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (50% of original size, 56,810 MiB memory constraint): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.2316 | 0.2462 | 0.2292 | 0.2250 | 0.2274 | +| After distillation | 0.2960 | 0.3146 | 0.3085 | 0.2925 | 0.2768 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) improved from 23.16% to 29.60% (+6.44 percentage points) +- All individual category scores (Humanities, Other, Social Sciences, STEM) improved, demonstrating effective knowledge transfer from teacher to student + +### Regression Case: Llama-3.1-8B-Instruct (69% of original, 78,000 MiB) + +Distillation results for a pruned Llama-3.1-8B-Instruct checkpoint (approximately 69% of original size, 78,000 MiB memory constraint) showing regression due to overfitting on the small WikiText-103 dataset (evaluated with limit 100): + +| Model | MMLU | Humanities | Other | Social Sciences | STEM | +|-------|------|------------|-------|-----------------|------| +| Before distillation | 0.6626 | 0.7069 | 0.6892 | 0.7525 | 0.5574 | +| After distillation | 0.6496 | 0.6862 | 0.6677 | 0.7433 | 0.5532 | +| Original Llama-3.1-8B-Instruct | 0.6839 | 0.7231 | 0.7038 | 0.7667 | 0.5911 | + +**Key observations:** + +- MMLU accuracy (average across all categories) decreased from 66.26% to 64.96% (-1.30 percentage points) after distillation +- The model overfitted to the small WikiText-103 dataset, causing performance regression +- This demonstrates the critical importance of using larger, more diverse datasets for knowledge distillation + +### Recommendations + +- **For production distillation:** Use larger production datasets like [nvidia/Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) for better results and to avoid overfitting (see regression case above) +- **Training duration:** Train for more iterations to ensure proper convergence +- **See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices** diff --git a/examples/puzzletron/mbridge_distillation/distill_hf.py b/examples/puzzletron/mbridge_distillation/distill_hf.py new file mode 100644 index 0000000000..a981e355d1 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_hf.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to `/checkpoints` in megatron distributed checkpoint format. + +See `README.md` in this directory for example usage and data preparation instructions. +""" + +import argparse +import os +import traceback + +import megatron.bridge.models.distillation_provider +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure +# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers +# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge +# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration. +# +# Note: Currently, bridges are also registered when distillation_provider is imported +# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider +# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 +import modelopt.torch.utils.distributed as dist + +# Use local copy of distillation_provider with fix for heterogeneous models +# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge +from modelopt.torch.puzzletron.export.mbridge.distillation_provider import ( + DistillationProvider, + convert_to_distillation_provider, +) +from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import ( + export_to_hf_and_copy_config, +) +from modelopt.torch.utils import print_rank_0 + +# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider +# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time +megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider + +# Import distill() AFTER patching so it uses the patched DistillationProvider +from megatron.bridge.training.distill import distill # noqa: E402 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments (accepts HuggingFace input only at the moment) + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training & Eval arguments + parser.add_argument( + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + # Logging arguments + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + # Export arguments + parser.add_argument( + "--hf-export-path", + type=str, + default=None, + help=( + "Path where to save the HuggingFace export. " + "If provided, exports checkpoint to HF format after distillation." + ), + ) + parser.add_argument( + "--hf-model", + type=str, + required=True, + help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). " + "Should match the base architecture of the student model.", + ) + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path) + provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.98, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) + ckpt_format="torch_dist", + async_save=True, + fully_parallel_save=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + + # Export to HuggingFace format if hf_export_path is provided + if args.hf_export_path: + # Wait for all ranks to finish distillation before export + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + # Save rank before destroying process group (dist.rank() won't work after destruction) + is_rank_0 = dist.rank() == 0 + + # Destroy process group on all ranks - export_ckpt will create its own temporary one + # This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # Only rank 0 exports + if is_rank_0: + try: + export_to_hf_and_copy_config( + student_hf_path=args.student_hf_path, + checkpoint_dir=checkpoint_dir, + train_iters=args.train_iters, + hf_export_path=args.hf_export_path, + hf_model=args.hf_model, + ) + except Exception as e: + print(f"⚠️ Export failed: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + except Exception as e: + print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}") + print_rank_0(f"Traceback:\n{traceback.format_exc()}") + raise + finally: + dist.cleanup() diff --git a/examples/puzzletron/requirements.txt b/examples/puzzletron/requirements.txt index db6894d631..0511fb473b 100644 --- a/examples/puzzletron/requirements.txt +++ b/examples/puzzletron/requirements.txt @@ -1,2 +1,3 @@ lm-eval==0.4.10 math-verify +ray diff --git a/examples/specdec_bench/README.md b/examples/specdec_bench/README.md index 770edf8d75..20b7202299 100644 --- a/examples/specdec_bench/README.md +++ b/examples/specdec_bench/README.md @@ -28,8 +28,19 @@ MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_ben Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`. ```bash -python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1 --postprocess gptoss - +python3 run.py \ + --model_dir openai/gpt-oss-120b \ + --tokenizer openai/gpt-oss-120b \ + --draft_model_dir /path/to/eagle \ + --mtbench question.jsonl \ + --tp_size 1 \ + --ep_size 1 \ + --draft_length 3 \ + --output_length 4096 \ + --num_requests 80 \ + --engine TRTLLM \ + --concurrency 1 \ + --postprocess gptoss ``` ### Running Random ids on GPT OSS + Eagle3 @@ -37,8 +48,101 @@ python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b - Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`. ```bash -python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1 +python3 run.py \ + --model_dir openai/gpt-oss-120b \ + --tokenizer openai/gpt-oss-120b \ + --draft_model_dir /path/to/eagle \ + --random_isl 1024 \ + --tp_size 1 \ + --ep_size 1 \ + --draft_length 3 \ + --output_length 4096 \ + --num_requests 40 \ + --engine TRTLLM \ + --concurrency 1 +``` + +### Running [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) on Llama 3.3 70B + Eagle 3 + +1. Install the requirements file using `pip install -r requirements_speed.txt` + +2. Prepare the data using the provided script: + +```bash +python3 prepare_data.py --dataset speed --config all +``` + +The data will be saved to `data/` directory, each config type (qualitative, throughput_1k, ...) to each own directory. + +#### License + +GOVERNING TERMS: This dataset is governed by the NVIDIA Evaluation Dataset License Agreement. + +ADDITIONAL INFORMATION: MIT for bigcode/humanevalpack, RUCAIBox/MMATH, RUCAIBox/BAMBOO and EQ-Bench. Apache 2.0 for Writing Bench and Spec-Bench. CC BY 4.0 for FBK-MT/MCIF. MIT and Apache 2.0 for tianyang/repobench_python_v1.1, JetBrains-Research/lca-project-level-code-completion and tianyang/repobench_java_v1.1. + +NOTICE: For each dataset a user elects to use, the user is responsible for checking if the dataset license is fit for the intended purpose. The `prepare_data.py` script automatically fetches data from all the source datasets. + +Additional details are in [HuggingFace dataset repository](https://huggingface.co/datasets/nvidia/SPEED-Bench). + +#### Qualitative split + +```bash +python3 run.py \ + --model_dir meta-llama/Llama-3.3-70B-Instruct \ + --tokenizer meta-llama/Llama-3.3-70B-Instruct \ + --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B \ + --dataset speed \ + --dataset_path data/speed/qualitative \ + --tp_size 8 \ + --ep_size 1 \ + --draft_length 3 \ + --output_length 4096 \ + --engine TRTLLM \ + --concurrency 32 \ + --show_progress +``` + +#### Throughput split +```bash +python3 run.py \ + --model_dir meta-llama/Llama-3.3-70B-Instruct \ + --tokenizer meta-llama/Llama-3.3-70B-Instruct \ + --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B \ + --dataset speed \ + --dataset_path data/speed/throughput_1k \ + --tp_size 8 \ + --ep_size 1 \ + --draft_length 3 \ + --output_length 4096 \ + --engine TRTLLM \ + --concurrency 32 \ + --show_progress +``` + +For longer context (>8192 tokens), please use the following configuration when using TRTLLM: + +```yaml +engine_args: + max_seq_len: 131072 # Model max context length (for Llama 3.3 70B) + enable_chunked_prefill: true +``` + +```bash +python3 run.py \ + --model_dir meta-llama/Llama-3.3-70B-Instruct \ + --tokenizer meta-llama/Llama-3.3-70B-Instruct \ + --draft_model_dir yuhuili/EAGLE3-LLaMA3.3-Instruct-70B \ + --dataset speed \ + --dataset_path data/speed/throughput_16k \ + --tp_size 8 \ + --ep_size 1 \ + --draft_length 3 \ + --output_length 4096 \ + --engine TRTLLM \ + --concurrency 32 \ + --show_progress \ + --runtime_params runtime_args_long_context.yaml ``` ## Notes diff --git a/examples/specdec_bench/SPECBENCH_PORTING.md b/examples/specdec_bench/SPECBENCH_PORTING.md new file mode 100644 index 0000000000..bdc2be5c31 --- /dev/null +++ b/examples/specdec_bench/SPECBENCH_PORTING.md @@ -0,0 +1,329 @@ +# Porting Spec-Bench Inference Runners to specdec_bench + +This guide explains how to convert any `inference_*.py` runner from [Spec-Bench](https://github.com/hemingkx/Spec-Bench) to a model class compatible with `specdec_bench`. + +## Overview + +Spec-Bench inference runners follow a pattern where: + +1. A `*_forward()` function handles the speculative decoding logic +2. The `run_eval()` function orchestrates evaluation with tokenized inputs +3. Models are loaded in `__main__` and passed to `run_eval()` + +In contrast, `specdec_bench` uses a class-based approach where: + +1. Models inherit from the `Model` base class +2. `__init__()` handles model loading +3. `run()` is an async method that processes single requests +4. `stop()` handles cleanup + +## The specdec_bench Model Interface + +```python +class Model: + def __init__(self, model_dir, tokenizer, max_draft_length): + raise NotImplementedError + + async def run(self, prompt_ids, sampling_params, request_id, turn_id): + """ + prompt_ids: list of token IDs (not a tensor!) + Returns dict with: + - output_ids: list of list of token chunks per step [[chunk1, chunk2, ...]] + - output_logits: optional logits (usually None) + - token_times: list of timestamps per decoding step + """ + raise NotImplementedError + + def stop(self): + pass +``` + +## Step-by-Step Porting Guide + +### Step 1: Identify the Key Components in Spec-Bench + +Look at the `inference_*.py` file and identify: + +1. **The forward function** (e.g., `medusa_forward`, `ea_forward`) + - This contains the core speculative decoding loop + - Signature: `forward_func(inputs, model, tokenizer, max_new_tokens, **kwargs)` + - Returns: `(output_ids, new_token_count, num_steps, accept_length_list)` + +2. **The model class** (e.g., `MedusaModel`, `EaModel`) + - Found in `model//` directory + - Has a `from_pretrained()` class method + +3. **Required utilities** from the method's module: + - Buffer generation (e.g., `generate_medusa_buffers`) + - Initialization functions (e.g., `initialize_medusa`, `initialize_past_key_values`) + - Decoding functions (e.g., `tree_decoding`, `generate_candidates`) + - State update functions (e.g., `update_inference_inputs`) + +4. **Method-specific choices/configs** (e.g., `mc_sim_7b_63` for Medusa) + +### Step 2: Create the specdec_bench Model Class + +```python +# specdec_bench/specdec_bench/models/specbench_.py + +from .base import Model +import asyncio +import time +import torch + +# Import dependencies from Spec-Bench +try: + import sys + import os + spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench") + sys.path.insert(0, spec_bench_path) + from model.. import + from model..kv_cache import initialize_past_key_values + from model..utils import ( + # Import all required utilities + ) + from model.. import +except ImportError as e: + print(f" dependencies not found: {e}") + = None + + +class SpecBenchModel(Model): + def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs): + # 1. Validate dependencies + if is None: + raise ImportError(" dependencies not found.") + + # 2. Extract configuration from kwargs + self.dtype = kwargs.get("dtype", "float16") + self.max_steps = kwargs.get("max_steps", 512) + self.temperature = sampling_kwargs.get("temperature", 0.0) + # ... other method-specific parameters + + # 3. Set up device (avoid device_map="auto" for multi-GPU issues) + self.device = torch.device(kwargs.get("device", "cuda:0")) + + # 4. Convert dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map.get(self.dtype, torch.float16) + + # 5. Load the model + self.model = .from_pretrained( + model_dir, + # ... other args from Spec-Bench's __main__ + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(self.device) + + self.sampling_kwargs = sampling_kwargs +``` + +### Step 3: Port the Forward Function + +Convert the standalone `*_forward()` function to an internal method: + +```python + def _forward(self, input_ids, max_new_tokens, end_id): + """ + Port of the original *_forward function. + + Key changes from Spec-Bench: + 1. input_ids is already a tensor (converted in run()) + 2. Add timing list to track per-step timestamps + 3. Use self.device instead of model.base_model.device + 4. Return timing along with other outputs + """ + accept_length_list = [] + timing = [time.perf_counter()] # ADD: Track timing + + # === COPY THE FORWARD LOGIC FROM SPEC-BENCH === + # Replace: device=model.base_model.device + # With: device=self.device + + # Initialize buffers... + # Initialize KV cache... + # Main decoding loop... + + for idx in range(self.max_steps): + # Generate candidates... + # Tree decoding... + # Evaluate posterior... + # Update inputs... + + timing.append(time.perf_counter()) # ADD: Record time per step + + # Check for EOS + if end_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + + return input_ids, new_token, idx + 1, accept_length_list, timing # ADD timing +``` + +### Step 4: Implement the run() Method + +```python + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """ + Async interface for specdec_bench. + + Args: + prompt_ids: List of input token IDs (NOT a tensor) + max_length: Maximum new tokens to generate + end_id: EOS token ID + request_id: Request identifier + turn_id: Turn identifier + + Returns: + dict with output_ids, output_logits, token_times + """ + output_dict = {} + + # Convert prompt_ids list to tensor + input_ids = torch.tensor( + [prompt_ids], dtype=torch.long, device=self.device + ) + + # Run forward pass (use asyncio.to_thread for sync code) + result = await asyncio.to_thread( + self._forward, input_ids, max_length, end_id + ) + input_ids_out, new_token, num_steps, accept_length_list, timing = result + + # Extract generated tokens (excluding prompt) + original_len = len(prompt_ids) + generated_tokens = input_ids_out[0, original_len:].tolist() + + # Remove EOS token if present + if end_id in generated_tokens: + eos_idx = generated_tokens.index(end_id) + generated_tokens = generated_tokens[:eos_idx] + + # Format output_ids as list of token chunks per step + # This matches specdec_bench's expected format + reformatted_output_ids = [[]] + start = 0 + for accept_len in accept_length_list: + if accept_len > 0 and start < len(generated_tokens): + chunk = generated_tokens[start:start + accept_len] + if chunk: + reformatted_output_ids[0].append(chunk) + start += accept_len + + # Handle remaining tokens + if start < len(generated_tokens): + reformatted_output_ids[0].append(generated_tokens[start:]) + + output_dict['output_ids'] = reformatted_output_ids + output_dict['output_logits'] = None + output_dict['token_times'] = timing + + return output_dict +``` + +### Step 5: Implement stop() for Cleanup + +```python + def stop(self): + """Clean up resources.""" + # Clear any cached states + if hasattr(self.model, "past_key_values"): + del self.model.past_key_values + del self.model.past_key_values_data + del self.model.current_length_data + + # Clear method-specific buffers + if hasattr(self.model, "_buffers"): + del self.model._buffers + + # Free GPU memory + if hasattr(self, 'model') and self.model is not None: + del self.model + torch.cuda.empty_cache() +``` + +### Step 6: Register the Model (Optional) + +Add to `specdec_bench/specdec_bench/models/__init__.py`: + +```python +from .specbench_ import SpecBenchModel +``` + +## Key Differences Summary + +| Aspect | Spec-Bench | specdec_bench | +|--------|-----------|---------------| +| Input format | `inputs.input_ids` (tensor from tokenizer) | `prompt_ids` (list of ints) | +| Output format | `(output_ids, new_token, steps, accept_lengths)` | `dict` with `output_ids`, `output_logits`, `token_times` | +| Output IDs | Full sequence tensor | List of token chunks per step | +| Timing | External (in `run_eval`) | Internal (in `run()`) | +| Device | `device_map="auto"` | Explicit single device | +| Interface | Function-based | Class-based with async `run()` | + +## Common Pitfalls + +1. **Device Mismatch**: Avoid `device_map="auto"` which spreads model across GPUs. Use explicit `.to(device)`. + +2. **Tensor vs List**: `prompt_ids` in specdec_bench is a Python list, not a tensor. Convert it in `run()`. + +3. **Output Format**: specdec_bench expects `output_ids` as `[[chunk1, chunk2, ...]]` (list of lists of lists for beam_width=1). + +4. **Timing**: Add `time.perf_counter()` calls to track per-step latency. + +5. **EOS Handling**: Strip EOS tokens from output before formatting. + +6. **Async Wrapper**: Use `asyncio.to_thread()` to wrap synchronous forward passes. + +## Example: Mapping Spec-Bench Methods + +| Spec-Bench File | Model Class | Forward Function | Key Utils | +|-----------------|-------------|------------------|-----------| +| `inference_medusa.py` | `MedusaModel` | `medusa_forward` | `generate_medusa_buffers`, `initialize_medusa` | +| `inference_eagle.py` | `EaModel` | `ea_forward` | `generate_tree_buffers`, `initialize_tree` | +| `inference_eagle2.py` | `EaModel` | `ea_forward` | Same as EAGLE | +| `inference_hydra.py` | `HydraModel` | `hydra_forward` | `generate_hydra_buffers`, `initialize_hydra` | +| `inference_lookahead.py` | `LookaheadModel` | `lookahead_forward` | Lookahead-specific utils | + +## Testing Your Port + +```python +import asyncio + +async def test(): + model = SpecBenchModel( + model_dir="/path/to/model", + max_concurrent_requests=1, + sampling_kwargs={"temperature": 0.0}, + # method-specific kwargs... + ) + + result = await model.run( + prompt_ids=[1, 2, 3, 4, 5], # Example token IDs + max_length=100, + end_id=2, # EOS token + request_id="test", + turn_id=0 + ) + + print("Output chunks:", result['output_ids']) + print("Timing:", result['token_times']) + + model.stop() + +asyncio.run(test()) +``` + +Adjust the vicuna chat template to be in the tokenizer_config to be + +Insert to tokenizer_config (for vicuna) + +```json +"chat_template": "{% set ns = namespace(system='') %}{% for m in messages %}{% if m['role'] == 'system' %}{% set ns.system = m['content'] %}{% endif %}{% endfor %}{{ ns.system | trim }}{% if ns.system | trim != '' %} {% endif %}{% for m in messages %}{% if m['role'] == 'user' %}USER: {{ m['content'] | trim }} ASSISTANT:{% elif m['role'] == 'assistant' %}{{ m['content'] | trim }}{% endif %}{% endfor %}" +``` diff --git a/examples/specdec_bench/prepare_data.py b/examples/specdec_bench/prepare_data.py new file mode 100644 index 0000000000..67fe898983 --- /dev/null +++ b/examples/specdec_bench/prepare_data.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +from typing import get_args + +from specdec_bench import datasets +from specdec_bench.datasets.speed import config_type + +datasets_available = { + "speed": datasets.SPEEDBench, +} + + +def prepare_data(args: argparse.Namespace) -> None: + """Prepare and save benchmark data to disk. + + Calls the dataset's ``prepare_data`` classmethod which downloads and + resolves all external data references, then saves the fully-resolved + result as a parquet file so that subsequent benchmark runs can load + directly from disk without re-downloading. + + Args: + args: Parsed CLI arguments containing dataset type, config, + output directory, and optional filtering parameters. + """ + configs = get_args(config_type) if args.config == "all" else [args.config] + + dataset_cls = datasets_available[args.dataset] + + for config in configs: + print(f"Preparing config '{config}' ...") + + output_path = dataset_cls.prepare_data( + output_dir=args.output_dir / args.dataset / config, + config_name=config, + ) + + print(f" -> Saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download and prepare benchmark datasets for specdec_bench.", + ) + parser.add_argument( + "--dataset", + type=str, + default="speed", + choices=list(datasets_available.keys()), + help="Dataset to prepare (default: %(default)s)", + ) + parser.add_argument( + "--config", + type=str, + default="all", + choices=[*list(get_args(config_type)), "all"], + help='SPEED-Bench configuration to prepare. Use "all" to prepare all configs. (default: %(default)s)', + ) + parser.add_argument( + "--output_dir", + type=Path, + default=Path("data/"), + help="Directory to save the prepared dataset files (default: %(default)s)", + ) + + args = parser.parse_args() + prepare_data(args) diff --git a/examples/specdec_bench/requirements_speed.txt b/examples/specdec_bench/requirements_speed.txt new file mode 100644 index 0000000000..5b0117e3a7 --- /dev/null +++ b/examples/specdec_bench/requirements_speed.txt @@ -0,0 +1,4 @@ +datasets>=4.4.0,<5.0.0 +rich>=14.2.0 +seaborn>=0.13.2 +tiktoken>=0.12.0 diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index dd57a51427..bd99cff56e 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -25,15 +25,48 @@ postprocess_base, postprocess_gptoss, ) +from tqdm.asyncio import tqdm engines_available = { "TRTLLM": models.TRTLLMPYTModel, "VLLM": models.VLLMModel, "SGLANG": models.SGLANGModel, + "AUTO_DEPLOY": models.AutoDeployModel, + "SPECBENCH_MEDUSA": models.SpecBenchMedusaModel, } +datasets_available = { + "mtbench": datasets.MTBench, + "random": datasets.RandomToken, + "specbench": datasets.SpecBench, + "speed": datasets.SPEEDBench, +} + + +async def tqdm_gather(*fs, return_exceptions=False, **kwargs): + if not return_exceptions: + return await tqdm.gather(*fs, **kwargs) + + async def wrap(f): + try: + return await f + except Exception as e: + return e + return await tqdm.gather(*map(wrap, fs), **kwargs) -async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10): + +async def run_loop( + runner, + dataset, + tokenizer, + output_length, + postprocess, + concurrency=10, + end_id=-1, + show_progress=False, + completions=False, + chat_template_args={}, +): """ Async version of run_loop with concurrency control using a semaphore. @@ -46,7 +79,6 @@ async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concu """ semaphore = asyncio.Semaphore(concurrency) max_length = output_length - end_id = tokenizer.eos_token_id async def process_single_request(request, i): """Process a single request with all its conversation turns.""" @@ -57,7 +89,12 @@ async def process_single_request(request, i): for turn_id, question in enumerate(request.turns): messages.append({"role": "user", "content": question}) - entry_encoded = encode_chat(tokenizer, messages) + entry_encoded = encode_chat( + tokenizer, + messages, + chat_template_args=chat_template_args, + completions=completions, + ) # Run the async runner.run directly output_tokens = await runner.run( @@ -70,12 +107,19 @@ async def process_single_request(request, i): return messages tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)] - text_outputs = await asyncio.gather(*tasks, return_exceptions=True) + if show_progress: + text_outputs = await tqdm_gather( + *tasks, + return_exceptions=True, + desc=f"Running requests (concurrency={concurrency})", + ) + else: + text_outputs = await asyncio.gather(*tasks, return_exceptions=True) # Check for any exceptions and handle them for i, result in enumerate(text_outputs): if isinstance(result, Exception): - print(f"Error processing request {i}: {result}") + print(f"Error processing request {i}/{dataset.data[i].question_id}: {result}") raise result runner.process_metrics_final(text_outputs) @@ -83,14 +127,23 @@ async def process_single_request(request, i): def run_simple(args): - tokenizer = get_tokenizer(args.tokenizer) + tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + chat_template_args = args.runtime_params.get("chat_template_args", {}) dataset_kwargs = args.runtime_params.get("dataset_kwargs", {}) - if args.mtbench is not None: - dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs) + if args.num_requests is not None: + dataset_kwargs["num_samples"] = args.num_requests + if args.dataset is not None: + if args.dataset == "random": + assert args.random_isl is not None, "Random input length must be provided" + dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs) + else: + dataset = datasets_available[args.dataset](args.dataset_path, **dataset_kwargs) + elif args.mtbench is not None: + dataset = datasets.MTBench(args.mtbench, **dataset_kwargs) elif args.random_isl is not None: - dataset = datasets.RandomToken( - tokenizer, args.random_isl, args.num_requests, **dataset_kwargs - ) + dataset = datasets.RandomToken(tokenizer, args.random_isl, **dataset_kwargs) + elif args.specbench is not None: + dataset = datasets.SpecBench(args.specbench, **dataset_kwargs) engine_args = args.runtime_params.get("engine_args", {}) sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0}) model_class = engines_available[args.engine] @@ -103,6 +156,7 @@ def run_simple(args): speculative_num_steps=args.draft_length, tensor_parallel_size=args.tp_size, moe_expert_parallel_size=args.ep_size, + trust_remote_code=args.trust_remote_code, **engine_args, ) @@ -111,8 +165,15 @@ def run_simple(args): metrics_list.append(metrics.AATiming(tokenizer)) if args.mtbench is not None: metrics_list.insert(0, metrics.MTBench()) + elif args.specbench is not None or args.dataset == "speed": + metrics_list.insert(0, metrics.SpecBench(requests=dataset.data)) else: metrics_list.insert(0, metrics.AcceptanceRate()) + + if args.save_dir is not None: + for metric in metrics_list: + metric.update_directory(args.save_dir) + runner = runners.SimpleRunner(model, metrics=metrics_list) if args.postprocess == "base": @@ -122,8 +183,21 @@ def run_simple(args): else: raise ValueError(f"Invalid postprocess: {args.postprocess}") + end_id = tokenizer.eos_token_id if not args.ignore_eos else -1 + asyncio.run( - run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency) + run_loop( + runner, + dataset, + tokenizer, + args.output_length, + postprocess, + args.concurrency, + end_id, + args.show_progress, + args.completions, + chat_template_args, + ) ) runner.clear_metrics() @@ -135,7 +209,18 @@ def run_simple(args): "--tokenizer", type=str, required=True, help="Path to the tokenizer directory" ) parser.add_argument( - "--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset" + "--mtbench", + type=str, + required=False, + default=None, + help="Path to the mtbench dataset", + ) + parser.add_argument( + "--specbench", + type=str, + required=False, + default=None, + help="Path to the specbench dataset", ) parser.add_argument( "--random_isl", @@ -144,7 +229,28 @@ def run_simple(args): default=None, help="How many tokens random input should be.", ) - parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run") + parser.add_argument( + "--dataset", + type=str, + required=False, + default=None, + choices=list(datasets_available.keys()), + help="Dataset to use", + ) + parser.add_argument( + "--dataset_path", + type=str, + required=False, + default=None, + help="Path to the dataset or config name for SPEEDBench", + ) + parser.add_argument( + "--num_requests", + type=int, + required=False, + default=None, + help="Number of requests to run. If not provided, all requests from the dataset will be run.", + ) parser.add_argument( "--engine", type=str, @@ -193,7 +299,17 @@ def run_simple(args): default=1, help="Maximum number of concurrent requests", ) + parser.add_argument( + "--trust_remote_code", action="store_true", help="Trust remote code for tokenizer and model" + ) parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric") + parser.add_argument("--ignore_eos", action="store_true", help="Ignore EOS token") + parser.add_argument("--show_progress", action="store_true", help="Show progress bar") + parser.add_argument( + "--completions", + action="store_true", + help="Skip chat template, tokenize the message directly", + ) parser.add_argument( "--postprocess", type=str, @@ -202,7 +318,13 @@ def run_simple(args): choices=["base", "gptoss"], help="Postprocess to use", ) - + parser.add_argument( + "--save_dir", + type=str, + required=False, + default=None, + help="Directory to save the results", + ) args = parser.parse_args() if args.runtime_params is not None: @@ -210,9 +332,20 @@ def run_simple(args): args.runtime_params = yaml.safe_load(f) else: args.runtime_params = {} + if args.dataset is None: + assert ( + args.mtbench is not None or args.random_isl is not None or args.specbench is not None + ), "Either mtbench or random_isl or specbench must be provided" + else: + assert args.dataset_path is not None, "Dataset path must be provided" + if args.dataset == "specbench": + args.specbench = args.dataset_path + elif args.dataset == "mtbench": + args.mtbench = args.dataset_path - assert args.mtbench is not None or args.random_isl is not None, ( - "Either mtbench or random_isl must be provided" - ) + if args.ignore_eos: + print( + "Warning: Ignore EOS should only be used in certain cases, do no activate unless necessary" + ) run_simple(args) diff --git a/examples/specdec_bench/specdec_bench/__init__.py b/examples/specdec_bench/specdec_bench/__init__.py index 3159bfe656..47f1c65a15 100644 --- a/examples/specdec_bench/specdec_bench/__init__.py +++ b/examples/specdec_bench/specdec_bench/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,3 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + diff --git a/examples/specdec_bench/specdec_bench/datasets/__init__.py b/examples/specdec_bench/specdec_bench/datasets/__init__.py index 64449d2b58..aefc2605bf 100644 --- a/examples/specdec_bench/specdec_bench/datasets/__init__.py +++ b/examples/specdec_bench/specdec_bench/datasets/__init__.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Dataset -from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat from .mtbench import MTBench from .random_token import RandomToken +from .specbench import SpecBench +from .speed import SPEEDBench + +__all__ = ["MTBench", "RandomToken", "SPEEDBench", "SpecBench"] diff --git a/examples/specdec_bench/specdec_bench/datasets/base.py b/examples/specdec_bench/specdec_bench/datasets/base.py index 587c04b07a..eb72affa49 100644 --- a/examples/specdec_bench/specdec_bench/datasets/base.py +++ b/examples/specdec_bench/specdec_bench/datasets/base.py @@ -14,11 +14,14 @@ # limitations under the License. from dataclasses import dataclass, field +from pathlib import Path from typing import Any @dataclass class Request: + question_id: int | None = None + category: str | None = None system_prompt: str | None = None turns: list[str] = field(default_factory=list) mm_content: Any | None = None # TODO @@ -35,3 +38,22 @@ def __init__(self, path, **kwargs): def _preprocess(self): raise NotImplementedError + + @classmethod + def prepare_data(cls, output_dir: str | Path, **kwargs) -> Path: + """Prepare and save the dataset to the specified output directory. + + Downloads any external data, resolves all references, and persists + the fully-resolved dataset so that subsequent loads are self-contained. + + Args: + output_dir: Directory where the prepared data file will be saved. + **kwargs: Dataset-specific parameters (e.g. config_name, category). + + Returns: + Path to the saved dataset file. + + Raises: + NotImplementedError: Subclasses must override this method. + """ + raise NotImplementedError diff --git a/examples/specdec_bench/specdec_bench/datasets/base_hf.py b/examples/specdec_bench/specdec_bench/datasets/base_hf.py deleted file mode 100644 index 6c7be3d8ca..0000000000 --- a/examples/specdec_bench/specdec_bench/datasets/base_hf.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -try: - from datasets import load_dataset -except ImportError: - print("datasets is not installed.") - datasets = None - - -from .base import Dataset, Request - - -class BaseHF(Dataset): - def __init__(self, num_samples=100, **kwargs): - self.data: list[Request] = [] # list of list of questions. - self.num_samples = num_samples - self._preprocess() - - def _preprocess(self): - dataset = self._load_dataset(self.num_samples) - for i, line in enumerate(dataset): - if i == self.num_samples: - break - self.data.append(self._single_line_process(line)) - - def _single_line_process(self, line): - raise NotImplementedError - - def _load_dataset(self, num_samples): - raise NotImplementedError - - -class OpenOrca(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request(system_prompt=line["system_prompt"], turns=[line["question"]]) - - def _load_dataset(self, num_samples): - return load_dataset("Open-Orca/OpenOrca", split="train", streaming=True) - - -class OpenMathInstructv2(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request(system_prompt=None, turns=[line["problem"]]) - - def _load_dataset(self, num_samples): - return load_dataset("nvidia/OpenMathInstruct-2", split="train_1M", streaming=True) - - -class UltraChat(BaseHF): - def _single_line_process(self, line, **kwargs): - return Request( - system_prompt=None, turns=[q for i, q in enumerate(line["data"]) if i % 2 == 0] - ) - - def _load_dataset(self, num_samples): - return load_dataset("stingning/ultrachat", split="train", streaming=True) diff --git a/examples/specdec_bench/specdec_bench/datasets/mtbench.py b/examples/specdec_bench/specdec_bench/datasets/mtbench.py index 53295bdbb4..cb58dd2103 100644 --- a/examples/specdec_bench/specdec_bench/datasets/mtbench.py +++ b/examples/specdec_bench/specdec_bench/datasets/mtbench.py @@ -33,11 +33,10 @@ class MTBench(Dataset): def __init__(self, path, num_samples=80, **kwargs): self.data: list[Request] = [] # list of list of questions. self.num_samples = num_samples - self.path = path - self._preprocess() + self._preprocess(path) - def _preprocess(self): - with open(self.path) as f: + def _preprocess(self, path): + with open(path) as f: for json_line in f: line = json.loads(json_line) key = "turns" if "turns" in line else "prompt" diff --git a/examples/specdec_bench/specdec_bench/datasets/random_token.py b/examples/specdec_bench/specdec_bench/datasets/random_token.py index 972a0455c2..521db57ee2 100644 --- a/examples/specdec_bench/specdec_bench/datasets/random_token.py +++ b/examples/specdec_bench/specdec_bench/datasets/random_token.py @@ -24,12 +24,10 @@ def __init__(self, tokenizer, input_len, num_samples=20, **kwargs): self.data: list[Request] = [] # list of list of questions. self.num_samples = num_samples self.input_len = input_len - self.tokenizer = tokenizer - self._preprocess() + self._preprocess(tokenizer) - def _preprocess(self): + def _preprocess(self, tokenizer): np.random.seed(0) - tokenizer = self.tokenizer num_prompts = self.num_samples offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) for i in range(num_prompts): diff --git a/examples/specdec_bench/specdec_bench/datasets/specbench.py b/examples/specdec_bench/specdec_bench/datasets/specbench.py new file mode 100644 index 0000000000..a14d340390 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/specbench.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from .base import Dataset, Request + + +class SpecBench(Dataset): + def __init__(self, path, num_samples=480, **kwargs): + self.data: list[Request] = [] # list of list of questions. + self.num_samples = num_samples + self._preprocess(path) + + def _preprocess(self, path): + with open(path) as f: + for json_line in f: + line = json.loads(json_line) + self.data.append( + Request( + question_id=line["question_id"], + category=line["category"], + system_prompt=None, + turns=line["turns"], + ) + ) + self.data = self.data[: self.num_samples] diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py new file mode 100644 index 0000000000..e3429126d9 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -0,0 +1,807 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: disable-error-code="index" +import random +import re +from enum import Enum +from pathlib import Path +from typing import Any, Literal, get_args + +from .base import Dataset, Request + +try: + import numpy as np + import pandas as pd + import tiktoken + from datasets import concatenate_datasets, load_dataset + + not_installed = False +except ImportError: + not_installed = True + + +config_type = Literal[ + "qualitative", + "throughput_1k", + "throughput_2k", + "throughput_8k", + "throughput_16k", + "throughput_32k", +] +TURNS_PLACEHOLDER = "FULL BENCHMARK DATA SHOULD BE FETCHED FROM THE SOURCE USING SPECDEC_BENCH" + + +class BenchmarkDataset(str, Enum): + """Enum for benchmark datasets used in SPEED-Bench. + + Each enum value represents a HuggingFace dataset identifier used for + loading external benchmark datasets. + """ + + BAMBOO = "RUCAIBox/BAMBOO" + CNN_DAILYMAIL = "abisee/cnn_dailymail" + HLE = "cais/hle" + LIVECODEBENCH = "livecodebench/code_generation_lite" + CODE_CONTESTS = "deepmind/code_contests" + MTBENCH_101 = "mtbench101/mt-bench-101" + OPUS100 = "Helsinki-NLP/opus-100" + CHATRAG_BENCH = "nvidia/ChatRAG-Bench" + MMLU_PRO = "TIGER-Lab/MMLU-Pro" + ADALEVAL_STACKSELECT = "AdaLEval/stackselect" + ADALEVAL_TEXTSORT = "AdaLEval/textsort" + ROLEBENCH = "ZenMoore/RoleBench" + ROLEBENCH_ROLES = "ZenMoore/RoleBench/roles" + COSER = "Neph0s/CoSER" + + +DATASETS_AND_LOADERS_FUNCTIONS = { + BenchmarkDataset.BAMBOO.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.CNN_DAILYMAIL.value: lambda dataset_name, config_name: load_dataset( + dataset_name, config_name, split="test" + ), + BenchmarkDataset.HLE.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="021a3d71f516a7ac28ceb8d284969902edf1edeb" + ) + if config_name != "train_test_split" + else load_dataset( + dataset_name, split="test", revision="021a3d71f516a7ac28ceb8d284969902edf1edeb" + ).train_test_split(test_size=0.5, shuffle=True, seed=42), + BenchmarkDataset.LIVECODEBENCH.value: lambda dataset_name, config_name: load_dataset( + "json", + data_files={ + "test": [ + f"https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/0fe84c3912ea0c4d4a78037083943e8f0c4dd505/{file_name}.jsonl" + for file_name in ["test", "test2", "test3", "test4", "test5", "test6"] + ] + }, + split="test", + ), + BenchmarkDataset.CODE_CONTESTS.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="802411c3010cb00d1b05bad57ca77365a3c699d6" + ), + BenchmarkDataset.MTBENCH_101.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.OPUS100.value: lambda dataset_name, config_name: load_dataset( + dataset_name, + config_name, + split="test", + revision="805090dc28bf78897da9641cdf08b61287580df9", + ), + BenchmarkDataset.CHATRAG_BENCH.value: lambda dataset_name, config_names: concatenate_datasets( + [ + load_dataset( + dataset_name, + config_name, + split="test", + revision="af6c7d420ddddf21f54f8ab3394bbf462aad2577", + ) + for config_name in config_names + ] + ), + BenchmarkDataset.MMLU_PRO.value: lambda dataset_name, config_name: load_dataset( + dataset_name, split="test", revision="30527804ea8854662078e457808040d872ecdf29" + ), + BenchmarkDataset.ADALEVAL_STACKSELECT.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.ADALEVAL_TEXTSORT.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.ROLEBENCH.value: lambda dataset_name, config_name: pd.read_json( + config_name, lines=True + ), + BenchmarkDataset.ROLEBENCH_ROLES.value: lambda dataset_name, config_name: load_dataset( + "json", data_files={"test": config_name}, split="test" + ), + BenchmarkDataset.COSER.value: lambda dataset_name, config_name: load_dataset( + "json", + data_files={"test": config_name.replace("tree", "raw") + "/test/test_set.json"}, + split="test", + ), +} + + +class SPEEDBench(Dataset): + def __init__( + self, + config_name: config_type = "qualitative", + num_samples: int | None = None, + _prepare_mode: bool = False, + **kwargs, + ): + if not_installed: + raise ImportError( + "Additional packages are required to use SPEED-Bench. Please run `pip install -r requirements_speed.txt`" + ) + self.data: list[Request] = [] + self.num_samples = num_samples + self.external_datasets: dict[str, Any] = {} + self._config_name = config_name + self._resolved_dataset = None + self._preprocess(config_name, _prepare_mode=_prepare_mode) + + def _get_external_dataset(self, dataset_name: str, config_name: str | list[str] = "default"): + full_name = f"{dataset_name}_{config_name}" + if full_name not in self.external_datasets: + self.external_datasets[full_name] = DATASETS_AND_LOADERS_FUNCTIONS[dataset_name]( + dataset_name, config_name + ) + if config_name == "train_test_split": + self.external_datasets[full_name] = ( + self.external_datasets[full_name]["train"], + self.external_datasets[full_name]["test"], + ) + return self.external_datasets[full_name] + + @staticmethod + def _generate_stackselect_prompt( + question: str, answers: list[str], answer: str, num_tokens: int + ) -> str: + random.seed(42) + encoder = tiktoken.get_encoding("o200k_base") + # Original prompt as given in Ada-LEval paper: https://arxiv.org/pdf/2404.06480 + prompt = """ +You are an AI assistant. Your job is to find out the most helpful answer to a given question. +Each time, you will be provided with a question and n answers to this question. +Each answer begins with an 'A' and a number(e.g. A4), which represents its designation. +You need to determine which answer is the most helpful one to the question. +The case sample is shown below and you should give me the answer in the format exactly the same as the sample. + +However, you should NOT focus on the content of sample answer. + +Sample Input (format only): + +The question is given below. +XXX(The content of question) +Possible answers are given below. +A1: +XXX(The content of answer 1) +A2: +XXX(The content of answer 2) +. +. +. +An: +XXX(The content of answer n) +Now the answers are over, please decide which answer is the most helpful one to the question. +You must give me the designation of the MOST helpful answer and the reason why you choose this answer. +For every other answer, you must give me the reason why you do not choose this answer. + +Sample Output (format only): + +Answer: The designation of the most helpful answer.(e.g. A4 means answer 4 is the most helpful answer) +Explanation: +A4: The reason why you choose this answer. +A1: The reason why you do not choose this answer. +A2: The reason why you do not choose this answer. +. +. +. +An: The reason why you do not choose this answer. +""" + prompt += "The question is given below.\n" + prompt += question + "\n\n" + prompt += "Possible answers are given below.\n" + tokens_prompt = len(encoder.encode(prompt, disallowed_special=())) + end_prompt = "Now the answers are over, please decide which answer is the most helpful one to the question. \n" + end_prompt += "You must give me the designation of the MOST helpful answer and the reason why you choose this answer.\n" + end_prompt += "For every other answer, you must give me the reason why you do not choose this answer.\n" + end_prompt_tokens = len(encoder.encode(end_prompt, disallowed_special=())) + correct_answer_i = int(answer.strip("A")) - 1 + correct_answer_tokens = len( + encoder.encode( + answer + ":\n\n" + answers[correct_answer_i] + "\n\n", + disallowed_special=(), + ) + ) + all_tokens = tokens_prompt + end_prompt_tokens + correct_answer_tokens + answers_to_add_stop = 0 + for i, answer in enumerate(answers): + if i == correct_answer_i: + continue + answer_to_add = f"A{i + 1}:\n\n{answer}\n\n" + answer_to_add_tokens = len(encoder.encode(answer_to_add, disallowed_special=())) + if all_tokens + answer_to_add_tokens > num_tokens: + break + answers_to_add_stop = i + answers_to_add = ( + answers[: answers_to_add_stop + 1] + if answers_to_add_stop >= correct_answer_i + else [answers[correct_answer_i]] + answers[: answers_to_add_stop + 1] + ) + random.shuffle(answers_to_add) + for i, answer in enumerate(answers_to_add): + prompt += f"A{i + 1}:\n\n{answer}\n\n" + prompt += end_prompt + return prompt + + @staticmethod + def _generate_textsort_prompt(prompt: str) -> str: + # Original prompt as given in Ada-LEval paper: https://arxiv.org/pdf/2404.06480 + original_instruction = "\n You are an AI assistant. Your job is to sort multiple book sections into the correct order.\n Each time, you will be provided with 4 pieces of text.\n These texts form a continuous part of a book, but are provided in random order.\n You need to find the correct order and return the answer in a string.\n For example, if you output [4, 1, 3, 2], that means the correct order is: Part 4 -> Part 1 -> Part 3 -> Part 2.\n You will also be provided with the neighboring paragraphs before and after the 4 pieces of texts. \n\n The case sample is shown below and you should give me the answer in the format exactly the same as the sample. \n\n However, you should NOT focus on the content of sample answer. \n\n Please do NOT output any extra content. \n Sample Input (format only): \n\n Before: XXX (Text before the continuous book part)\n\n\n Part 1: XXX\n\n\n Part 2: XXX\n\n\n Part 3: XXX\n\n\n Part 4: XXX\n\n\n After: XXX (Text after the continuous book part)\n\n\n Sample Output (format only): \n\n Answer: [4, 1, 3, 2] \n\n\n\n" + + new_instruction = """ +You are an AI assistant. Your job is to sort multiple book sections into the correct order. + Each time, you will be provided with 4 pieces of text. + These texts form a continuous part of a book, but are provided in random order. + You need to find the correct order and write the all the parts in the correct order. + For example, if the correct order is: Part 4 -> Part 1 -> Part 3 -> Part 2, you need to answer with a continous text of all the parts in the correct order. + You should NOT change the text, just write it in the order it should appear. + You will also be provided with the neighboring paragraphs before and after the 4 pieces of texts. + You should NOT output the before and after paragraphs, just the text in the correct order. + + The case sample is shown below and you should give me the answer in the format exactly the same as the sample. + + However, you should NOT focus on the content of sample answer. + + Please do NOT output any extra content. + + Sample Input (format only): + + Before: BBB (Text before the continuous book part) + + + Part 1: XXX + + + Part 2: YYY + + + Part 3: ZZZ + + + Part 4: WWW + + + After: AAA (Text after the continuous book part) + + Sample Output (format only): + + Answer: + + + WWW + + XXX + + ZZZ + + YYY + """ + return prompt.replace(original_instruction, new_instruction, 1) + + @staticmethod + def _generate_writing_prompt(contents: list[str]) -> str: + content = "\n\n".join( + [ + f"START CONTENT {i + 1}\n\n{content}\n\nEND CONTENT" + for i, content in enumerate(contents) + ] + ) + # Inspired by the prompt used in BAMBOO paper: https://arxiv.org/pdf/2309.13345 + prompt = f""" +I want you to act as a long dialogue completer. +Given a long dialogue(s), your objectives are: +1. Add one speaker mentioned in the past dialogue(s) at the end of the last sentence of each dialogue (between START CONTENT and END CONTENT) to complete the sentence and ensure its semantic integrity. At here, the added word must be a person's name which appears in the dialogue. +2. Continue the dialogue(s) with one or more speakers who appeared in the dialogue(s) before. Be coherent with the previous dialogue(s) and be creative in your response. +The content of the dialogue(s) is given below. + + +{content} +""" + return prompt + + @staticmethod + def _pad_or_truncate_prompt( + prompt: str, target_num_tokens: int, padding: str = "Answer now please.\n" + ) -> str: + encoder = tiktoken.get_encoding("o200k_base") + + tokens = encoder.encode(prompt, disallowed_special=()) + current_num_tokens = len(tokens) + + if current_num_tokens > target_num_tokens: + # Truncate if too long + tokens = encoder.encode(prompt, disallowed_special=()) + return encoder.decode(tokens[:target_num_tokens]) + elif current_num_tokens < target_num_tokens: + # Add padding if too short + padding_tokens = encoder.encode(padding, disallowed_special=()) + tokens_needed = target_num_tokens - current_num_tokens + # Calculate how many full padding sequences we need + num_padding_repeats = (tokens_needed + len(padding_tokens) - 1) // len(padding_tokens) + padded_prompt = prompt + (padding * num_padding_repeats) + # Truncate to exact target length + padded_tokens = encoder.encode(padded_prompt, disallowed_special=()) + return encoder.decode(padded_tokens[:target_num_tokens]) + else: + return prompt + + @staticmethod + def _generate_bamboo_prompt(external_dataset: "Dataset", num_tokens: int) -> str: + prompt = SPEEDBench._generate_writing_prompt(external_dataset["content"]) + return SPEEDBench._pad_or_truncate_prompt(prompt, num_tokens) + + @staticmethod + def _generate_chatrag_bench_prompt(external_dataset: "Dataset") -> list[Any]: + prompt = "Please give a full and complete answer for the questions. \n\nContext:\n{context}\n\nQuestion:\n{question}" + context = "\n\n".join([ctx["text"] for ctx in external_dataset["ctxs"][0]]) + questions = [ + message["content"] + for message in external_dataset["messages"][0] + if message["role"] == "user" + ] + + return [prompt.format(context=context, question=questions[0])] + questions[1:] + + @staticmethod + def _generate_coser_prompt(external_dataset: "Dataset") -> str: + rng = np.random.default_rng(seed=12347) + # Original prompt as given in CoSER paper: https://arxiv.org/pdf/2404.06480 + prompt = """You are {character} from {book_name}. +==={character}'s Profile=== +{character_profile} +===Current Scenario=== +{scenario} +===Information about the other Characters=== +{other_character_profiles_str} +===Your Inner Thoughts=== +{motivation} + +===Requirements=== +Your output should include **thought**, **speech**, and **action**. Use [your thought] +for thoughts, which others can't see, e.g. [I'm terrified, but I must appear strong.]. Use +(your action) for actions, which others can see, such as (watches silently, trying to control +her fear and anger).""" + character = rng.choice(external_dataset["major_characters"][0]) + character_profile = external_dataset["character_profiles"][0][character] + scenario = external_dataset["scenario"][0] + book_name = external_dataset["book"][0] + motivation = next( + ( + key_character["motivation"] + for key_character in external_dataset["key_characters"][0] + if key_character["name"] == character + ), + "No motivation provided", + ) + if motivation == "No motivation provided": + print("warning: no motivation provided for character", character) + other_character_profiles_str = "\n\n".join( + [ + f"{character_name}: {character_profile}" + for character_name, character_profile in external_dataset["character_profiles"][ + 0 + ].items() + if character_name != character and character_profile is not None + ] + ) + return prompt.format( + character=character, + character_profile=character_profile, + book_name=book_name, + scenario=scenario, + other_character_profiles_str=other_character_profiles_str, + motivation=motivation, + ) + + @staticmethod + def _generate_mmlu_pro_prompt(external_dataset: "Dataset", subject: str) -> list[Any]: + def get_question_and_options(question, options): + options = [(chr(ord("A") + i), a) for i, a in enumerate(options)] + options_str = "\n".join([f"({letter}) {option}" for letter, option in options]) + return f"Question: {question}\n\nOptions: {options_str}\n\n" + + # Original prompt as given in MMLU-Pro paper: https://arxiv.org/pdf/2406.01574 + prompt = 'The following are multiple choice questions (with answers) about {subject}. Think step by step and then finish your answer with "the answer is (X)" where X is the correct letter choice.\n\n' + first_question = prompt.format(subject=subject) + get_question_and_options( + external_dataset["question"][0], external_dataset["options"][0] + ) + return [first_question] + [ + get_question_and_options(question, options) + for question, options in zip( + external_dataset["question"][1:], external_dataset["options"][1:] + ) + ] + + @staticmethod + def _generate_hle_prompt( + example: dict[str, Any], + hle_train: "pd.DataFrame", + num_tokens: int, + rng: "np.random.Generator", + ) -> str: + encoder = tiktoken.get_encoding("o200k_base") + prompt = ( + "Please answer the question below.\n\nHere are some examples of question and answer pairs in the category of " + + example["category"] + + ":\n\n" + ) + prompt_tokens = encoder.encode(prompt) + example_tokens = encoder.encode(example["question"]) + current_num_tokens = len(prompt_tokens) + len(example_tokens) + hle_train_category = hle_train[hle_train["category"] == example["category"]] + + while current_num_tokens < num_tokens: + hle_train_category_sample = hle_train_category.sample(1, random_state=rng) + prompt += hle_train_category_sample["demonstration"].iloc[0] + current_num_tokens += len(hle_train_category_sample["tokens"].iloc[0]) + prompt_tokens += list(hle_train_category_sample["tokens"].iloc[0]) + + return encoder.decode( + prompt_tokens[: num_tokens - len(example_tokens) + 1] + example_tokens + ) + + @staticmethod + def _get_num_tokens_from_config(speed_config: config_type | str) -> int: + match = re.search(r"throughput_(\d+)k", speed_config) + if match: + return int(match.group(1)) * 1000 + else: + raise ValueError(f"Could not determine num_tokens from speed_config: {speed_config}") + + def _fetch_all_turns_data( + self, example: dict[str, Any], speed_config: config_type | str + ) -> dict[str, Any]: + turns = example["turns"] + if not turns[0].startswith(TURNS_PLACEHOLDER): + return example + + if BenchmarkDataset.BAMBOO.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + src_ids = [int(match) for match in re.findall(r"_(\d+)", example["src_id"])] + external_dataset = self._get_external_dataset( + BenchmarkDataset.BAMBOO.value, config_name=example["source"] + ) + external_dataset = external_dataset.select(src_ids) + example["turns"] = [self._generate_bamboo_prompt(external_dataset, num_tokens)] + + elif BenchmarkDataset.CNN_DAILYMAIL.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CNN_DAILYMAIL.value, config_name="3.0.0" + ).to_pandas() + src_id = example["src_id"] + article = external_dataset[external_dataset["id"] == src_id]["article"].iloc[0] + example["turns"] = [ + example["turns"][0].removeprefix(f"{TURNS_PLACEHOLDER}\n\n").format(article=article) + ] + + elif BenchmarkDataset.HLE.value in example["source"]: + if "qualitative" in speed_config: + external_dataset = self._get_external_dataset( + BenchmarkDataset.HLE.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + example["turns"] = [ + external_dataset[external_dataset["id"] == src_id]["question"].iloc[0] + ] + elif "throughput" in speed_config: + num_tokens = self._get_num_tokens_from_config(speed_config) + hle_train, hle_test = self._get_external_dataset( + BenchmarkDataset.HLE.value, config_name="train_test_split" + ) + hle_train = hle_train.to_pandas() + hle_train = hle_train[hle_train["image"] == ""] + hle_train["demonstration"] = hle_train.apply( + lambda e: "Question: " + + e["question"] + + "\n\nAnswer: " + + e["rationale"] + + "\n\n", + axis=1, + ) + hle_train["tokens"] = hle_train["demonstration"].apply( + lambda e: tiktoken.get_encoding("o200k_base").encode(e, disallowed_special=()) + ) + src_id = example["src_id"] + hle_test = hle_test.to_pandas() + external_dataset_example = hle_test[hle_test["id"] == src_id].iloc[0] + self.hle_rng = getattr(self, "hle_rng", np.random.default_rng(42)) + example["turns"] = [ + self._generate_hle_prompt( + external_dataset_example, hle_train, num_tokens, self.hle_rng + ) + ] + else: + raise ValueError(f"Invalid speed_config: {speed_config}") + + elif BenchmarkDataset.LIVECODEBENCH.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.LIVECODEBENCH.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[ + external_dataset["question_id"] == src_id + ].iloc[0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format( + question=external_dataset_example["question_content"], + starter_code=external_dataset_example["starter_code"], + ) + ] + + elif BenchmarkDataset.CODE_CONTESTS.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CODE_CONTESTS.value, config_name="test" + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[external_dataset["name"] == src_id].iloc[0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format(question=external_dataset_example["description"]) + ] + + elif BenchmarkDataset.MTBENCH_101.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.MTBENCH_101.value, config_name=example["source"] + ) + src_id = example["src_id"].rsplit("_", 1)[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [entry["user"] for entry in external_dataset_example["history"][0]] + + elif BenchmarkDataset.OPUS100.value in example["source"]: + _, config_name, src_id = example["src_id"].split("_") + external_dataset = self._get_external_dataset( + BenchmarkDataset.OPUS100.value, config_name=config_name + ) + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format(question=external_dataset_example["translation"][0]) + ] + + elif BenchmarkDataset.CHATRAG_BENCH.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.CHATRAG_BENCH.value, config_name=["hybridial", "sqa"] + ) + src_id = example["src_id"].rsplit("_", 1)[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = self._generate_chatrag_bench_prompt(external_dataset_example) + + elif BenchmarkDataset.MMLU_PRO.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.MMLU_PRO.value, config_name="test" + ) + src_id = int(example["src_id"].split("(")[1].split(",")[0]) + external_dataset_example = external_dataset.select( + range(src_id, src_id + len(example["turns"])) + ) + example["turns"] = self._generate_mmlu_pro_prompt( + external_dataset_example, example["sub_category"] + ) + + elif BenchmarkDataset.ADALEVAL_STACKSELECT.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + external_dataset = self._get_external_dataset( + BenchmarkDataset.ADALEVAL_STACKSELECT.value, + config_name=example["source"], + ).to_pandas() + src_id = example["src_id"] + external_dataset_example = external_dataset[ + external_dataset["question_id"] == src_id + ].iloc[0] + example["turns"] = [ + self._pad_or_truncate_prompt( + self._generate_stackselect_prompt( + question=external_dataset_example["question"], + answers=external_dataset_example["all_answers"], + answer=external_dataset_example["answer"], + num_tokens=num_tokens, + ), + num_tokens, + ) + ] + + elif BenchmarkDataset.ADALEVAL_TEXTSORT.value in example["source"]: + num_tokens = self._get_num_tokens_from_config(speed_config) + external_dataset = self._get_external_dataset( + BenchmarkDataset.ADALEVAL_TEXTSORT.value, config_name=example["source"] + ) + src_id = example["src_id"].split("_")[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [ + self._pad_or_truncate_prompt( + self._generate_textsort_prompt(external_dataset_example["prompt"][0]), + num_tokens, + ) + ] + + elif BenchmarkDataset.ROLEBENCH.value in example["source"]: + config_name = example["src_id"].split("_")[1] + external_dataset = self._get_external_dataset( + BenchmarkDataset.ROLEBENCH.value, + config_name=example["source"].replace("tree", "raw") + + f"/{config_name}/role_specific/test.jsonl", + ) + roles_dataset = self._get_external_dataset( + BenchmarkDataset.ROLEBENCH_ROLES.value, + config_name="https://huggingface.co/datasets/ZenMoore/RoleBench/raw/a57ed54f9613921e4a5f1b63601a558cd5acf971/profiles-eng/desc.json", + ) + src_ids = [int(match) for match in re.findall(r"_(\d+)", example["src_id"])][ + : len(example["turns"]) + ] + external_dataset_example = external_dataset.iloc[src_ids] + role_name = external_dataset_example["role"].iloc[0] + role_description_and_catchphrases = roles_dataset[role_name][0] + example["turns"] = [ + example["turns"][0] + .removeprefix(f"{TURNS_PLACEHOLDER}\n\n") + .format( + role_name=role_name, + role_description_and_catchphrases=role_description_and_catchphrases, + ) + + "\n" + + external_dataset_example["question"].iloc[0] + ] + [ + question.removeprefix(f"{role_name}, ").removeprefix(f" {role_name},") + for question in external_dataset_example["question"].iloc[1:] + ] + + elif BenchmarkDataset.COSER.value in example["source"]: + external_dataset = self._get_external_dataset( + BenchmarkDataset.COSER.value, config_name=example["source"] + ) + src_id = example["src_id"].split("_")[1] + external_dataset_example = external_dataset.select([int(src_id)]) + example["turns"] = [self._generate_coser_prompt(external_dataset_example)] + + return example + + def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Dataset": + """Load the raw HuggingFace dataset from a config name or local path. + + Args: + config_name_or_dataset_path: Either a SPEED-Bench config name + (e.g. ``"qualitative"``) or a path to a local parquet file / + directory. + category: If provided, filter the dataset to this category only. + + Returns: + The loaded (and optionally filtered / truncated) HuggingFace dataset. + """ + if config_name_or_dataset_path in get_args(config_type): + dataset = load_dataset("nvidia/SPEED-Bench", config_name_or_dataset_path, split="test") + else: + config_name_or_dataset_path_path = Path(config_name_or_dataset_path) + if not config_name_or_dataset_path_path.exists(): + msg = ", ".join(get_args(config_type)) + raise ValueError( + f"Dataset path {config_name_or_dataset_path_path} does not exist or not one of the supported configs {msg}" + ) + if config_name_or_dataset_path_path.is_dir(): + data_files = { + "test": [ + str(path) for path in config_name_or_dataset_path_path.rglob("*.parquet") + ] + } + else: + data_files = {"test": [str(config_name_or_dataset_path_path)]} + dataset = load_dataset("parquet", data_files=data_files, split="test") + if self.num_samples is not None: + dataset = dataset.select(range(self.num_samples)) + return dataset + + def _resolve_external_data( + self, dataset: "Dataset", speed_config: config_type | str + ) -> "Dataset": + """Resolve all external data references in the dataset. + + Applies ``_fetch_all_turns_data`` to every example so that turn + placeholders are replaced with fully-resolved prompt text. + + Args: + dataset: The HuggingFace dataset with potentially unresolved turns. + speed_config: The SPEED-Bench config name used to determine + token-length parameters for throughput configs. + + Returns: + The dataset with all turns fully resolved. + """ + return dataset.map(self._fetch_all_turns_data, fn_kwargs={"speed_config": speed_config}) + + def _preprocess( + self, + config_name_or_dataset_path: config_type | str, + *, + _prepare_mode: bool = False, + ): + dataset = self._load_dataset(config_name_or_dataset_path) + + if _prepare_mode: + # Resolve all external data references (only allowed during prepare) + dataset = self._resolve_external_data(dataset, config_name_or_dataset_path) + else: + # Validate that all turns are fully resolved (no placeholders remaining) + for example in dataset: + for turn in example["turns"]: + if turn.startswith(TURNS_PLACEHOLDER): + raise ValueError( + f"Unresolved data placeholder found in question_id={example['question_id']} " + f"(category={example['category']}). Please run " + f"`python prepare_data.py --config ` first to download " + f"and resolve all external data references." + ) + + self._resolved_dataset = dataset + self.data = [ + Request( + system_prompt=None, + turns=example["turns"], + category=example["category"], + question_id=example["question_id"], + ) + for example in dataset + ] + assert len(self.data) == len(dataset), ( # type: ignore[arg-type] + f"Number of requests {len(self.data)} does not match number of requests in the dataset {len(dataset)}" # type: ignore[arg-type] + ) + + @classmethod + def prepare_data( + cls, + output_dir: str | Path, + config_name: config_type = "qualitative", + ) -> Path: + """Download, resolve, and save the SPEED-Bench dataset as parquet. + + This is the **only** entry-point that fetches external data and + resolves turn placeholders. The resulting parquet file can then be + loaded directly by the normal ``SPEEDBench(config_name=)`` + constructor without any further network access. + + Args: + output_dir: Directory where the parquet file will be written. + config_name: SPEED-Bench configuration to prepare. + + Returns: + Path to the saved parquet file. + """ + instance = cls(config_name=config_name, _prepare_mode=True) + + # Persist to parquet + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / "test.parquet" + instance._resolved_dataset.to_parquet(output_path) + return output_path diff --git a/examples/specdec_bench/specdec_bench/metrics/__init__.py b/examples/specdec_bench/specdec_bench/metrics/__init__.py index b61616830c..1f6ac79fcb 100644 --- a/examples/specdec_bench/specdec_bench/metrics/__init__.py +++ b/examples/specdec_bench/specdec_bench/metrics/__init__.py @@ -15,6 +15,8 @@ from .aa_timing import AATiming from .acceptance_rate import AcceptanceRate -from .base import Metric from .mtbench import MTBench +from .specbench import SpecBench from .timing import Timing + +__all__ = ["AATiming", "AcceptanceRate", "MTBench", "SpecBench", "Timing"] diff --git a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py index 21af35112f..cce735d5f1 100644 --- a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py @@ -30,7 +30,7 @@ def __init__(self, base_tokenizer): raise ImportError( "Please install tiktoken to use the AATiming metric, or remove the metric from the run command" ) - self.enc = tiktoken.get_encoding("cl100k_base") + self.enc = tiktoken.get_encoding("o200k_base") self.base_tokenizer = base_tokenizer self.total_tokens = [] diff --git a/examples/specdec_bench/specdec_bench/metrics/specbench.py b/examples/specdec_bench/specdec_bench/metrics/specbench.py new file mode 100644 index 0000000000..32ab3d1c7a --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/specbench.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import defaultdict +from itertools import chain +from pathlib import Path +from statistics import mean + +try: + import matplotlib.pyplot as plt + import pandas as pd + from rich.console import Console + from rich.table import Table + + not_installed = False +except ImportError: + not_installed = True + +from .acceptance_rate import AcceptanceRate + + +class SpecBench(AcceptanceRate): + def __init__(self, requests): + super().__init__() + if not_installed: + raise ImportError( + "Please install rich, matplotlib, seaborn, and pandas to use the SpecBench metric" + ) + self.requests = requests + + def process_final(self, text_outputs): + lengths = {} + self.out["Request_AR"] = {} + for request_id, request in enumerate(self.requests): + turns = self.prompt_ar[request_id].values() + assert len(turns) == len(request.turns), ( + f"Number of turns {len(turns)} does not match number of turns in request {len(request.turns)}" + ) + self.out["Request_AR"][request.question_id] = mean(list(chain(*turns))) + for turn in turns: + self._get_lengths(turn, lengths) + print(request.category, self.out["Request_AR"][request.question_id]) + per_category = defaultdict(list) + for request in self.requests: + per_category[request.category].append(self.out["Request_AR"][request.question_id]) + self.out["Category_AR"] = {} + for category_name, category_ar in per_category.items(): + if len(category_ar) > 0: + category_ar = mean(category_ar) + self.out["Category_AR"][category_name] = category_ar + average_ar = mean(self.out["Request_AR"].values()) + self.out["Average_AR"] = average_ar + self._process_lengths(lengths) + self.write() + self._format_write_output(text_outputs) + self._pretty_print_results() + self._dump_results() + self._create_visualizations(text_outputs) + + def _format_write_output(self, outputs): + with open(os.path.join(self.directory, "specbench_responses.jsonl"), "w") as outfile: + for i, messages in enumerate(outputs): + out_line = {} + out_line["question_id"] = self.requests[i].question_id + out_line["category"] = self.requests[i].category + q_turns = [c["content"] for c in messages if c["role"] == "user"] + a_turns = [c["content"] for c in messages if c["role"] == "assistant"] + out_line["turns"] = q_turns + out_line["choices"] = [{"index": 0, "turns": a_turns}] + json.dump(out_line, outfile) + outfile.write("\n") + + def _pretty_print_results(self): + # Create and display results table + console = Console() + table = Table( + title="Acceptance Rate Results", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Category", style="cyan", no_wrap=True) + table.add_column("Average AR", justify="right", style="green") + + # Add category rows + for category_name, category_ar in sorted(self.out["Category_AR"].items()): + table.add_row(category_name, f"{category_ar:.4f}") + + # Add separator and summary row + table.add_section() + table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AR']:.4f}[/bold]") + + console.print(table) + + def _dump_results(self): + with open(os.path.join(self.directory, "specbench_results.json"), "w") as outfile: + json.dump(self.out, outfile, indent=4) + + def _create_visualizations( + self, + text_outputs: list[list[dict[str, str]]], + title: str = "Speculative Decoding Acceptance Rate Analysis", + ): + """ + Create professional plots for acceptance rates. + Completely generated by Cursor. + """ + + # Set style + plt.style.use("seaborn-v0_8") + + df_clean = pd.DataFrame.from_dict( + { + "question_id": list(self.out["Request_AR"].keys()), + "acceptance_rate": list(self.out["Request_AR"].values()), + "category": [request.category for request in self.requests], + "response_length": [ + mean([len(c["content"]) for c in messages if c["role"] == "assistant"]) + for messages in text_outputs + ], + } + ) + + if len(df_clean) == 0: + print("Warning: No successful results to plot") + return + + # 1. Acceptance rate by category + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + fig.suptitle(title, fontsize=16, fontweight="bold") + + # Plot 1: Acceptance rate by category + ax1 = axes[0] + category_stats = ( + df_clean.groupby("category") + .agg({"acceptance_rate": ["mean", "std"], "question_id": "count"}) + .round(3) + ) + + categories = category_stats.index.tolist() + means = category_stats[("acceptance_rate", "mean")].values + stds = category_stats[("acceptance_rate", "std")].values + counts = category_stats[("question_id", "count")].values + + bars = ax1.bar(range(len(categories)), means, yerr=stds, capsize=5, alpha=0.8) + ax1.set_xlabel("Category") + ax1.set_ylabel("Acceptance Rate") + ax1.set_title("Acceptance Rate by Category") + ax1.set_xticks(range(len(categories))) + ax1.set_xticklabels(categories, rotation=45, ha="right") + + # Add count labels on bars + for i, (bar, count) in enumerate(zip(bars, counts)): + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.01, + f"n={count}", + ha="center", + va="bottom", + fontsize=8, + ) + + # Plot 2: Acceptance rate vs response length + ax2 = axes[1] + # Bin response lengths + df_clean["response_length_bin"] = pd.cut( + df_clean["response_length"], + bins=[0, 100, 300, 500, 1000, float("inf")], + labels=["0-100", "100-300", "300-500", "500-1000", "1000+"], + ) + + length_stats = ( + df_clean.groupby("response_length_bin") + .agg({"acceptance_rate": ["mean", "std"], "question_id": "count"}) + .round(3) + ) + + length_bins = length_stats.index.tolist() + length_means = length_stats[("acceptance_rate", "mean")].values + length_stds = length_stats[("acceptance_rate", "std")].values + length_counts = length_stats[("question_id", "count")].values + + bars2 = ax2.bar( + range(len(length_bins)), + length_means, + yerr=length_stds, + capsize=5, + alpha=0.8, + ) + ax2.set_xlabel("Response Length (characters)") + ax2.set_ylabel("Acceptance Rate") + ax2.set_title("Acceptance Rate by Response Length") + ax2.set_xticks(range(len(length_bins))) + ax2.set_xticklabels(length_bins) + + for i, (bar, count) in enumerate(zip(bars2, length_counts)): + ax2.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.01, + f"n={count}", + ha="center", + va="bottom", + fontsize=8, + ) + + # Plot 3: Distribution of acceptance rates + ax3 = axes[2] + ax3.hist(df_clean["acceptance_rate"], bins=20, alpha=0.7, edgecolor="black") + ax3.axvline( + df_clean["acceptance_rate"].mean(), + color="red", + linestyle="--", + label=f"Mean: {df_clean['acceptance_rate'].mean():.3f}", + ) + ax3.set_xlabel("Acceptance Rate") + ax3.set_ylabel("Frequency") + ax3.set_title("Distribution of Acceptance Rates") + ax3.legend() + + plt.tight_layout() + plot_path = Path(self.directory) / "acceptance_rate_analysis.png" + plt.savefig(plot_path, dpi=300, bbox_inches="tight") + plt.close() + print(f"Plots saved to {plot_path}") diff --git a/examples/specdec_bench/specdec_bench/metrics/timing.py b/examples/specdec_bench/specdec_bench/metrics/timing.py index 023aaf7853..5bf33c604e 100644 --- a/examples/specdec_bench/specdec_bench/metrics/timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/timing.py @@ -53,6 +53,7 @@ def process_final(self, text_outputs): if tpot_time: self.out["Request Generation Step Time"] = compute_statistics(tpot_time) self.out["Request Generation Tokens Per Second"] = compute_statistics(gen_tp_time) + self.out["Number of Output Tokens"] = compute_statistics(self.total_tokens) for k, v in self.out.items(): print(k, v) self.write() diff --git a/examples/specdec_bench/specdec_bench/models/__init__.py b/examples/specdec_bench/specdec_bench/models/__init__.py index 5fa1260ab1..e103a9d922 100644 --- a/examples/specdec_bench/specdec_bench/models/__init__.py +++ b/examples/specdec_bench/specdec_bench/models/__init__.py @@ -13,7 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Model +from .auto_deploy import AutoDeployModel from .sglang import SGLANGModel +from .specbench_medusa import SpecBenchMedusaModel from .trtllm_torch_api import TRTLLMPYTModel from .vllm import VLLMModel + +__all__ = [ + "AutoDeployModel", + "SGLANGModel", + "SpecBenchMedusaModel", + "TRTLLMPYTModel", + "VLLMModel", +] diff --git a/examples/specdec_bench/specdec_bench/models/auto_deploy.py b/examples/specdec_bench/specdec_bench/models/auto_deploy.py new file mode 100644 index 0000000000..bd030e783e --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/auto_deploy.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import itertools +import time +from typing import Any + +try: + from tensorrt_llm._torch.auto_deploy.llm import LLM + from tensorrt_llm.llmapi import DraftTargetDecodingConfig + from tensorrt_llm.sampling_params import SamplingParams +except ImportError: + print("tensorrt_llm._torch.auto_deploy is not installed.") + LLM = None + +from .base import Model + + +class AutoDeployModel(Model): + def __init__(self, model_path, max_concurrent_requests, sampling_kwargs, **kwargs): + self.model = create_auto_deploy_model(model_path, max_concurrent_requests, kwargs) + self.sampling_kwargs = sampling_kwargs + + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + output_dict = {} + sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id) + outputs = [] + timing = [time.perf_counter()] + beam_lens = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + + async for output in self.model.generate_async( + prompt_ids, + streaming=not sampling_config.use_beam_search, + sampling_params=sampling_config, + ): + for beam in output.outputs: + beam_lens[beam.index].append(len(beam.token_ids)) + outputs.append(output.outputs) + timing.append(time.perf_counter()) + + reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + for beam_idx, beam_len in enumerate(beam_lens): + response = outputs[-1][beam_idx] + if beam_len[0] != 0: + reformatted_output_ids[beam_idx].append(response.token_ids[: beam_len[0]]) + for s, e in itertools.pairwise(beam_len): + reformatted_output_ids[beam_idx].append(response.token_ids[s:e]) + if len(response.token_ids) > beam_len[-1]: + reformatted_output_ids[beam_idx].append(response.token_ids[beam_len[-1] :]) + + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + return output_dict + + def stop(self): + """Stop and cleanup the model.""" + if hasattr(self, "model") and self.model is not None: + with contextlib.suppress(Exception): + del self.model + + +def create_auto_deploy_model(model_path: str, max_concurrent_requests: int, kwargs: dict[str, Any]): + world_size = kwargs.get("world_size", kwargs.get("tensor_parallel_size", 1)) + + max_seq_len = kwargs.get("max_seq_len", 8192) + + kv_cache_config = { + "enable_block_reuse": kwargs.get("prefix_cache", False), + "free_gpu_memory_fraction": kwargs.get("free_gpu_memory_fraction", 0.75), + } + + specdec = None + speculative_algorithm = kwargs.get("speculative_algorithm") + + if speculative_algorithm == "DRAFT_TARGET": + specdec = DraftTargetDecodingConfig( + max_draft_len=kwargs.get("speculative_num_steps", 3), + speculative_model_dir=kwargs.get("draft_model_dir"), + ) + elif speculative_algorithm == "NONE": + specdec = None + + max_num_tokens = kwargs.get("max_num_tokens", 8192) + + llm_kwargs = { + "model": model_path, + "world_size": world_size, + "max_batch_size": max_concurrent_requests, + "max_seq_len": max_seq_len, + "max_num_tokens": max_num_tokens, + "skip_tokenizer_init": kwargs.get("skip_tokenizer_init", True), + "kv_cache_config": kv_cache_config, + "runtime": "trtllm", + "disable_overlap_scheduler": kwargs.get("disable_overlap_scheduler", True), + "speculative_config": specdec, + } + + if kwargs.get("attn_backend"): + llm_kwargs["attn_backend"] = kwargs["attn_backend"] + + if kwargs.get("compile_backend"): + llm_kwargs["compile_backend"] = kwargs["compile_backend"] + + # Optimization mode: "graph" uses full torch.export, "transformers" is simpler + # Default to "transformers" to avoid torch.export dimension specialization issues + llm_kwargs["mode"] = kwargs.get("mode", "transformers") + + if kwargs.get("cuda_graph_batch_sizes"): + llm_kwargs["cuda_graph_batch_sizes"] = kwargs["cuda_graph_batch_sizes"] + + model = LLM(**llm_kwargs) + return model + + +def check_sampling_config(sampling_config: dict[str, Any], max_length: int, end_id: int): + return SamplingParams( + use_beam_search=sampling_config.get("beam_width", 1) > 1, + n=sampling_config.get("beam_width", 1), + top_k=sampling_config.get("top_k"), + top_p=sampling_config.get("top_p"), + seed=sampling_config.get("seed"), + temperature=sampling_config.get("temperature", 1.0), + max_tokens=max_length, + end_id=end_id, + detokenize=False, + ) diff --git a/examples/specdec_bench/specdec_bench/models/base.py b/examples/specdec_bench/specdec_bench/models/base.py index 42186fef05..ab26a4704d 100644 --- a/examples/specdec_bench/specdec_bench/models/base.py +++ b/examples/specdec_bench/specdec_bench/models/base.py @@ -18,7 +18,7 @@ class Model: def __init__(self, model_dir, tokenizer, max_draft_length): raise NotImplementedError - async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + async def run(self, prompt_ids, sampling_params, request_id, turn_id): """ prompt_ids is list of tokens output is list of list of tokens diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 4840a0eda3..d5ff890ffd 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -27,7 +27,12 @@ class SGLANGModel(Model): def __init__( - self, model_dir, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + self, + model_dir, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, ): speculative_algorithm = kwargs.get("speculative_algorithm") if speculative_algorithm == "MTP": @@ -43,35 +48,44 @@ def __init__( self.model = sgl.Engine( model_path=model_dir, skip_tokenizer_init=True, - mem_fraction_static=0.7, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + trust_remote_code=kwargs.get("trust_remote_code", False), + mem_fraction_static=0.8, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), speculative_algorithm=speculative_algorithm, speculative_num_steps=kwargs.get("speculative_num_steps", 3), speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), speculative_draft_model_path=kwargs.get("draft_model_dir"), torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, attention_backend=kwargs.get("attention_backend"), enable_torch_compile=kwargs.get("enable_torch_compile", False), cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, ) else: self.model = sgl.Engine( model_path=model_dir, skip_tokenizer_init=True, - mem_fraction_static=0.7, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + trust_remote_code=kwargs.get("trust_remote_code", False), + mem_fraction_static=0.8, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, attention_backend=kwargs.get("attention_backend"), enable_torch_compile=kwargs.get("enable_torch_compile", False), cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, ) self.sampling_config = sampling_kwargs async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """Synchronous version of run for use with asyncio.to_thread""" timing = [] output_dict = {} self.sampling_config["max_new_tokens"] = max_length @@ -79,7 +93,7 @@ async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): timing.append(time.perf_counter()) assert self.sampling_config.get("beam_width", 1) == 1 beam_lens = [[] for _ in range(self.sampling_config.get("beam_width", 1))] - outputs = [] + outputs = [None] result = await self.model.async_generate( sampling_params=self.sampling_config, input_ids=prompt_ids, stream=True ) diff --git a/examples/specdec_bench/specdec_bench/models/specbench_medusa.py b/examples/specdec_bench/specdec_bench/models/specbench_medusa.py new file mode 100644 index 0000000000..e483f379c3 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/specbench_medusa.py @@ -0,0 +1,284 @@ +# Adapted from https://github.com/hemingkx/Spec-Bench/tree/66230f10cb0a02aced5ef3ce1e85163c16160454/model/medusa +# SPDX-FileCopyrightText: Copyright (c) 2024 Heming Xia. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import sys +import time + +import torch + +from .base import Model + +# Medusa dependencies from Spec-Bench +try: + spec_bench_path = os.path.join(os.getcwd(), "Spec-Bench") + sys.path.insert(0, spec_bench_path) + from model.medusa.kv_cache import initialize_past_key_values + from model.medusa.medusa_choices import mc_sim_7b_63 + from model.medusa.medusa_model import MedusaModel + from model.medusa.utils import ( + evaluate_posterior, + generate_candidates, + generate_medusa_buffers, + initialize_medusa, + reset_medusa_mode, + tree_decoding, + update_inference_inputs, + ) +except ImportError as e: + print(f"Medusa dependencies not found: {e}") + MedusaModel = None + + +class SpecBenchMedusaModel(Model): + def __init__( + self, + model_dir, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, + ): + if MedusaModel is None: + raise ImportError( + "Medusa dependencies not found. Please ensure Spec-Bench is available." + ) + assert max_concurrent_requests == 1, "Only support batch size 1 for now!" + self.medusa_num_heads = kwargs.get("medusa_num_heads", 4) + self.draft_model_path = kwargs.get("draft_model_dir") + self.dtype = kwargs.get("dtype", "float16") + self.max_steps = kwargs.get("max_steps", 512) + + # Medusa decoding parameters + self.temperature = sampling_kwargs.get("temperature", 0.0) + self.posterior_threshold = kwargs.get("posterior_threshold", 0.09) + self.posterior_alpha = kwargs.get("posterior_alpha", 0.3) + self.medusa_choices = kwargs.get("medusa_choices", mc_sim_7b_63) + + # Convert dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float64": torch.float64, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = dtype_map.get(self.dtype, torch.float16) + + # Load the Medusa model + # Use single GPU to avoid device mismatch issues with device_map="auto" + self.device = torch.device(kwargs.get("device", "cuda:0")) + self.model = MedusaModel.from_pretrained( + self.draft_model_path, + model_dir, + medusa_num_heads=self.medusa_num_heads, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(self.device) + + self.sampling_kwargs = sampling_kwargs + + def _medusa_forward(self, input_ids, max_new_tokens, end_id): + """ + Run Medusa speculative decoding forward pass. + + Returns: + tuple: (output_ids, new_token_count, num_steps, accept_length_list, timing) + """ + # Avoid modifying the input_ids in-place + accept_length_list = [] + input_ids = input_ids.clone() + timing = [time.perf_counter()] + + # Cache medusa buffers (the fixed patterns for tree attention) + if ( + hasattr(self.model, "medusa_choices") + and self.model.medusa_choices == self.medusa_choices + ): + medusa_buffers = self.model.medusa_buffers + else: + medusa_buffers = generate_medusa_buffers(self.medusa_choices, device=self.device) + self.model.medusa_buffers = medusa_buffers + self.model.medusa_choices = self.medusa_choices + + # Initialize the past key and value states + if hasattr(self.model, "past_key_values"): + past_key_values = self.model.past_key_values + past_key_values_data = self.model.past_key_values_data + current_length_data = self.model.current_length_data + current_length_data.zero_() + else: + ( + past_key_values, + past_key_values_data, + current_length_data, + ) = initialize_past_key_values(self.model.base_model) + self.model.past_key_values = past_key_values + self.model.past_key_values_data = past_key_values_data + self.model.current_length_data = current_length_data + + input_len = input_ids.shape[1] + cur_length = input_len + reset_medusa_mode(self.model) + medusa_logits, logits = initialize_medusa( + input_ids, self.model, medusa_buffers["medusa_attn_mask"], past_key_values + ) + new_token = 0 + + for idx in range(self.max_steps): + candidates, tree_candidates = generate_candidates( + medusa_logits, + logits, + medusa_buffers["tree_indices"], + medusa_buffers["retrieve_indices"], + ) + medusa_logits, logits, outputs = tree_decoding( + self.model, + tree_candidates, + past_key_values, + medusa_buffers["medusa_position_ids"], + input_ids, + medusa_buffers["retrieve_indices"], + ) + best_candidate, accept_length = evaluate_posterior( + logits, + candidates, + self.temperature, + self.posterior_threshold, + self.posterior_alpha, + ) + input_ids, logits, medusa_logits, new_token = update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + medusa_buffers["retrieve_indices"], + outputs, + logits, + medusa_logits, + new_token, + past_key_values_data, + current_length_data, + ) + accept_length_tree = input_ids.shape[1] - cur_length + cur_length = accept_length_tree + cur_length + accept_length_list.append(accept_length_tree) + timing.append(time.perf_counter()) + + if end_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + + return input_ids, new_token, idx + 1, accept_length_list, timing + + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + """ + Run inference on the given prompt. + + Args: + prompt_ids: List of input token IDs + max_length: Maximum number of new tokens to generate + end_id: End of sequence token ID + request_id: Request identifier + turn_id: Turn identifier + + Returns: + dict with output_ids, output_logits, and token_times + """ + output_dict = {} + + # Convert prompt_ids to tensor + input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=self.device) + + # Run medusa forward pass (synchronously, but wrapped for async interface) + ( + input_ids_out, + new_token, + num_steps, + accept_length_list, + timing, + ) = await asyncio.to_thread(self._medusa_forward, input_ids, max_length, end_id) + + # Extract generated tokens (excluding the prompt) + original_len = len(prompt_ids) + generated_tokens = input_ids_out[0, original_len:].tolist() + + # Remove EOS token from output if present + if end_id in generated_tokens: + eos_idx = generated_tokens.index(end_id) + generated_tokens = generated_tokens[:eos_idx] + # Also adjust accept_length_list and timing + # Count how many tokens we're removing + tokens_to_remove = len(input_ids_out[0, original_len:].tolist()) - len(generated_tokens) + if tokens_to_remove > 0 and len(accept_length_list) > 0: + # Adjust the last accept length + accept_length_list[-1] = max(0, accept_length_list[-1] - tokens_to_remove) + if accept_length_list[-1] == 0: + accept_length_list.pop() + if len(timing) > 1: + timing.pop() + + # Format output_ids as list of list of tokens per step (for beam_width=1) + reformatted_output_ids = [[]] + start = 0 + for accept_len in accept_length_list: + if accept_len > 0: + reformatted_output_ids[0].append(generated_tokens[start : start + accept_len]) + start += accept_len + + # Handle any remaining tokens + if start < len(generated_tokens): + reformatted_output_ids[0].append(generated_tokens[start:]) + + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + + return output_dict + + def stop(self): + """Cleanup resources.""" + # Clear cached KV states to free memory + if hasattr(self.model, "past_key_values"): + del self.model.past_key_values + del self.model.past_key_values_data + del self.model.current_length_data + + # Clear medusa buffers + if hasattr(self.model, "medusa_buffers"): + del self.model.medusa_buffers + + # Move model to CPU or delete to free GPU memory + if hasattr(self, "model") and self.model is not None: + del self.model + torch.cuda.empty_cache() diff --git a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py index 11ceeb2071..25a2aed632 100644 --- a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py +++ b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py @@ -38,7 +38,12 @@ class TRTLLMPYTModel(Model): def __init__( - self, model_path, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + self, + model_path, + max_concurrent_requests, + sampling_kwargs, + use_draft_logits=False, + **kwargs, ): self.model = create_executor(model_path, max_concurrent_requests, kwargs) self.sampling_kwargs = sampling_kwargs @@ -80,16 +85,23 @@ def create_executor(model_path: str, max_concurrent_requests, kwargs): max_draft_len=kwargs.get("speculative_num_steps", 3), speculative_model_dir=kwargs.get("draft_model_dir", None), ) - disable_overlap_schedule = True elif kwargs.get("speculative_algorithm", None) == "EAGLE3": + extra_params = {} + if "allow_advanced_sampling" in EagleDecodingConfig.model_fields: + extra_params["allow_advanced_sampling"] = kwargs.get("allow_advanced_sampling", False) + elif "allow_advanced_sampling" in kwargs: + print( + f"WARNING: allow_advanced_sampling unsupported in tensorrt_llm version: {trtllm.__version__}" + ) specdec = EagleDecodingConfig( max_draft_len=kwargs.get("speculative_num_steps", 3), speculative_model_dir=kwargs.get("draft_model_dir", None), eagle3_one_model=kwargs.get("use_one_model", True), eagle3_layers_to_capture=kwargs.get("eagle3_layers_to_capture", None), + num_eagle_layers=kwargs.get("num_eagle_layers", 1), + **extra_params, ) - disable_overlap_schedule = not kwargs.get("use_one_model", True) elif kwargs.get("speculative_algorithm", None) == "MTP": specdec = MTPDecodingConfig( @@ -127,13 +139,15 @@ def create_executor(model_path: str, max_concurrent_requests, kwargs): moe_expert_parallel_size=kwargs.get("moe_expert_parallel_size", 2), disable_overlap_scheduler=disable_overlap_schedule, cuda_graph_config=cuda_graph_config, - enable_chunked_prefill=kwargs.get("enable_chunked_prefill", False), + enable_chunked_prefill=kwargs.get("enable_chunked_prefill", True), kv_cache_config=kv_cache_config, speculative_config=specdec, enable_attention_dp=kwargs.get("enable_attention_dp", False), max_batch_size=max_concurrent_requests, moe_config=MoeConfig(backend=kwargs.get("moe_backend", "TRTLLM")), sampler_type="TorchSampler", + max_seq_len=kwargs.get("max_seq_len", None), + max_num_tokens=kwargs.get("max_num_tokens", 8192), ) return model diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index deb79ed89e..0ea6c6c4e5 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -51,10 +51,13 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs } elif kwargs.get("speculative_algorithm") == "DRAFT_TARGET": specdec = { - "method": "draft_target", + "method": "draft_model", "model": kwargs.get("draft_model_dir"), "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), } + if kwargs.get("parallel_draft_block_sizes") is not None: + specdec["disable_padded_drafter_batch"] = True + specdec["parallel_draft_block_sizes"] = kwargs.get("parallel_draft_block_sizes") elif kwargs.get("speculative_algorithm") == "MTP": specdec = { "method": "mtp", @@ -62,15 +65,22 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs } elif kwargs.get("speculative_algorithm") == "NONE": specdec = None + + if specdec is None: + num_speculative_tokens = 1 + else: + num_speculative_tokens = specdec.get("num_speculative_tokens", 3) engine_args = AsyncEngineArgs( model=model_dir, - trust_remote_code=True, + trust_remote_code=kwargs.get("trust_remote_code", False), tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1, enable_prefix_caching=kwargs.get("prefix_cache", False), speculative_config=specdec, - max_num_seqs=max_concurrent_requests, + max_num_seqs=max_concurrent_requests * num_speculative_tokens, skip_tokenizer_init=False, + async_scheduling=kwargs.get("async_scheduling", True), + enforce_eager=False, ) self.model = AsyncLLM.from_engine_args(engine_args) self.sampling_kwargs = sampling_kwargs @@ -88,6 +98,8 @@ async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): output_dict = {} self.sampling_config.max_tokens = max_length self.sampling_config.stop_token_ids = [end_id] + if end_id == -1: + self.sampling_config.ignore_eos = True outputs, timing, full_tokens = await self.generate(prompt_ids, request_id, turn_id) diff --git a/examples/specdec_bench/specdec_bench/runners/__init__.py b/examples/specdec_bench/specdec_bench/runners/__init__.py index 61a85c769c..17832bb997 100644 --- a/examples/specdec_bench/specdec_bench/runners/__init__.py +++ b/examples/specdec_bench/specdec_bench/runners/__init__.py @@ -13,5 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import BaseRunner from .simple import SimpleRunner + +__all__ = ["SimpleRunner"] diff --git a/examples/specdec_bench/specdec_bench/runners/base.py b/examples/specdec_bench/specdec_bench/runners/base.py index c481a0fd09..ee9062e39f 100644 --- a/examples/specdec_bench/specdec_bench/runners/base.py +++ b/examples/specdec_bench/specdec_bench/runners/base.py @@ -21,7 +21,7 @@ def __init__(self, model, metrics): self.metrics = metrics self.prompt_ar = [] - async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + async def run(self, prompt_ids, max_length, end_id, sampling_kwargs): raise NotImplementedError() def process_metrics_final(self, text_outputs): diff --git a/examples/specdec_bench/specdec_bench/utils.py b/examples/specdec_bench/specdec_bench/utils.py index d605f0b4b9..14ded0f31b 100644 --- a/examples/specdec_bench/specdec_bench/utils.py +++ b/examples/specdec_bench/specdec_bench/utils.py @@ -18,13 +18,17 @@ from transformers import AutoTokenizer -def get_tokenizer(path): - return AutoTokenizer.from_pretrained(path) +def get_tokenizer(path, trust_remote_code=False): + return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code) -def encode_chat(tokenizer, messages): +def encode_chat(tokenizer, messages, chat_template_args={}, completions=False): + if completions: + return tokenizer.encode(messages[-1]["content"], add_special_tokens=False) return tokenizer.encode( - tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), + tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, **chat_template_args + ), add_special_tokens=False, ) @@ -46,4 +50,11 @@ def postprocess_base(text): def postprocess_gptoss(text): - return text.split("<|channel|>final<|message|>")[-1] + final_message = text.split("<|channel|>final<|message|>")[-1] + if "<|end|>" in final_message: + final_message = final_message.split("<|end|>")[0] + if "<|return|>" in final_message: + final_message = final_message.split("<|return|>")[0] + if "<|channel|>" in final_message: + final_message = final_message.split("<|channel|>")[0] + return final_message diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index c495809bb9..7e9c855cb1 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM, ### Docker -Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information. +Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information. Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies. @@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst ## Getting Started: Simplified Workflow ```bash -bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4 +bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct ``` This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it @@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ --data input_conversations/daring-anteater.jsonl \ - --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json ``` -This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details. +FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ - --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json \ --offline-data $HIDDEN_STATES_DIR @@ -244,6 +242,17 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. +### Configuring Draft Model + +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: + +```json +{ + "num_hidden_layers": 2, + "intermediate_size":8192 +} +``` + ### Draft Vocabulary Compression We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: @@ -254,15 +263,7 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -### Configuring Draft Model - -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: - -```json -{ - "draft_vocab_size": 32000 -} -``` +Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache ` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. ### Interact with `modelopt.torch.speculative` diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py index f9818e4642..a3d1681c4c 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -17,10 +17,10 @@ import argparse import asyncio -import json from pathlib import Path import torch +from datasets import load_dataset from tqdm import tqdm as tqdm from transformers import AutoModel, AutoTokenizer @@ -54,12 +54,10 @@ def parse_args() -> argparse.Namespace: ## I/O Parameters ## parser.add_argument( - "--input-file", + "--input-data", type=Path, required=True, - help="""Path to the input `jsonl` file containing conversations. - Each entry must have a unique `conversation_id` field and a `conversations` field - containing a list of messages.""", + help="""Path to the `jsonl` file or directory containing `jsonl` files.""", ) parser.add_argument( "--output-dir", @@ -75,21 +73,68 @@ def parse_args() -> argparse.Namespace: help="""For debugging purposes, limit the number of conversations processed. Default is None, meaning no limit.""", ) + parser.add_argument( + "--dp-rank", + type=int, + default=0, + help="""Data parallel rank. TASK_ID on SLURM.""", + ) + parser.add_argument( + "--dp-world-size", + type=int, + default=1, + help="""Data parallel world size. Number of tasks on SLURM.""", + ) return parser.parse_args() -async def main(args: argparse.Namespace) -> None: - all_conversations = [] - with args.input_file.open("r", encoding="utf-8") as f: - all_conversations.extend([json.loads(line) for line in f if line.strip()]) +def main(args: argparse.Namespace) -> None: + # Load conversations + if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): + dataset = load_dataset("json", data_files=str(args.input_data), split="train") + elif args.input_data.is_dir(): + dataset = load_dataset( + "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" + ) + else: + raise ValueError( + f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}" + ) + print(f"Loaded {len(dataset)} conversations from {args.input_data}") - print("Loaded", len(all_conversations), "conversations from", args.input_file) + # Shard data + if args.dp_world_size > 1: + dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank) + print( + f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}" + ) + + # Remove already dumped conversations + def keep_conversation(entry): + conversation_id = entry.get("conversation_id", entry.get("uuid", None)) + assert conversation_id is not None, "conversation_id is required" + output_file = args.output_dir / f"{conversation_id}.pt" + return not output_file.exists() + + original_num = len(dataset) + dataset = dataset.filter(keep_conversation) + print( + "Removed", + original_num - len(dataset), + "conversations due to existing output files", + ) - model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + # For debugging + if args.debug_max_num_conversations is not None: + dataset = dataset.select(range(args.debug_max_num_conversations)) + + model = AutoModel.from_pretrained( + args.model, torch_dtype="auto", device_map="auto", trust_remote_code=True + ) num_hidden_layers = getattr(model.config, "num_hidden_layers", None) - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") @@ -99,30 +144,11 @@ async def main(args: argparse.Namespace) -> None: num_skipped_too_long = 0 num_invalid = 0 num_success = 0 - num_total_conversations = min( - len(all_conversations), args.debug_max_num_conversations or len(all_conversations) - ) - for idx, entry in enumerate( - tqdm( - all_conversations[: args.debug_max_num_conversations], - desc="Processing conversations", - total=num_total_conversations, - ) - ): - conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) - conversations = entry["conversations"] - if not conversations or not isinstance(conversations, list): - num_invalid += 1 - continue - - # Tokenize and check length - input_ids = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_template=False - ) - num_input_tokens = input_ids.shape[1] - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 - continue + pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") + + async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Tensor): + nonlocal num_success + nonlocal num_hidden_layers # Get hidden states with torch.inference_mode(): @@ -144,9 +170,9 @@ async def main(args: argparse.Namespace) -> None: aux_hidden_states = torch.cat( [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 ) - output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu() + output_hidden_states = hidden_states[-1].squeeze(0).cpu() output_file = output_dir / f"{conversation_id}.pt" - num_success += 1 + with open(output_file, "wb") as f: torch.save( { @@ -158,19 +184,49 @@ async def main(args: argparse.Namespace) -> None: f, ) + num_success += 1 + pbar.update(1) + + async def submit_generates(): + nonlocal num_skipped_too_long + nonlocal num_invalid + tasks = [] + idx = 0 + for entry in dataset: + conversation_id = entry.get("conversation_id", entry.get("uuid")) + + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + # Tokenize and check length + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors="pt", add_generation_template=False + )["input_ids"] + num_input_tokens = input_ids.shape[1] + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 + continue + + tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + # Increment only for valid conversations to match dump file index + idx += 1 + await asyncio.gather(*tasks) + + asyncio.run(submit_generates()) + if num_skipped_too_long > 0: print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") if num_invalid > 0: print(f"Skipped {num_invalid} invalid conversations without proper fields.") - if num_success == num_total_conversations: + if num_success == len(dataset): print(f"Successfully processed all {num_success} conversations.") else: - print( - f"Successfully processed {num_success} out of {num_total_conversations} conversations." - ) + print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") if __name__ == "__main__": cli_args = parse_args() - asyncio.run(main(cli_args)) + main(cli_args) diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh index 48d12aeb2d..debbe68814 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh @@ -19,5 +19,5 @@ python3 collect_hidden_states/compute_hidden_states_hf.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh index 31e2294d9b..dac0ab9a91 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh @@ -30,7 +30,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & +CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & done wait diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh index 487d0d69dc..75a27deb62 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh @@ -20,6 +20,6 @@ export TLLM_LOG_LEVEL="error"; python3 collect_hidden_states/compute_hidden_states_trtllm.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 4b0fd10605..d06cfc0613 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -33,7 +33,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & +export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & done wait diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 45c9c66321..3ef7156372 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -13,23 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os +import inspect +from collections.abc import Callable from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import FrameType from typing import Any import numpy as np import torch import transformers from datasets import load_dataset -from PIL import Image +from packaging.version import Version from scripts.ar_validate import validate_ar from torch.utils.data import Dataset -from transformers import AutoProcessor, Trainer, TrainerCallback +from transformers import Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother +import modelopt +from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master +from modelopt.torch.utils.plugins.transformers_dataset import ( + LanguageDataCollator, + ShardedDataset, + VisionLanguageDataCollator, +) try: import wandb @@ -38,459 +49,124 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" -) - - -def preprocess(examples, tokenizer, **kwargs): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - for sentence in source: - role, content = get_role_content(sentence) - messages.append({"role": role.lower(), "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - output = tokenizer( - conversation, - return_tensors="pt", - add_special_tokens=False, - truncation=True, - ) - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - - return new_examples - - -def preprocess_vlm(examples, tokenizer, processor, img_dir): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - "pixel_values": [], - "image_flags": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - # align role to user-assistant format - def convert_role(role): - role_map = { - "human": "user", - "gpt": "assistant", - } - return role_map[role.lower()] if role.lower() in role_map else role.lower() - - for sentence in source: - role, content = get_role_content(sentence) - new_role = convert_role(role) - messages.append({"role": new_role, "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - img_filename = os.path.join(img_dir, examples[i]["image"]) - img = Image.open(img_filename) - output = processor(images=img, text=conversation, return_tensors="pt") - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - # TODO: add labels and answer-only loss masking? - - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - new_examples["pixel_values"].append(output.pixel_values) - new_examples["image_flags"].append( - torch.ones((output.pixel_values.shape[0],), dtype=torch.int64) - ) - return new_examples +class OfflineSupervisedDataset(Dataset): + """Offline dataset for supervised fine-tuning. -class SupervisedDataset(Dataset): - """Dataset for supervised fine-tuning. + This dataset loads data on-the-fly from pre-processed .pt data files. Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + dumped_files (list): A list of file paths to the dumped .pt files. """ def __init__( self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, + dumped_files, ): super().__init__() - - print_rank_0("Formatting inputs...") - sources = raw_data - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess - self.data_dict = self.preprocess_fn( - sources, tokenizer, processor=vlm_processor, img_dir=img_dir - ) + self.dumped_files = dumped_files def __len__(self): - return len(self.data_dict["input_ids"]) + return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - return {k: self.data_dict[k][i] for k in self.data_dict} - + offline_data = torch.load(self.dumped_files[i]) -class LazySupervisedDataset(Dataset): - """Lazy dataset for supervised fine-tuning. - - This dataset loads data on-the-fly when requested, which can be memory-efficient but slower. - - Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ - - def __init__( - self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in lazy mode") - self.tokenizer = tokenizer - self.raw_data = raw_data - self.cached_data_dict = {} - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess - - def __len__(self): - return len(self.raw_data) - - def __getitem__(self, i) -> dict[str, torch.Tensor]: - if i in self.cached_data_dict: - return self.cached_data_dict[i] - ret = self.preprocess_fn( - [self.raw_data[i]], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir - ) - ret = {k: ret[k][0] for k in ret} - self.cached_data_dict[i] = ret + labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) + labels[..., :-1] = offline_data["input_ids"][..., 1:] + ret = { + "input_ids": offline_data["input_ids"], + "base_model_hidden_states": offline_data["hidden_states"], + "aux_hidden_states": offline_data["aux_hidden_states"], + "attention_mask": torch.ones_like(offline_data["input_ids"]), + "loss_mask": torch.ones_like(offline_data["input_ids"]), + "labels": labels, + } return ret -class OfflineSupervisedDataset(Dataset): - """Lazy offline dataset for supervised fine-tuning. +class EagleOfflineDataCollator: + """Data collator that truncate or pads data for offline training.""" - This dataset loads data on-the-fly from pre-processed .pt data files as well as - input conversations in JSON format. + def __init__(self, train_len): + self.train_len = train_len - Args: - data_entries (list): A list of tuples (raw_data_example, file_path). - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ + def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): + """Pad or truncate a tensor to length along a given dimension.""" + dim = dim % x.ndim # support negative dimension - def __init__( - self, - data_entries, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in offline mode") - self.tokenizer = tokenizer - self.data_entries = data_entries - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess + # allocate output tensor + out_shape = list(x.shape) + out_shape[dim] = length + out = x.new_zeros(out_shape) - # Does not cache the hidden states, as those have an extremely large memory footprint. - self.cached_data_dict = {} + # consturct copy slice + slc = [slice(None)] * x.ndim + slc[dim] = slice(0, min(length, x.size(dim))) - def __len__(self): - return len(self.data_entries) + # populate output tensor + out[tuple(slc)] = x[tuple(slc)] + return out - def __getitem__(self, i) -> dict[str, torch.Tensor]: - # Load the conversational data, using the cache - raw_data, offline_file_path = self.data_entries[i] - # Extend the data sample with the hidden states from the .pt file - max_length = self.tokenizer.model_max_length - offline_data = torch.load(offline_file_path) - offline_data["input_ids"] = offline_data["input_ids"][:max_length] - offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] - offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + base_batch = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["input_ids", "attention_mask", "loss_mask", "labels"] + } - ret = { - "input_ids": offline_data["input_ids"], - "attention_mask": torch.ones_like(offline_data["input_ids"]), - "loss_mask": torch.ones_like(offline_data["input_ids"]), - "labels": torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID), - "kwargs": { - "base_model_outputs": { - "base_model_hidden_states": offline_data["hidden_states"], - "aux_hidden_states": offline_data["aux_hidden_states"], - } - }, + base_model_outputs = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["base_model_hidden_states", "aux_hidden_states"] } - return ret + + batch = { + **base_batch, + "base_model_outputs": base_model_outputs, + } + return batch def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, - max_length=None, + train_len=None, ) -> dict: - """Make dataset and collator for supervised fine-tuning. - - Args: - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - data_args: Data arguments. + if data_args.offline_data_path is None: + train_dataset = ShardedDataset("json", data_files=data_args.data_path) + + if not data_args.vlm_processor: + data_collator = LanguageDataCollator( + tokenizer=tokenizer, + train_len=train_len, + return_labels=True, + ) + else: + data_collator = VisionLanguageDataCollator( + processor=data_args.vlm_processor, + train_len=train_len, + local_image_path=data_args.vlm_img_dir, + return_labels=True, + ) - Returns: - dict: A dictionary containing train and eval datasets. - """ - if data_args.vlm_processor: - vlm_processor = AutoProcessor.from_pretrained( - data_args.vlm_processor, trust_remote_code=True, use_fast=True - ) - vlm_img_dir = data_args.vlm_img_dir else: - vlm_processor, vlm_img_dir = None, None - # Load the conversations from the source file - print_rank_0("Loading input conversations...") - data_json = [] - data_path_p = Path(data_args.data_path) - if data_path_p.is_dir(): - # Load all .jsonl files in the directory and combine them - for jsonl_file in sorted(data_path_p.glob("*.jsonl")): - with open(jsonl_file) as f: - data_json.extend(json.loads(line) for line in f) - else: - with open(data_args.data_path) as f: - if data_args.data_path.endswith("jsonl"): - data_json = [json.loads(line) for line in f] - else: - data_json = json.load(f) - - if data_args.offline_data_path is not None: print_rank_0("Loading pre-processed data for offline training...") - dataset_cls = OfflineSupervisedDataset + assert not data_args.vlm_processor, "Offline data is not supported for VLM." - # Glob for all .pt files in the data_path directory - assert data_args.offline_data_path is not None, ( - "offline_data_path must be provided for offline training." - ) offline_data_path = Path(data_args.offline_data_path) - # Collect all pt file paths - all_files = {str(p) for p in offline_data_path.glob("*.pt")} - all_files |= {str(p) for p in offline_data_path.glob("**/*.pt")} - if not all_files: + dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") - # Build a map from conv_id to file_path for fast lookup - print("building conv_id_to_file map...") - conv_id_to_file = {} - for pt_path in all_files: - pt_name = Path(pt_path).name - # Expect conv_id.pt - if pt_name.endswith(".pt"): - conv_id = pt_name[:-3] - conv_id_to_file[conv_id] = pt_path - - valid_entries = [] - print("filtering valid entries...") - for entry in data_json: - conv_id = entry.get("conversation_id") - if conv_id is None: - conv_id = entry.get("uuid") - if conv_id is None: - conv_id = entry.get("id") - if conv_id is None: - raise ValueError(f"Conversation ID required but not found for entry {entry}") - - file_path = conv_id_to_file.get(str(conv_id)) - if file_path is None: - continue - valid_entries.append((entry, file_path)) - - if len(valid_entries) == 0: - msg = """No valid files found in the offline data path that match the conversation IDs - in the provided data json. Please ensure that the offline data path is correct and - contains .pt files named after the conversation IDs, and that the input conversations - json has the correct format (with 'conversation_id' or 'id' fields).""" - raise ValueError(msg) - elif len(valid_entries) < len(data_json): - print_rank_0( - f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations" - " have corresponding .pt files in the offline data path. Continuing..." - ) - - num_train = int(len(valid_entries) * 0.95) - train_dataset = dataset_cls( - valid_entries[:num_train], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - valid_entries[num_train:], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorForOffline(max_length=max_length) - else: - print_rank_0("Loading input conversations...") - dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset - - train_dataset = dataset_cls( - data_json[: int(len(data_json) * 0.95)], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - data_json[int(len(data_json) * 0.95) :], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorWithPadding(max_length=max_length) + train_dataset = OfflineSupervisedDataset(dumped_files) + data_collator = EagleOfflineDataCollator(train_len=train_len) return { "train_dataset": train_dataset, - "eval_dataset": eval_dataset, "data_collator": data_collator, } -class DataCollatorWithPadding: - def __init__(self, max_length): - self.max_length = max_length - - def paddingtensor2d(self, intensors, length): - n, dim = intensors.shape - if n > length: - return intensors[:length, :] - padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def paddingtensor(self, intensors, length): - if intensors.shape[0] > length: - return intensors[:length] - padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], self.max_length) for item in features] - ) - batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], self.max_length) for item in features] - ) - batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], self.max_length) for item in features] - ) - - batch_labels = torch.stack( - [self.paddingtensor(item["labels"], self.max_length) for item in features] - ) - - batch = { - "input_ids": batch_input_ids, - "attention_mask": batch_attention_mask, - "loss_mask": batch_loss_mask, - "labels": batch_labels, - } - - # Collate VLM data - if "pixel_values" in features[0]: - # pixel values and image flags should be flattened inside a batch - batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0) - batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0) - - return batch - - -class DataCollatorForOffline(DataCollatorWithPadding): - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - base_batch = super().__call__(features) - if "kwargs" not in features[0]: - raise ValueError("No kwargs found in batch features. Offline data required.") - - features = [item["kwargs"]["base_model_outputs"] for item in features] - - batch_hidden_states = torch.stack( - [ - self.paddingtensor2d(item["base_model_hidden_states"], self.max_length) - for item in features - ] - ) - batch_aux_hidden_states = torch.stack( - [self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features] - ) - - batch = { - **base_batch, - "base_model_outputs": { - "base_model_hidden_states": batch_hidden_states, - "aux_hidden_states": batch_aux_hidden_states, - }, - } - - return batch - - class EagleTrainerWithAccLog(Trainer): """Wrapper around Trainer that logs training accuracy.""" @@ -566,3 +242,137 @@ def on_step_end(self, args, state, control, **kwargs): except Exception: print_rank_0("AR validation not available.") return control + + +def get_patched_templated_ring_attn(orig_templated_attn: Callable): + """ + Return patched version of + torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention + to support TTT. + """ + + def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): + """Get chunk-interleaved TTT mask for current rank. + e.g.: + 2 ranks, ttt_step=1; + full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0], + [x, 0, 0, 0, 0, x, 0, 0], + [x, x, 0, 0, 0, 0, x, 0], + [x, x, x, 0, 0, 0, 0, x], + + rank 0, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 1, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 0, step1: [[0, 0, 0, 0], + [0, 0, 0, 0]] + + rank 1, step1: [[x, x, 0, 0], + [x, x, 0, 0]] + + """ + device = torch.cuda.current_device() + q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device) + kv_indices = ( + torch.arange(q_len * size * (ttt_step + 1), device=device) + .view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :] + .reshape(-1) + ) + msk_func = get_ttt_msk_func(q_len * size, ttt_step) + attn_mask = msk_func( + None, + None, + q_indices.view(1, 1, -1, 1), + kv_indices.view(1, 1, 1, -1), + ) + attn_bias = torch.where( + attn_mask, + torch.zeros((), dtype=dtype, device=attn_mask.device), + torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device), + ) + + return attn_bias + + def patched_templated_attn(*args, **kwargs): + """Patched version of _templated_ring_attention.""" + # Get original attention op + # Sensitive to impl of _templated_ring_attention + original_op = args[2] + + # This patch is only enabled for eagle model by context manager, not base model. + patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH + + if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: + raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}") + + # Unset is_causal to use custom attn mask + if patch_enbabled: + kwargs["is_causal"] = False + + def patched_op(*args, **kwargs): + # Inspect the parent frame to get current shard info + # This is sensitive to torch _templated_ring_attention impl + try: + frame: FrameType = inspect.currentframe() + f_back: FrameType = frame.f_back + rank = f_back.f_locals["rank"] + size = f_back.f_locals["size"] + query = f_back.f_locals["query"] + key = f_back.f_locals["key"] + i = f_back.f_locals["i"] + ttt_step = (key.shape[2] // query.shape[2]) - 1 + except Exception as e: + raise RuntimeError( + f"Failed to capture loop variables in patched _templated_ring_attention: {e}" + ) from e + # Set attn mask to permuted TTT mask + if "attn_bias" in kwargs: + kwargs["attn_bias"] = _get_sharded_ttt_msk( + i, rank, size, query.shape[2], ttt_step, query.dtype + ) + # Perform shard attention + return original_op(*args, **kwargs) + + return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs) + + return patched_templated_attn + + +def patch_ring_attention_for_ttt(): + """Patch torch ring attention to support context parallelism for TTT.""" + # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. + + if Version(torch.__version__) < Version("2.10.0"): + raise RuntimeError( + f"Context parallel TTT only supported for PyTorch >= 2.10.0. " + f"Got {torch.__version__}. " + f"Please use torch 2.10.0 or cp_size=1." + ) + + from torch.distributed.tensor.experimental._context_parallel import _attention + + # 1. Disable load balance, which is designed for causal mask. + # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init. + _attention._cp_options.enable_load_balance = False + + # 2. Patch templated ring attention for TTT mask. + original_templated_ring_attention = _attention._templated_ring_attention + original_templated_ring_attention_backward = _attention._templated_ring_attention_backward + _attention._templated_ring_attention = get_patched_templated_ring_attn( + original_templated_ring_attention + ) + _attention._templated_ring_attention_backward = get_patched_templated_ring_attn( + original_templated_ring_attention_backward + ) + + # 3. Patch merger to skip the blank shard to avoid difference in output. + original_sdpa_merger_step = _attention._SDPAMerger.step + + def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool): + if lse.sum() <= 0: + return + return original_sdpa_merger_step(self, out, lse, partial) + + _attention._SDPAMerger.step = patched_sdpa_merger_step diff --git a/examples/speculative_decoding/fsdp_config.json b/examples/speculative_decoding/fsdp_config.json new file mode 100644 index 0000000000..6d934182fe --- /dev/null +++ b/examples/speculative_decoding/fsdp_config.json @@ -0,0 +1 @@ +{"fsdp_version":2} \ No newline at end of file diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index e3b6c5a21d..ae8a21eea4 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi EAGLE_CONFIG="${1#*=}" ;; - --fsdp_transformer_layer_cls_to_wrap*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" - ;; - --num_gpu*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_GPU="${1#*=}" - ;; --disable_tqdm*) if [[ "$1" != *=* ]]; then shift; fi DISABLE_TQDM="${1#*=}" @@ -102,6 +94,22 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi AR_VALIDATE_STEPS="${1#*=}" ;; + --cp_size*) + if [[ "$1" != *=* ]]; then shift; fi + CP_SIZE="${1#*=}" + ;; + --dp_size*) + if [[ "$1" != *=* ]]; then shift; fi + DP_SHARD_SIZE="${1#*=}" + ;; + --log_steps*) + if [[ "$1" != *=* ]]; then shift; fi + LOG_STEPS="${1#*=}" + ;; + --draft_vocab_cache*) + if [[ "$1" != *=* ]]; then shift; fi + DRAFT_VOCAB_CACHE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -126,11 +134,9 @@ OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} NUM_EPOCHS=${NUM_EPOCHS:-1} SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} LR=${LR:-"1e-4"} -TRAIN_BS=${TRAIN_BS:-4} +TRAIN_BS=${TRAIN_BS:-1} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} -NUM_GPU=${NUM_GPU:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} @@ -138,6 +144,10 @@ VLM_PROCESSOR=${VLM_PROCESSOR:-} VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} +CP_SIZE=${CP_SIZE:-1} +DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} +LOG_STEPS=${LOG_STEPS:-100} +DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -163,11 +173,6 @@ else OFFLINE_TRAINING_ARGS="" fi -if [[ "$NUM_GPU" == 1 ]]; then - MULTI_GPU="" -else - MULTI_GPU="--multi_gpu" -fi if [[ "$VLM_PROCESSOR" != "" ]]; then VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" @@ -175,9 +180,25 @@ else VLM_ARGS="" fi +if [[ "$GPU_COUNT" -gt 1 ]]; then + #Use FSDP2 when multi GPU available + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json" +else + #Otherwise, single GPU training + FSDP_ARGS="" +fi + + +if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then + DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" +else + DRAFT_VOCAB_CACHE_ARGS="" +fi + + # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ +CMD="accelerate launch --mixed_precision bf16 main.py \ --mode $MODE \ --eagle_decoder_type $EAGLE_DECODER_TYPE \ --model_name_or_path $MODEL \ @@ -197,15 +218,19 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 100 \ + --logging_steps $LOG_STEPS \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \ --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ + $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ + $FSDP_ARGS \ + --cp_size $CP_SIZE \ + --dp_shard_size $DP_SHARD_SIZE \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd1af9563b..6821111849 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,12 +36,22 @@ import torch import transformers -from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module +from accelerate import ParallelismConfig +from eagle_utils import ( + EagleTrainerWithAccLog, + EagleTrainingPlot, + make_eagle_supervised_data_module, + patch_ring_attention_for_ttt, +) from medusa_utils import make_medusa_supervised_data_module from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.utils import ( + load_vlm_or_llm_with_kwargs, + patch_transformers5_params_loading, +) from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -70,9 +80,9 @@ class DataArguments: }, ) lazy_preprocess: bool = True - draft_vocab_cache_dir: str = field( - default="draft_vocab_cache", - metadata={"help": "Path to the d2t cache directory."}, + draft_vocab_cache: str | None = field( + default=None, + metadata={"help": "Path to d2t.pt cache file."}, ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -91,7 +101,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -100,6 +110,8 @@ class TrainingArguments(transformers.TrainingArguments): remove_unused_columns: bool = field( default=False, metadata={"help": "Set to False to keep extra args for VLM."} ) + cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) + dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) @dataclass @@ -130,30 +142,39 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) + if training_args.cp_size > 1: + patch_ring_attention_for_ttt() + # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 + training_args.parallelism_config.sp_backend = None print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir): - last_checkpoint = get_last_checkpoint(training_args.output_dir) + # Detect checkpoint to resume from + last_checkpoint = ( + get_last_checkpoint(training_args.output_dir) + if os.path.isdir(training_args.output_dir) + else None + ) + if last_checkpoint: print_rank_0(f"Last checkpoint detected: {last_checkpoint}") - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint + checkpoint = training_args.resume_from_checkpoint or last_checkpoint use_offline_training = data_args.offline_data_path is not None if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") + with patch_transformers5_params_loading(): + _, model = load_vlm_or_llm_with_kwargs( + checkpoint, torch_dtype="auto", trust_remote_code=True + ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( + model_config, model = load_vlm_or_llm_with_kwargs( model_args.model_name_or_path, torch_dtype="auto", device_map="cpu", @@ -163,79 +184,48 @@ def train(): if use_offline_training: # When doing offline training, we need to set num_hidden_layers # since we override it when loading the model for space savings - model_config = transformers.AutoConfig.from_pretrained( - model_args.model_name_or_path, trust_remote_code=True - ) model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, trust_remote_code=True, ) - if tokenizer.chat_template is None: - tokenizer.chat_template = ( - "{%- for message in messages %}" - "{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}" - "{%- endfor %}" - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - if training_args.mode == "medusa": config = { "medusa_num_heads": medusa_args.medusa_num_heads, "medusa_num_layers": medusa_args.medusa_num_layers, } mtsp.convert(model, [("medusa", config)]) - elif training_args.mode in ["eagle1", "eagle3"]: - from modelopt.torch.speculative.config import ( - default_eagle_config, - eagle3_default_config, - kimik2_eagle_default_config, + elif training_args.mode == "eagle3": + custom_config = ( + json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} ) - if eagle_args.eagle_decoder_type == "kimik2": - eagle_architecture_config = kimik2_eagle_default_config - else: - eagle_architecture_config = { - "eagle1": default_eagle_config, - "eagle3": eagle3_default_config, - }[training_args.mode] - - if eagle_args.eagle_config: - with open(eagle_args.eagle_config) as f: - custom_config = json.load(f) - eagle_architecture_config.update(custom_config) - config = { "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, - "eagle_architecture_config": eagle_architecture_config, + "eagle_architecture_config": custom_config, } mtsp.convert(model, [("eagle", config)]) # read draft vocab cache if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - try: - model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) - vocab_cache_path = os.path.join( - data_args.draft_vocab_cache_dir, model_name, "d2t.pt" + if not os.path.isfile(data_args.draft_vocab_cache): + raise FileNotFoundError( + f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" ) - vocab_cache = torch.load(vocab_cache_path) - model.eagle_module.d2t = vocab_cache - print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) - elif training_args.mode in ["eagle1", "eagle3"]: + elif training_args.mode == "eagle3": data_module = make_eagle_supervised_data_module( - tokenizer, data_args, max_length=training_args.training_seq_len + tokenizer, data_args, train_len=training_args.training_seq_len ) trainer = EagleTrainerWithAccLog( diff --git a/examples/speculative_decoding/requirements.txt b/examples/speculative_decoding/requirements.txt index 765af61041..6324bac62b 100644 --- a/examples/speculative_decoding/requirements.txt +++ b/examples/speculative_decoding/requirements.txt @@ -1,5 +1,2 @@ -flash-attn -openai -py7zr -sentencepiece>=0.2.0 -tensorboardX +accelerate==1.12.0 +transformers==5.0.0rc1 diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 38b8866933..d5c37a8951 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -18,10 +18,11 @@ from accelerate import Accelerator from datasets import load_dataset from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer import modelopt.torch.opt as mto from modelopt.torch.speculative.plugins.transformers import HFARValidation +from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs mto.enable_huggingface_checkpointing() @@ -71,7 +72,7 @@ def main(): accelerator = Accelerator() # Load model and tokenizer - model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto") + _, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index dfc293ee98..fc34215830 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -18,10 +18,10 @@ import argparse import torch -from transformers import AutoModelForCausalLM import modelopt.torch.opt as mto from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs def parse_args(): @@ -38,11 +38,11 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") +_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_hf_checkpoint( - model, # The quantized model. - export_dir=args.export_path, # The directory where the exported files will be stored. + model, + export_dir=args.export_path, ) print(f"Exported checkpoint to {args.export_path}") diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 60c57d1c3a..4e117d1238 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -17,12 +17,11 @@ set -eo pipefail -# Set default values for BASE_MODEL, NUM_GPU, and DATA +# Set default values for BASE_MODEL and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct -NUM_GPU=1 DATA=input_conversations/daring-anteater.jsonl -# Parse input arguments --base_model, --num_gpu, and --data +# Parse input arguments --base_model and --data while [[ $# -gt 0 ]]; do key="$1" case $key in @@ -30,10 +29,6 @@ while [[ $# -gt 0 ]]; do BASE_MODEL="$2" shift; shift ;; - --num_gpu) - NUM_GPU="$2" - shift; shift - ;; --data) DATA="$2" shift; shift @@ -49,15 +44,6 @@ while [[ $# -gt 0 ]]; do esac done - -if [[ "$NUM_GPU" == 1 ]]; then - export CUDA_VISIBLE_DEVICES=0 -else - # Export as 0,1,...,N-1 for NUM_GPU GPUs - devs="$(seq -s, 0 $((NUM_GPU-1)))" - export CUDA_VISIBLE_DEVICES="$devs" -fi - if [[ "$OFFLINE_DATA_PATH" != "" ]]; then OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" else @@ -73,7 +59,6 @@ mkdir -p "$(dirname "$OUTPUT_DIR")" --output_dir $OUTPUT_DIR \ $OFFLINE_DATA_ARGS \ --data $DATA \ - --num_gpu $NUM_GPU \ --num_epochs 2 \ --eagle_config eagle_config.json diff --git a/examples/windows/Benchmark.md b/examples/windows/Benchmark.md index 0105a7fad3..6714a61e8f 100644 --- a/examples/windows/Benchmark.md +++ b/examples/windows/Benchmark.md @@ -24,6 +24,8 @@ Memory savings and inference speedup are compared to the ONNX FP16 baseline. ### 1.2 Accuracy Comparison +#### 1.2.1 MMLU + For accuracy evaluation, the [Massive Multitask Language Understanding (MMLU)](https://arxiv.org/abs/2009.03300) benchmark has been utilized. Please refer to the [detailed instructions](./accuracy_benchmark/README.md) for running the MMLU accuracy benchmark. The table below shows the MMLU 5-shot score for some models. @@ -39,3 +41,56 @@ The table below shows the MMLU 5-shot score for some models. | [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) | 61.76 | 60.73 | | [Llama3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | 60.8 | 57.71 | | [Gemma-2b-it](https://huggingface.co/google/gemma-2b-it) | 37.01 | 37.2 | + +#### 1.2.2 Perplexity (PPL) + +Perplexity measures how well a probability model predicts a sample. Lower perplexity values indicate better model quality. The following table shows perplexity values at input sequence length 1024 with chunk size of 512. + +**Learn more about Perplexity:** [Perplexity - Wikipedia](https://en.wikipedia.org/wiki/Perplexity) | [Hugging Face - Perplexity of Fixed-Length Models](https://huggingface.co/docs/transformers/en/perplexity) + +- **FP16-MB**: Baseline FP16 genai model (Model Builder) +- **Mixed AWQ-MO**: Important linear layers in INT8, rest in INT4 (AWQ), with ModelOpt. +- **Mixed RTN-MO**: Important linear layers in INT8, rest in INT4 (RTN), with ModelOpt. +- **Pure INT4 AWQ-MO**: All linear layers INT4 (AWQ) with ModelOpt. +- **Pure INT4 RTN-MO**: All linear layers INT4 (RTN) with ModelOpt. +- **Pure INT8 RTN-MO**: All linear layers INT8 (RTN) with ModelOpt. +- **Pure INT8 AWQ-MO**: All linear layers INT8 (AWQ) with ModelOpt. +- **Configuration**: Windows OS, GPU RTX 5090, nvidia-modelopt v0.39.0, onnxruntime-genai-cuda 0.9.2, onnxruntime-gpu 1.23.0, torch 2.8.0+cu128, transformers 4.49.0 + +| Model | FP16-MB | Mixed AWQ-MO | Mixed RTN-MO | Pure INT4 AWQ-MO | Pure INT4 RTN-MO | Pure INT8 RTN-MO | Pure INT8 AWQ-MO | +|:------|:--------|:-------------|:-------------|:-----------------|:-----------------|:-----------------|:-----------------| +| DeepSeek R1 Distill Qwen 1.5B | 39.447 | 41.699 | 44.332 | 44.213 | 46.304 | 39.802 | 39.713 | +| Llama 3.2 1B Instruct | 12.631 | 13.852 | 14.176 | 14.549 | 16.900 | 12.664 | 12.637 | +| Phi-3.5 Mini Instruct | 6.046 | 6.500 | 6.599 | 6.711 | 7.070 | - | - | +| Phi-4 Mini Instruct | 9.039 | 9.673 | 9.712 | 10.015 | 10.911 | - | - | +| Qwen 2.5 1.5B Instruct | 9.216 | 10.084 | 10.338 | 10.495 | 10.933 | 9.227 | 9.232 | + +For detailed instructions on evaluating perplexity, please refer to the [Perplexity Evaluation Guide](./accuracy_benchmark/perplexity_metrics/README.md). + +#### 1.2.3 KL-divergence + +KL-divergence (Kullback-Leibler divergence) quantifies the distributional difference between the quantized model and the baseline model. Lower KL-divergence values indicate that the quantized model's output distribution is closer to the original model. + +**Learn more about KL-divergence:** [KL Divergence - Wikipedia](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) | [Understanding KL Divergence](https://www.countbayesie.com/blog/2017/5/9/kullback-leibler-divergence-explained) + +**Supported backends:** PyTorch and Onnxruntim-cuda, onnxruntime-trt-rtx-ep are both supported for evaluation. + +- **Baseline model**: Hugging Face FP16 model +- **Quantized models**: Models where quantization is simulated (a.k.a. fake quantization), typically using the PyTorch-CUDA backend for evaluation. Fake quantization means quantized weights and dequantized simultaneously to simulate quantization. The inference backend column in the table below indicates whether the reported results are from PyTorch simulation or ONNX-runtime-based inference. +- **Configuration**: Windows OS, GPU RTX 5090, nvidia-modelopt v0.39.0, onnxruntime-genai-cuda 0.9.2, onnxruntime-gpu 1.23.0, torch 2.8.0+cu128, transformers 4.49.0 + +| Model | Quantization Method | Quantization Granularity | KL-divergence | Inference Backend | +|:-----------------------|:-------------------------------------------------|:--------------------------------------------------------------------|:--------------|:------------------------------| +| Qwen2.5-1.5B-Instruct | Base FP16 (Baseline) | - | 0.000 | PyTorch (FP16) | +| Qwen2.5-1.5B-Instruct | int4+int8 Blockwise-max_algo-mixed_quant (simulated) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.336 | PyTorch (fake quantization) | +| Qwen2.5-1.5B-Instruct | int4+int8 max_algo-mixed_quant (simulated, per-channel) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.337 | PyTorch (fake quantization) | +| Llama-3.2-3B-Instruct | Base FP16 (Baseline) | - | 0.000 | PyTorch (FP16) | +| Llama-3.2-3B-Instruct | int4+int8 Blockwise-awq-lite_algo-mixed_quant (simulated) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.228 | PyTorch (fake quantization) | +| Llama-3.2-3B-Instruct | int4+int8 per-channel-awq-lite_algo-mixed_quant (simulated) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.230 | PyTorch (fake quantization) | +| Llama-3.2-3B-Instruct | int4+int8 Blockwise-max_algo-mixed_quant (simulated) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.238 | PyTorch (fake quantization) | +| Llama-3.2-3B-Instruct | int4+int8 per-channel-max_algo-mixed_quant (simulated) | INT4: per-block (block-size=128), INT8: per-channel (row-wise) | 0.238 | PyTorch (fake quantization) | +| Llama-3.2-3B-Instruct | int4 Blockwise-max_algo only (simulated) | INT4: per-block (block-size=128) | 0.334 | PyTorch (fake quantization) | + +*All KL-divergence results above are obtained via PyTorch fake quantization simulation unless otherwise noted. Inference with ONNX-runtime can also be evaluated .* + +For detailed instructions on computing KL-divergence, please refer to the [KL-divergence Evaluation Guide](./accuracy_benchmark/kl_divergence_metrics/README.md). diff --git a/examples/windows/README.md b/examples/windows/README.md index 2c9b212389..30e6cad1d3 100644 --- a/examples/windows/README.md +++ b/examples/windows/README.md @@ -117,7 +117,7 @@ onnx.save_model( ) ``` -For detailed instructions about deployment of quantized models with DirectML backend (ORT-DML), see the [DirectML](https://nvidia.github.io/Model-Optimizer/deployment/2_directml.html#directml-deployment). +For detailed instructions about deployment of quantized models with ONNX Runtime, see the [ONNX Runtime Deployment Guide](https://nvidia.github.io/Model-Optimizer/deployment/2_onnxruntime.html). > [!Note] > The ready-to-deploy optimized ONNX models from ModelOpt-Windows are available at HuggingFace [NVIDIA collections](https://huggingface.co/collections/nvidia/optimized-onnx-models-for-nvidia-rtx-gpus). @@ -134,7 +134,12 @@ For detailed instructions about deployment of quantized models with DirectML bac ## Support Matrix -Please refer to [support matrix](https://nvidia.github.io/Model-Optimizer/guides/0_support_matrix.html) for a full list of supported features and models. +| Model Type | Support Matrix | +|------------|----------------| +| Large Language Models (LLMs) | [View Support Matrix](./onnx_ptq/genai_llm/README.md#support-matrix) | +| Automatic Speech Recognition | [View Support Matrix](./onnx_ptq/whisper/README.md#support-matrix) | +| Segmentation Models | [View Support Matrix](./onnx_ptq/sam2/README.md#support-matrix) | +| Diffusion Models | [View Support Matrix](./torch_onnx/diffusers/README.md#support-matrix) | ## Benchmark Results diff --git a/examples/windows/onnx_ptq/genai_llm/README.md b/examples/windows/onnx_ptq/genai_llm/README.md index b833d44dc6..83274bb40c 100644 --- a/examples/windows/onnx_ptq/genai_llm/README.md +++ b/examples/windows/onnx_ptq/genai_llm/README.md @@ -1,12 +1,23 @@ +## Table of Contents + +- [Overview](#overview) +- [Setup](#setup) +- [Prepare ORT-GenAI Compatible Base Model](#prepare-ort-genai-compatible-base-model) +- [Quantization](#quantization) +- [Evaluate the Quantized Model](#evaluate-the-quantized-model) +- [Deployment](#deployment) +- [Support Matrix](#support-matrix) +- [Troubleshoot](#troubleshoot) + ## Overview -The example script showcases how to utilize the **ModelOpt-Windows** toolkit for optimizing ONNX (Open Neural Network Exchange) models through quantization. This toolkit is designed for developers looking to enhance model performance, reduce size, and accelerate inference times, while preserving the accuracy of neural networks deployed with backends like DirectML on local RTX GPUs running Windows. +The example script showcases how to utilize the **ModelOpt-Windows** toolkit for optimizing ONNX (Open Neural Network Exchange) models through quantization. This toolkit is designed for developers looking to enhance model performance, reduce size, and accelerate inference times, while preserving the accuracy of neural networks deployed with backends like TensorRT-RTX, DirectML, CUDA on local RTX GPUs running Windows. Quantization is a technique that converts models from floating-point to lower-precision formats, such as integers, which are more computationally efficient. This process can significantly speed up execution on supported hardware, while also reducing memory and bandwidth requirements. This example takes an ONNX model as input, along with the necessary quantization settings, and generates a quantized ONNX model as output. This script can be used for quantizing popular, [ONNX Runtime GenAI](https://onnxruntime.ai/docs/genai) built Large Language Models (LLMs) in the ONNX format. -### Setup +## Setup 1. Install ModelOpt-Windows. Refer [installation instructions](../../README.md). @@ -16,15 +27,15 @@ This example takes an ONNX model as input, along with the necessary quantization pip install -r requirements.txt ``` -### Prepare ORT-GenAI Compatible Base Model +## Prepare ORT-GenAI Compatible Base Model -You may generate the base model using the model builder that comes with onnxruntime-genai. The ORT-GenAI's [model-builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models) downloads the original Pytorch model from Hugging Face, and produces an ONNX GenAI compatible base model in ONNX format. See example command-line below: +You may generate the base model using the model builder that comes with onnxruntime-genai. The ORT-GenAI's [model-builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models) downloads the original Pytorch model from Hugging Face, and produces an ONNX GenAI-compatible base model in ONNX format. See example command-line below: ```bash python -m onnxruntime_genai.models.builder -m meta-llama/Meta-Llama-3-8B -p fp16 -e dml -o E:\llama3-8b-fp16-dml-genai ``` -### Quantization +## Quantization To begin quantization, run the script like below: @@ -35,13 +46,13 @@ python quantize.py --model_name=meta-llama/Meta-Llama-3-8B \ --calib_size=32 --algo=awq_lite --dataset=cnn ``` -#### Command Line Arguments +### Command Line Arguments The table below lists key command-line arguments of the ONNX PTQ example script. | **Argument** | **Supported Values** | **Description** | |---------------------------|------------------------------------------------------|-------------------------------------------------------------| -| `--calib_size` | 32 (default), 64, 128 | Specifies the calibration size. | +| `--calib_size` | 32 , 64, 128 (default) | Specifies the calibration size. | | `--dataset` | cnn (default), pilevel | Choose calibration dataset: cnn_dailymail or pile-val. | | `--algo` | awq_lite (default), awq_clip, rtn, rtn_dq | Select the quantization algorithm. | | `--onnx_path` | input .onnx file path | Path to the input ONNX model. | @@ -54,12 +65,13 @@ The table below lists key command-line arguments of the ONNX PTQ example script. | `--awqclip_alpha_step` | 0.05 (default) | Step-size for AWQ weight clipping, user-defined | | `--awqclip_alpha_min` | 0.5 (default) | Minimum AWQ weight-clipping threshold, user-defined | | `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined | -| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration | -| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data| -| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization| -| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy| +| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [cuda,cpu]) | List of execution-providers to use for session run during calibration | +| `--add_position_ids` | Default: position_ids input is disabled | Use this option to enable position_ids input in calibration data| +| `--enable_mixed_quant` | Default: mixed-quant is disabled | Use this option to enable mixed precision quantization| +| `--layers_8bit` | Default: None | Use this option to override default mixed-quant strategy| | `--gather_quantize_axis` | Default: None | Use this option to enable INT4 quantization of Gather nodes - choose 0 or 1| | `--gather_block_size` | Default: 32 | Block-size for Gather node's INT4 quantization (when its enabled using gather_quantize_axis option)| +| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.| Run the following command to view all available parameters in the script: @@ -69,37 +81,134 @@ python quantize.py --help Note: -1. For the `algo` argument, we have following options to choose form: awq_lite, awq_clip, rtn, rtn_dq. +1. For the `algo` argument, we have following options to choose from: awq_lite, awq_clip, rtn, rtn_dq. - The 'awq_lite' option does core AWQ scale search and INT4 quantization. - The 'awq_clip' option primarily does weight clipping and INT4 quantization. - The 'rtn' option does INT4 RTN quantization with Q->DQ nodes for weights. - The 'rtn_dq' option does INT4 RTN quantization with only DQ nodes for weights. 1. RTN algorithm doesn't use calibration-data. -1. If needed for the input base model, use `--no_position_ids` command-line option to disable +1. If needed for the input base model, use `--add_position_ids` command-line option to enable generating position_ids calibration input. The GenAI built LLM models produced with DML EP has position_ids input but ones produced with CUDA EP, NvTensorRtRtx EP don't have position_ids input. Use `--help` or command-line options table above to inspect default values. Please refer to `quantize.py` for further details on command-line parameters. -### Evaluate the Quantized Model +### Mixed Precision Quantization (INT4 + INT8) + +ModelOpt-Windows supports **mixed precision quantization**, where different layers in the model can be quantized to different bit-widths. This approach combines INT4 quantization for most layers (for maximum compression and speed) with INT8 quantization for important or sensitive layers (to preserve accuracy). + +#### Why Use Mixed Precision? + +Mixed precision quantization provides an optimal balance between: + +- **Model Size**: Primarily INT4 keeps the model small +- **Inference Speed**: INT4 layers run faster and smaller +- **Accuracy Preservation**: Critical layers in INT8 maintain model quality + +Based on benchmark results, mixed precision quantization shows significant advantages: + +| Model | Metric | INT4 RTN | Mixed RTN (INT4+INT8) | Improvement | +|:------|:-------|:-------------|:---------------------|:-----------| +| DeepSeek R1 1.5B | MMLU | 32.40% | 33.90% | +1.5% | +| | Perplexity | 46.304 | 44.332 | -2.0 (lower is better) | +| Llama 3.2 1B | MMLU | 39.90% | 44.70% | +4.8% | +| | Perplexity | 16.900 | 14.176 | -2.7 (lower is better) | +| Qwen 2.5 1.5B | MMLU | 56.70% | 57.50% | +0.8% | +| | Perplexity | 10.933 | 10.338 | -0.6 (lower is better) | + +As shown above, mixed precision significantly improves accuracy with minimal disk size increase (~85-109 MB). + +#### How Mixed Precision Works + +The quantization strategy selects which layers to quantize to INT8 vs INT4: + +1. **INT8 Layers** (Higher Precision): Important layers that significantly impact model quality. Quantized per-channel + +2. **INT4 Layers** (Maximum Compression): All other layers. Qunatized blockwise. + +This strategy preserves accuracy for the most sensitive layers while maintaining aggressive compression elsewhere. + +#### Using Mixed Precision Quantization + +##### Method 1: Use the default mixed precision strategy + +```bash +python quantize.py --model_name=meta-llama/Meta-Llama-3.2-1B \ + --onnx_path="E:\models\llama3.2-1b-fp16\model.onnx" \ + --output_path="E:\models\llama3.2-1b-int4-int8-mixed\model.onnx" \ + --algo=awq_lite \ + --calib_size=32 \ + --enable_mixed_quant +``` + +The `--enable_mixed_quant` flag automatically applies the default strategy. + +##### Method 2: Specify custom layers for INT8 + +```bash +python quantize.py --model_name=meta-llama/Meta-Llama-3.2-1B \ + --onnx_path="E:\models\llama3.2-1b-fp16\model.onnx" \ + --output_path="E:\models\llama3.2-1b-int4-int8-custom\model.onnx" \ + --algo=awq_lite \ + --calib_size=32 \ + --layers_8bit="layers.0,layers.1,layers.15,layers.16" +``` + +The `--layers_8bit` option allows you to manually specify which layers to quantize to INT8. You can use: + +- Layer indices: `layers.0,layers.5,layers.10` +- Layer paths: `model/layers.0/attn/qkv_proj` +- Partial names: `qkv_proj,down_proj` + +##### Technical Details + +- **Block Size**: INT4 layers use block-wise quantization (default block-size=128), INT8 uses per-channel quantization +- **Quantization Axis**: INT4 (per-block), INT8 (per-channel row-wise) +- **Compatibility**: Works with both `awq_lite` and `rtn_dq` algorithms +- **Automatic Detection**: The `--layers_8bit` option automatically enables mixed quantization + +For more benchmark results and detailed accuracy metrics, refer to the [Benchmark Guide](../../Benchmark.md). + +## Evaluate the Quantized Model To evaluate the quantized model, please refer to the [accuracy benchmarking](../../accuracy_benchmark/README.md) and [onnxruntime-genai performance benchmarking](https://github.com/microsoft/onnxruntime-genai/tree/main/benchmark/python). -### Deployment +## Deployment -Once an ONNX FP16 model is quantized using ModelOpt-Windows, the resulting quantized ONNX model can be deployed on the DirectML backend using [ORT-GenAI](https://onnxruntime.ai/) or [ORT](https://onnxruntime.ai/). +Once an ONNX FP16 model is quantized using ModelOpt-Windows, the resulting quantized ONNX model can be deployed using [ORT-GenAI](https://onnxruntime.ai/) or [ORT](https://onnxruntime.ai/). Refer to the following example scripts and tutorials for deployment: 1. [ORT GenAI examples](https://github.com/microsoft/onnxruntime-genai/tree/main/examples/python) 1. [ONNX Runtime documentation](https://onnxruntime.ai/docs/api/python/) -### Model Support Matrix - -Please refer to [support matrix](https://nvidia.github.io/Model-Optimizer/guides/0_support_matrix.html) for a full list of supported features and models. - -### Troubleshoot +## Support Matrix + +| Model | ONNX INT4 AWQ (W4A16) | +| :---: | :---: | +| Llama3.1-8B-Instruct | ✅ | +| Phi3.5-mini-Instruct | ✅ | +| Mistral-7B-Instruct-v0.3 | ✅ | +| Llama3.2-3B-Instruct | ✅ | +| Gemma-2b-it | ✅ | +| Gemma-2-2b | ✅ | +| Gemma-2-9b | ✅ | +| Nemotron Mini 4B Instruct | ✅ | +| Qwen2.5-7B-Instruct | ✅ | +| DeepSeek-R1-Distill-Llama-8B | ✅ | +| DeepSeek-R1-Distil-Qwen-1.5B | ✅ | +| DeepSeek-R1-Distil-Qwen-7B | ✅ | +| DeepSeek-R1-Distill-Qwen-14B | ✅ | +| Mistral-NeMo-Minitron-2B-128k-Instruct | ✅ | +| Mistral-NeMo-Minitron-4B-128k-Instruct | ✅ | +| Mistral-NeMo-Minitron-8B-128k-Instruct | ✅ | + +> *All LLMs in the above table are [GenAI](https://github.com/microsoft/onnxruntime-genai/) built LLMs.* + + > *`ONNX INT4 AWQ (W4A16)` means INT4 weights and FP16 activations using AWQ algorithm.* + +## Troubleshoot 1. **Configure Directories** @@ -132,4 +241,4 @@ Please refer to [support matrix](https://nvidia.github.io/Model-Optimizer/guides 1. **Error - Invalid Position-IDs input to the ONNX model** - The ONNX models produced using ONNX GenerativeAI (GenAI) have different IO bindings for models produced using different execution-providers (EPs). For instance, model built with DML EP has position-ids input in the ONNX model but models builts using CUDA EP or NvTensorRtRtx EP don't have position-ids inputs. So, if base model requires, use `no_position_ids` command-line argument for disabling position_ids calibration input or set "add_position_ids" variable to `False` value (hard-code) in the quantize script if required. + The ONNX models produced using ONNX GenerativeAI (GenAI) have different IO bindings for models produced using different execution-providers (EPs). For instance, model built with DML EP has position-ids input in the ONNX model but models builts using CUDA EP or NvTensorRtRtx EP don't have position-ids inputs. So, if base model requires, use `--add_position_ids` command-line argument for enabling position_ids calibration input or set "add_position_ids" variable to `True` value (hard-code) in the quantize script if required. diff --git a/examples/windows/onnx_ptq/genai_llm/quantize.py b/examples/windows/onnx_ptq/genai_llm/quantize.py index 57021ed4d6..d21d1d796b 100644 --- a/examples/windows/onnx_ptq/genai_llm/quantize.py +++ b/examples/windows/onnx_ptq/genai_llm/quantize.py @@ -230,7 +230,7 @@ def get_calib_inputs( ) if "cnn" in dataset_name: - dataset2 = load_dataset("cnn_dailymail", name="3.0.0", split="train").select( + dataset2 = load_dataset("abisee/cnn_dailymail", name="3.0.0", split="train").select( range(max_calib_rows_to_load) ) column = "article" @@ -369,7 +369,7 @@ def main(args): f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, " f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, " f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, " - f"layers_8bit={args.layers_8bit}\n" + f"layers_8bit={args.layers_8bit}, use_column_major={args.use_column_major}\n" ) print( @@ -443,6 +443,7 @@ def main(args): layers_8bit=args.layers_8bit, gather_block_size=args.gather_block_size, gather_quantize_axis=args.gather_quantize_axis, + use_column_major=args.use_column_major, ) logging.info(f"\nQuantization process took {time.time() - t} seconds") @@ -589,10 +590,10 @@ def main(args): help="True when --use_gqa was passed during export.", ) parser.add_argument( - "--no_position_ids", + "--add_position_ids", dest="add_position_ids", - action="store_false", - default=True, + action="store_true", + default=False, help="True when we want to also pass position_ids input to model", ) parser.add_argument( @@ -605,7 +606,7 @@ def main(args): parser.add_argument( "--calibration_eps", type=parse_calibration_eps, # Use the custom parser - default=["dml", "cpu"], # Default as a list + default=["cuda", "cpu"], # Default as a list help="Comma-separated list of calibration endpoints. Choose from 'cuda', 'cpu', 'dml', 'NvTensorRtRtx'.", ) parser.add_argument( @@ -629,5 +630,14 @@ def main(args): default="", help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"), ) + parser.add_argument( + "--use_column_major", + default=False, + action="store_true", + help=( + "Apply column-major storage optimization for execution providers that need it. " + "Only applicable for DQ-only quantization (e.g., rtn_dq, awq_lite, awq_clip)." + ), + ) args = parser.parse_args() main(args) diff --git a/examples/windows/onnx_ptq/sam2/README.md b/examples/windows/onnx_ptq/sam2/README.md index 24eb9a634a..8d9e5fed3e 100644 --- a/examples/windows/onnx_ptq/sam2/README.md +++ b/examples/windows/onnx_ptq/sam2/README.md @@ -6,6 +6,7 @@ This repository contains an example to demontrate 8-bit quantization of SAM2 ONN - [ONNX export and Inference tool](#onnx-export-and-inference-tool) - [Quantization](#quantization) +- [Support Matrix](#support-matrix) - [Validated Settings](#validated-settings) - [Troubleshoot](#troubleshoot) @@ -59,6 +60,14 @@ python .\sam2_onnx_quantization.py --onnx_path=E:\base\sam2_hiera_large.encoder. ``` +## Support Matrix + +| Model | ONNX INT8 Max (W8A8) | ONNX FP8 Max (W8A8) | +| :---: | :---: | :---: | +| sam2-hiera-large | ✅ | ✅ | + +> *`ONNX INT8 Max` means INT8 (W8A8) quantization of ONNX model using Max calibration. Similar holds true for the term `ONNX FP8 Max`.* + ## Validated Settings This example is currently validated with following settings: diff --git a/examples/windows/onnx_ptq/whisper/README.md b/examples/windows/onnx_ptq/whisper/README.md index 7d82c0dc45..8757aaeb53 100644 --- a/examples/windows/onnx_ptq/whisper/README.md +++ b/examples/windows/onnx_ptq/whisper/README.md @@ -7,6 +7,7 @@ This repository contains an example to demontrate 8-bit quantization of Whisper - [ONNX export](#onnx-export) - [Inference script](#inference-script) - [Quantization script](#quantization-script) +- [Support Matrix](#support-matrix) - [Validated Settings](#validated-settings) - [Troubleshoot](#troubleshoot) @@ -149,6 +150,14 @@ python .\whisper_onnx_quantization.py --model_name=openai/whisper-large --base_m - In case, ONNX installation unexpectedly throws error, then one can try with other ONNX versions. +## Support Matrix + +| Model | ONNX INT8 Max (W8A8) | ONNX FP8 Max (W8A8) | +| :---: | :---: | :---: | +| whisper-large | ✅ | ✅ | + +> *`ONNX INT8 Max` means INT8 (W8A8) quantization of ONNX model using Max calibration. Similar holds true for the term `ONNX FP8 Max`.* + ## Validated Settings These scripts are currently validated with following settings: diff --git a/examples/windows/torch_onnx/diffusers/README.md b/examples/windows/torch_onnx/diffusers/README.md index 962bf4ec0f..856fcb63be 100644 --- a/examples/windows/torch_onnx/diffusers/README.md +++ b/examples/windows/torch_onnx/diffusers/README.md @@ -7,7 +7,7 @@ This repository provides relevant steps, script, and guidance for quantization o - [Installation and Pre-requisites](#installation-and-pre-requisites) - [Quantization of Backbone](#quantization-of-backbone) - [Inference using ONNX runtime](#inference-using-onnxruntime) -- [Quantization Support Matrix](#quantization-support-matrix) +- [Support Matrix](#support-matrix) - [Validated Settings](#validated-settings) - [Troubleshoot](#troubleshoot) @@ -92,10 +92,10 @@ Place the quantized ONNX model file for the backbone inside its specific subdire Optimum-ONNX Runtime provides pipelines such as ORTStableDiffusion3Pipeline and ORTFluxPipeline that can be used to run ONNX-exported diffusion models. These pipelines offer a convenient, high-level interface for loading the exported graph and performing inference. For a practical reference, see the stable diffusion inference [example script](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/models/stable_difusion) in the ONNX Runtime inference examples repository. -## Quantization Support Matrix +## Support Matrix | Model | fp8 | nvfp41 | -| :---: | :---: | +| :---: | :---: | :---: | | SD3-Medium-Diffusers | ❌ | ✅ | | SD3.5-Medium | ✅ | ✅ | | Flux.1.Dev2 | ✅ | ✅ | @@ -106,7 +106,7 @@ Optimum-ONNX Runtime provides pipelines such as ORTStableDiffusion3Pipeline and > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try disabling with KV-Cache or MHA quantization, try out different calibration settings (like calibration samples data, samples size, diffusion steps etc.) or perform QAT / QAD (not yet supported / validated on Windows RTX).* -Please refer to [support matrix](https://nvidia.github.io/Model-Optimizer/guides/0_support_matrix.html) for a full list of supported features and models. +> *There are some known performance issues with NVFP4 model execution using TRTRTX EP. Stay tuned for further updates!* ## Validated Settings diff --git a/experimental/README.md b/experimental/README.md new file mode 100644 index 0000000000..61c2233c16 --- /dev/null +++ b/experimental/README.md @@ -0,0 +1,135 @@ +# Experimental Optimization Techniques + +Experimental optimization algorithms and research prototypes under active development. + +## Purpose + +For new optimization techniques (quantization, pruning, sparsity, etc.) that are: + +- Novel or research-stage algorithms +- Not yet production-ready +- May have unstable APIs + +**⚠️ Warning**: Experimental features are not guaranteed to work across releases. APIs may change or features may be removed without notice. Use at your own risk. + +## Requirements + +Each experimental technique must include: + +- **README.md** - Explains what the technique does, how to use it, current status, model support, and references +- **Working code** - Clear, readable implementation +- **Comprehensive tests** - Good test coverage demonstrating correctness +- **Detailed documentation** - Clear docs on usage, APIs, and behavior +- **Example** - Demonstrating usage +- **Model support list** - Which models/frameworks are supported +- **Deployment info** - Supported deployment frameworks (TensorRT-LLM, vLLM, SGLang, etc.) and whether custom kernels are required +- **requirements.txt** - Additional dependencies beyond base modelopt +- **License headers** - Apache 2.0 headers on all Python files + +## Example Structures + +Organize your code however makes sense. Here are some examples: + +**Simple flat structure:** + +```text +experimental/my_technique/ +├── README.md +├── requirements.txt +├── my_technique.py +├── test_my_technique.py +└── example.py +``` + +**Package structure:** + +```text +experimental/my_technique/ +├── README.md +├── requirements.txt +├── my_technique/ +│ ├── __init__.py +│ ├── core.py +│ └── config.py +├── tests/ +│ └── test_core.py +└── examples/ + └── example_usage.py +``` + +## Quality Standards + +Experimental code must meet quality standards: + +- Comprehensive test coverage required +- Clear documentation required +- Pass all pre-commit checks + +## PR Guidelines + +Keep PRs focused and reviewable: + +- **Split large features**: Break complex techniques into multiple PRs if needed +- **Reasonable scope**: PRs with tens of thousands of lines are difficult to review +- **Incremental development**: Consider submitting core functionality first, then enhancements +- If your technique is large, discuss the implementation plan in an issue first + +## Example Documentation Template + +Your technique's README.md should include: + +```markdown +# Your Technique Name + +Brief description of the optimization technique. + +## Model Support + +| Model/Framework | Supported | Notes | +|-----------------|-----------|-------| +| LLMs (Llama, GPT, etc.) | ✅ | Tested on Llama 3.1 | +| Diffusion Models | ❌ | Not yet supported | +| Vision Models | ✅ | Experimental | + +## Deployment + +| Framework | Supported | Notes | +|-----------|-----------|-------| +| TensorRT-LLM | ✅ | Requires custom kernel | +| vLLM | ❌ | Not yet supported | +| SGLang | ✅ | Uses standard ops | + +## Usage + +\`\`\`python +from experimental.my_technique import my_optimize +... +\`\`\` + +## Status + +Current state: Prototype + +Known issues: +- Issue 1 +- Issue 2 + +## References + +- [Paper](link) +- [Code repository](link) +- [Project page](link) +- [Related work](link) +``` + +## Path to Production + +When a technique is ready for production (proven effective, stable API, full tests, comprehensive docs), it can be promoted to the main `modelopt` package. + +**Contributors**: Open an issue proposing graduation with evidence of effectiveness and stability. + +**Users**: If you find an experimental feature valuable, open a GitHub issue requesting promotion to production. User demand is a key signal for production readiness. + +## Questions? + +Open a GitHub issue with `[experimental]` prefix. diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 0000000000..cc4a318199 --- /dev/null +++ b/experimental/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental optimization techniques for Model Optimizer. + +This package contains experimental and research-stage optimization algorithms +that are under active development. APIs may change without notice. + +Warning: + Code in this package is experimental and not covered by semantic versioning. + Use at your own risk in production environments. +""" + +import warnings + +warnings.warn( + "The 'experimental' package contains unstable APIs that may change. " + "Use at your own risk in production environments.", + FutureWarning, + stacklevel=2, +) + +__all__ = [] diff --git a/experimental/dms/ARCHITECTURE.md b/experimental/dms/ARCHITECTURE.md new file mode 100644 index 0000000000..e884d246ba --- /dev/null +++ b/experimental/dms/ARCHITECTURE.md @@ -0,0 +1,130 @@ +# DMS Architecture and Advanced Options + +This document describes DMS internals, configuration options, and how to extend the codebase. + +## Code Details + +### Eviction Decisions + +DMS supports two ways to compute the eviction decision: + +- **Extracted from a single neuron of a key or query vector**: see Section 3.1 of [Dynamic Memory Compression: Retrofitting LLMs for Accelerated Inference](https://arxiv.org/pdf/2403.09636). Enable with `dms_separate_alpha=False`. +- **Produced by a learned linear projection (adapter) from the hidden state**: see Section 3.2 of [Inference-Time Hyper-Scaling with KV Cache Compression](https://arxiv.org/pdf/2506.05345). Enable with `dms_separate_alpha=True`. + +You can also choose the granularity of eviction decisions: + +- `dms_alpha_per: "head"`: decisions are made independently per attention head (KV cache lengths may differ across heads). +- `dms_alpha_per: "layer"`: decisions are shared across heads within a layer (all heads in the layer keep the same number of tokens). + +During training, decision logits are augmented with Gumbel noise to enable differentiable gating (`dms.core.get_gating_with_noise`). During inference, a hard threshold is used. + +### Attention + +The DMS attention implementation (given decision logits) can be found in `dms/attention.py` (see `dms_attn_train_mode`). + +### Loss Function + +Training uses knowledge distillation with forward KL divergence between student and teacher logits, computed in `dms/training/engine.py` (`distillation_loss`). This is combined with a DMS compression loss that encourages the model to match the target eviction fraction. + +### DMS Schedule + +The compression ratio increases linearly from `initial_cr` (typically 1.0) to `final_cr` (e.g., 16.0) over `final_step` training steps. See `dms_schedule()` in `dms/training/engine.py`. + +## Advanced Options + +### Chunked Prefill + +Chunked prefill reduces peak memory usage during the prefill phase by processing the input sequence in fixed-size chunks. Set the chunk size (in tokens) via: + +```python +Qwen3ForCausalLMDMS.from_pretrained(..., dms_chunked_prefill=4096) +``` + +### Cache Preallocation + +The paged KV cache uses a dynamically resizable per-attention-layer block table (similar to `std::vector` in C++), growing as needed during generation. If you know your maximum context length ahead of time, you can preallocate to avoid runtime reallocations: + +```python +Qwen3ForCausalLMDMS.from_pretrained(..., dms_preallocate_for_tokens=2048) +``` + +## Retrofitting a New Model Family + +To add DMS support for a new model family, create a new directory under `models/`: + +```bash +models/new_model/ +├── configuration_new_model_dms.py # Config extending the base model config +├── extract.py # Checkpoint extraction +├── modeling_new_model_dms.py # Model with DMS attention +└── train.py # Training entry point +``` + +The model-specific code should: + +1. Extend the model's config class with DMS parameters (see `models/qwen3/configuration_qwen3_dms.py`). +2. Override the attention forward pass and call: + - `dms.core.prepare_attention_input` + - `dms.attention.dms_attention` +3. Add `dms_proj_alpha` and `dms_proj_alpha_norm` layers to the attention layer. +4. Add a YAML config under `configs/`. + +Core DMS operations (`prepare_attention_input`, `dms_attention`, `post_process_attention_output`) are model-agnostic; model-specific code provides its Q/K/V projections and any required norms as inputs. + +## Adding a New Dataset + +To add a new training dataset, edit `dms/training/data.py`: + +1. Define `filter_fn` and `extract_fn` for your dataset. +2. Create a `DatasetInfo` instance. + +Example: + +```python +def my_dataset_filter_fn(ds_elem): + return ds_elem["quality_score"] > 0.8 + +def my_dataset_extract_fn(ds_elem): + return { + "conversation": [ + {"role": "user", "content": ds_elem["prompt"]}, + {"role": "assistant", "content": ds_elem["response"]}, + ] + } + +MyNewDataset = DatasetInfo( + args=("org/my-dataset",), + kwargs={"split": "train"}, + filter_fn=my_dataset_filter_fn, + extract_fn=my_dataset_extract_fn, +) +``` + +Then reference it in your YAML config: + +```yaml +data: + blend: "MyNewDataset:0.5,OpenR1Math220k:0.5" +``` + +## Checkpoint Resume + +To resume training from the latest checkpoint, set the following in your YAML config: + +```yaml +hf_trainer: + resume_from_checkpoint: "auto" +``` + +This auto-detects the latest `checkpoint-N` directory under the output directory. You can also specify an explicit path: + +```yaml +hf_trainer: + resume_from_checkpoint: outputs/qwen3_8b/checkpoint-300 +``` + +Resume works because: + +- The Hugging Face Trainer restores optimizer state, LR scheduler state, the training step counter, and RNG states. +- The DMS schedule is deterministic given the current training step. +- Gumbel noise is seeded from `step + process_index + grad_acc_step`. diff --git a/experimental/dms/README.md b/experimental/dms/README.md new file mode 100644 index 0000000000..5e49f011af --- /dev/null +++ b/experimental/dms/README.md @@ -0,0 +1,134 @@ +# Dynamic Memory Sparsification (DMS) + +A minimal, optimized implementation of the DMS algorithm for KV-cache compression, as described in: + +> **Inference-Time Hyper-Scaling with KV Cache Compression** +> Adrian Łańcucki, Konrad Staniszewski, Piotr Nawrot, Edoardo M. Ponti +> Paper: [https://arxiv.org/abs/2506.05345](https://arxiv.org/abs/2506.05345) +> NeurIPS: [https://neurips.cc/virtual/2025/loc/san-diego/poster/119605](https://neurips.cc/virtual/2025/loc/san-diego/poster/119605) + +Inference-time scaling trades efficiency for improved reasoning by generating longer sequences. In Transformer LLMs, generation cost is often bottlenecked by the size of the key-value (KV) cache. DMS addresses this by learning a KV cache eviction policy that compresses the cache while preserving accuracy. + +## How it works + +DMS learns a per-head eviction policy that determines which KV cache entries to keep during generation. Rather than immediately discarding tokens, DMS delays eviction decisions, implicitly merging representations and preserving critical information. During training, the compression ratio is gradually increased from 1× to a target value (e.g., 8×), using knowledge distillation to match the outputs of an uncompressed teacher model. + +## What makes DMS practical + +- Achieves **8× compression** with minimal accuracy loss +- Adapter training: the default recipe trains eviction adapters only and freezes base weights for efficiency +- Requires **~250 training steps** (about **4 hours on 8× H100**) to adapt Qwen3-8B +- Drop-in replacement for Hugging Face models via a custom cache that supports variable sequence lengths across attention heads + +| Model family | Size | Training time (8× H100) | +|------------|------|--------------------------| +| Qwen3 | 8B | ~4 hours | + +--- + +## Quick start: Retrofitting Qwen3-8B with DMS + +### Installation + +This repository is designed to run inside an NVIDIA PyTorch container: + +```bash +docker pull nvcr.io/nvidia/pytorch:25.11-py3 +``` + +Clone and install: + +```bash +git clone https://github.com/NVIDIA/Model-Optimizer +cd experimental/dms +pip install -e . +``` + +This single install provides everything needed for training and evaluation (including lm-eval-harness). + +### Train DMS adapters + +**Note:** The number of GPUs determines the effective batch size. The configuration below was tested on a DGX H100 with 8× H100 80GB GPUs. For debugging with a smaller compute budget (e.g., a single RTX 5090), see [`scripts/train_small_debug.sh`](scripts/train_small_debug.sh). + +```bash +bash scripts/train.sh configs/qwen3_8b.yaml +``` + +This freezes the original Qwen3-8B weights and trains only the DMS eviction-policy parameters using knowledge distillation. Training completes in ~4 hours on a single DGX H100 node. + +The trained student model is saved to `outputs/qwen3_8b/student_model/` at the end of training. + +To resume training from the latest checkpoint, set `resume_from_checkpoint: "auto"` in the YAML config. + +### Extract from an intermediate checkpoint (optional) + +To extract a model from an intermediate checkpoint, run: + +```bash +python -m models.qwen3.extract \ + --config outputs/qwen3_8b/config.yaml \ + --checkpoint outputs/qwen3_8b/checkpoint-238 +``` + +### Evaluate + +Evaluate on the RULER long-context benchmark: + +```bash +bash scripts/evaluate.sh outputs/qwen3_8b/student_model +``` + +**Prerequisite:** The saved model relies on the `dms` package for its attention and cache implementations. Ensure `dms` is installed (`pip install -e .`) in any environment where you load the model for inference or evaluation. + +--- + +## Repository structure + +```bash +. +├── configs # YAML experiment configs +│   └── qwen3_8b.yaml +├── dms # Core DMS library (pip install -e .) +│   ├── attention_prefill.py # Exact prefill with eviction-based masking +│   ├── attention.py # DMS attention: train + inference modes +│   ├── cache_paged.py # Paged cache with block-based memory management +│   ├── cache.py # KV cache: HF wrapper + combined + contiguous +│   ├── core.py # Shared ops: prepare_attention_input, gating, chunked prefill +│   └── training +│   ├── data.py # Data pipeline: loading, blending, tokenization +│   └── engine.py # Distillation, model config, noise, trainer state +├── ARCHITECTURE.md +├── example_inference.ipynb +├── models # Model-specific adaptations +│   └── qwen3 +│   ├── configuration_qwen3_dms.py # Qwen3ConfigDMS +│   ├── extract.py # Checkpoint extraction +│   ├── modeling_qwen3_dms.py # Qwen3ForCausalLMDMS +│   └── train.py # Training entry point +└── scripts # Launch scripts +    ├── evaluate.sh +    └── train.sh +``` + +For code details, advanced options, and guides on extending DMS, see [ARCHITECTURE.md](ARCHITECTURE.md). + +## Limitations + +This repository currently supports training eviction adapters only and keeps base model weights frozen. This training approach can achieve comparable accuracy while being roughly two orders of magnitude cheaper than full fine-tuning. In contrast, the original recipe used in the paper updates all model weights during training; we plan to support it in the near future. + +For inference, this repository currently supports a single prefill-then-generate workflow. Multi-turn conversations with interleaved `prefill, generate, prefill, ...` steps are not yet optimized: the cache must be reset between independent sequences, and a slow fallback is used that simulates generation via repeated prefill steps. See [example_inference.ipynb](./example_inference.ipynb) for details. + +## Citation + +If you found DMS useful, please cite: + +```bibtex +@inproceedings{ + lancucki2025inferencetime, + title={Inference-Time Hyper-Scaling with {KV} Cache Compression}, + author={Adrian {\L}a{\'n}cucki and Konrad Staniszewski and Piotr Nawrot and Edoardo Ponti}, + booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}, + year={2025}, + url={https://openreview.net/forum?id=8ZiElzQxf1} +} +``` diff --git a/experimental/dms/configs/qwen3_1.7b.yaml b/experimental/dms/configs/qwen3_1.7b.yaml new file mode 100644 index 0000000000..c889c27b21 --- /dev/null +++ b/experimental/dms/configs/qwen3_1.7b.yaml @@ -0,0 +1,64 @@ +# DMS debug configuration for Qwen3-1.7B +# Designed for debugging and code optimization on a limited compute budget. + +model: + name: Qwen/Qwen3-1.7B + dtype: float32 + forward_fn_kwargs: + train_attn_kwargs: + kernel_options: + BLOCK_M1: 16 + BLOCK_M2: 16 + BLOCK_N1: 16 + BLOCK_N2: 16 + +dms: + alpha_scale: 100.0 + initial_alpha_offset: 5.0 + window_size: 512 + disable_eviction: false + separate_alpha: true + alpha_per: head + tau: 0.1 + initial_cr: 1.0 + final_cr: 16.0 + final_step: 510 + +data: + blend: "OpenR1Math220k:1.0" + train_samples: 4000 + max_length: 8192 + concat_always_start_new: true + process_vocab_using_chunk: 4096 + tokenizer_kwargs: + enable_thinking: true + +hf_trainer: + output_dir: outputs/qwen3_1.7b_small + run_name: dms_qwen3_1.7b_small + max_steps: 544 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 16 + learning_rate: 3.0e-5 + weight_decay: 0.0 + warmup_steps: 0 + lr_scheduler_type: constant + save_strategy: steps + save_steps: 34 + save_total_limit: 5 + logging_strategy: steps + logging_steps: 1 + gradient_checkpointing: false + tf32: false + bf16: true + save_safetensors: false + adam_beta1: 0.9 + adam_beta2: 0.95 + max_grad_norm: 1.0 + seed: 42 + fsdp: "full_shard offload" + fsdp_config: + use_orig_params: true + sync_module_states: true + activation_checkpointing: true + resume_from_checkpoint: # null = fresh start, "auto" = latest, or explicit path diff --git a/experimental/dms/configs/qwen3_8b.yaml b/experimental/dms/configs/qwen3_8b.yaml new file mode 100644 index 0000000000..ab69b055fb --- /dev/null +++ b/experimental/dms/configs/qwen3_8b.yaml @@ -0,0 +1,62 @@ +# DMS training configuration for Qwen3-8B +# +# Usage: +# accelerate launch -m models.qwen3.train --config configs/qwen3_8b.yaml +# +# To resume from latest checkpoint: +# Set resume_from_checkpoint to "auto" below, or pass an explicit path. + +model: + name: Qwen/Qwen3-8B + dtype: float32 + +dms: + alpha_scale: 100.0 + initial_alpha_offset: 5.0 + window_size: 512 + disable_eviction: false + separate_alpha: true + alpha_per: head + tau: 0.1 + initial_cr: 1.0 + final_cr: 16.0 + final_step: 510 + +data: + blend: "OpenR1Math220k:1.0" + train_samples: 4000 + max_length: 32768 + concat_always_start_new: true + process_vocab_using_chunk: 4096 + tokenizer_kwargs: + enable_thinking: true + +hf_trainer: + output_dir: outputs/qwen3_8b + run_name: dms_qwen3_8b + max_steps: 544 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 3.0e-5 + weight_decay: 0.0 + warmup_steps: 0 + lr_scheduler_type: constant + save_strategy: steps + save_steps: 34 + save_total_limit: 5 + logging_strategy: steps + logging_steps: 1 + gradient_checkpointing: false + tf32: false + bf16: true + save_safetensors: false + adam_beta1: 0.9 + adam_beta2: 0.95 + max_grad_norm: 1.0 + seed: 42 + fsdp: "full_shard offload" + fsdp_config: + use_orig_params: true + sync_module_states: true + activation_checkpointing: true + resume_from_checkpoint: # null = fresh start, "auto" = latest, or explicit path diff --git a/experimental/dms/dms/__init__.py b/experimental/dms/dms/__init__.py new file mode 100644 index 0000000000..8d96709dc9 --- /dev/null +++ b/experimental/dms/dms/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Dynamic Memory Sparsification (DMS) package.""" diff --git a/experimental/dms/dms/attention.py b/experimental/dms/dms/attention.py new file mode 100644 index 0000000000..4beaa3727f --- /dev/null +++ b/experimental/dms/dms/attention.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMS attention: dispatch, training mode (FlexAttention), and inference mode (Flash Attention).""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import torch +from torch.nn.attention.flex_attention import flex_attention + +from dms.cache import DMSCache +from dms.logging import get_logger + +if TYPE_CHECKING: + from dms.cache import DMSCombinedCacheLayer + +logger = get_logger("Attention") + +try: + from flash_attn import flash_attn_with_kvcache +except ImportError as e: + logger.warning(f"Error importing flash_attn_with_kvcache: {e}") + flash_attn_with_kvcache = None + +try: + from dms.attention_prefill import dms_run_prefill_flex +except ImportError as e: + logger.warning(f"Error importing dms_run_prefill_flex: {e}") + dms_run_prefill_flex = None + + +# ============================================================================= +# Dispatch +# ============================================================================= + + +def dms_attention( + new_q_flash: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + decisions: torch.Tensor, + decision_logits: torch.Tensor, + attention_mask: torch.Tensor, + layer_idx: int, + dms_cache: DMSCache | None, + attn_scaling: float, + window_size: int, + train_attn_kwargs: dict[str, Any] = {}, +): + """Handles prefill/inference of DMS (Dynamic Memory Sparsification). + + If dms_cache is None, we are in train mode, otherwise we are in eval mode. + """ + if dms_cache is None: + # train mode + attn_output = dms_attn_train_mode( + q_flash=new_q_flash, + k=new_k, + v=new_v, + decision_logits=decision_logits, + attention_mask=attention_mask, + layer_idx=layer_idx, + attn_scaling=attn_scaling, + window_size=window_size, + train_attn_kwargs=train_attn_kwargs, + ) + else: + # eval mode + decisions = decisions.to(torch.int32) + + attn_output = dms_attn_eval_mode( + new_q_flash=new_q_flash, + new_k=new_k, + new_v=new_v, + decisions=decisions, + attention_mask=attention_mask, + layer_idx=layer_idx, + dms_cache=dms_cache, + attn_scaling=attn_scaling, + ) + + return attn_output + + +# ============================================================================= +# Training mode (FlexAttention with soft gating) +# ============================================================================= + +MASK_VALUE = -50000.0 # score used to mask tokens in attention + + +def dms_attn_train_mode( + q_flash: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + decision_logits: torch.Tensor, + attention_mask: torch.Tensor, + layer_idx: int, + attn_scaling: float, + window_size: int, + train_attn_kwargs: dict[str, Any] = {}, +): + """Perform DMS attention in training mode using FlexAttention with soft gating.""" + page_batch, seq_len_qf, gqa_factor, head_dim_qf = q_flash.size() + batch, head_k, seq_len_k, head_dim_k = k.size() + assert page_batch == batch * head_k, ( + f"page_batch: {page_batch} != batch * head_k: {batch * head_k}" + ) + assert seq_len_qf == seq_len_k, f"seq_len_qf: {seq_len_qf} != seq_len_k: {seq_len_k}" + assert head_dim_qf == head_dim_k, f"head_dim_qf: {head_dim_qf} != head_dim_k: {head_dim_k}" + + seq_len = seq_len_k + + assert v.size() == k.size(), f"v: {v.size()} k: {k.size()}" + + assert decision_logits.size() == (batch, head_k, seq_len), ( + f"decision_logits.size: {decision_logits.size()} != (batch, head_k, seq_len): {(batch, head_k, seq_len)}" + ) + assert attention_mask is None or attention_mask.ndim == 4, ( + f"attention_mask.ndim: {attention_mask.ndim} is not 4" + ) + assert layer_idx >= 0, f"layer_idx: {layer_idx} is not >= 0" + + decision_logits = decision_logits.reshape(page_batch, seq_len) + # note that dms has shifted the decision logits by 1 + decision_logits = torch.nn.functional.pad(decision_logits[..., 1:], (0, 1), value=0.0) + dms_mask_values = torch.nn.functional.logsigmoid(-decision_logits) + k = k.reshape(page_batch, 1, seq_len, head_dim_k) + v = v.reshape(page_batch, 1, seq_len, head_dim_k) + + q_flash = q_flash.transpose(1, 2) + + def score_mod(score, b, h, q_idx, k_idx): + causal = q_idx >= k_idx + within_sliding_window = q_idx - k_idx <= window_size + + causal = causal.to(score.dtype) + within_sliding_window = within_sliding_window.to(score.dtype) + + modified_score = within_sliding_window * score + (1 - within_sliding_window) * ( + dms_mask_values[b, k_idx] + score + ) + + return (1 - causal) * MASK_VALUE + causal * modified_score + + attn_output = torch.compile(flex_attention)( + query=q_flash, + key=k, + value=v, + score_mod=score_mod, + scale=attn_scaling, + enable_gqa=True, + **train_attn_kwargs, + ) + + attn_output = attn_output.reshape(batch, head_k, gqa_factor, seq_len_qf, head_dim_qf).transpose( + 2, 3 + ) + + return attn_output + + +# ============================================================================= +# Inference mode (Flash Attention + paged KV cache) +# ============================================================================= + + +def dms_attn_eval_mode( + new_q_flash: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + decisions: torch.Tensor, + attention_mask: torch.Tensor | None, + layer_idx: int, + dms_cache: DMSCache, + attn_scaling: float, + flash_attn_fn: Callable = flash_attn_with_kvcache, + prefill_attn_fn: Callable = dms_run_prefill_flex, + prefill_attn_fn_kwargs: dict = {}, +): + """Perform DMS attention in evaluation mode using flash attention or flex prefill.""" + assert decisions.dtype in (torch.int32, torch.long), ( + f"decisions.dtype: {decisions.dtype} is not int32 or long" + ) + batch, head_k, new_seq_len, head_dim_k = new_k.size() + page_batch, seq_len_qf, gqa_factor, head_dim_qf = new_q_flash.size() + + assert page_batch == batch * head_k, ( + f"page_batch: {page_batch} != batch * head_k: {batch * head_k}" + ) + assert seq_len_qf == new_seq_len, f"seq_len_qf: {seq_len_qf} != new_seq_len: {new_seq_len}" + assert head_dim_qf == head_dim_k, f"head_dim_qf: {head_dim_qf} != head_dim_k: {head_dim_k}" + + assert new_v.size() == new_k.size(), f"new_v: {new_v.size()} new_k: {new_k.size()}" + + assert decisions.size() == (batch, head_k, new_seq_len), ( + f"decisions.size: {decisions.size()} != (batch, head_k, new_seq_len): {(batch, head_k, new_seq_len)}" + ) + assert attention_mask is None or attention_mask.ndim == 4, ( + f"attention_mask.ndim: {attention_mask.ndim} is not 4" + ) + assert layer_idx >= 0, f"layer_idx: {layer_idx} is not >= 0" + + layer_cache: DMSCombinedCacheLayer = dms_cache[layer_idx] + + if layer_cache.is_inference_mode(): + layer_cache.update( + key_states=new_k, + value_states=new_v, + cache_kwargs={ + "eviction_info": decisions, + "sequence_lengths": None, + "cumulative_length": 1, + }, + ) + + attn_output = flash_attn_fn( + new_q_flash, + layer_cache.paged_cache.get_key_blocks(), + layer_cache.paged_cache.get_value_blocks(), + k=None, + v=None, + cache_seqlens=layer_cache.paged_cache.get_seq_lengths(), + causal=True, + softmax_scale=attn_scaling, + block_table=layer_cache.paged_cache.get_block_table(), + ) + + attn_output = attn_output.reshape(batch, head_k, seq_len_qf, gqa_factor, head_dim_qf) + + return attn_output + elif layer_cache.is_prefill_mode(): + if attention_mask is None: + attention_mask = torch.ones((batch, 1, 1, new_seq_len), dtype=torch.bool).to( + new_q_flash.device + ) + attention_mask = attention_mask.to(torch.bool) + assert attention_mask.ndim == 4, ( + f"attention_mask.ndim: {attention_mask.ndim} is not 4" + ) # [batch, head or 1, q_seq_len, k_seq_len] + + attention_mask = attention_mask[:, 0, -1, -new_seq_len:] + + attention_output = prefill_attn_fn( + q_flash=new_q_flash, + keys=new_k, + values=new_v, + decisions=decisions, + attn_mask=attention_mask, + cache=layer_cache, + attn_scaling=attn_scaling, + flash_attn_fn=flash_attn_fn, + **prefill_attn_fn_kwargs, + ) + return attention_output + else: + raise ValueError(f"Invalid mode: {layer_cache.current_mode}") diff --git a/experimental/dms/dms/attention_prefill.py b/experimental/dms/dms/attention_prefill.py new file mode 100644 index 0000000000..8f93fc1779 --- /dev/null +++ b/experimental/dms/dms/attention_prefill.py @@ -0,0 +1,494 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exact DMS prefill attention with eviction-based sparse masking and cache rewriting.""" + +from collections.abc import Callable + +import torch +from torch.nn.attention.flex_attention import AuxRequest, flex_attention + +from dms.cache import DMSCombinedCacheLayer +from dms.logging import get_logger + +logger = get_logger("AttentionPrefill") + + +# ============================================================================= +# Cache rewriting utilities +# ============================================================================= + + +def rewrite_cache_in_left_padding_style( + compressed_attention_mask: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + eviction_info: torch.Tensor, +): + """Rewrite cache entries in left-padded format, removing evicted tokens. + + Args: + - compressed_attention_mask: torch.Tensor of shape (batch, kv_seq_len) + that for each key specifies how far to the right the key is visible from the query. + 0 denotes attention masking. + - key_states: torch.Tensor of shape (batch, heads_kv, kv_seq_len, head_dim) + - value_states: torch.Tensor of shape (batch, heads_kv, kv_seq_len, head_dim) + - eviction_info: torch.Tensor of shape (batch, heads_kv, kv_seq_len). + + Returns: + - left padded, potentially pruned (eviction) version of the key, value and eviction info tensors + """ + _batch, heads_kv, kv_seq_len, head_dim = key_states.shape + assert heads_kv == 1, "kv heads should be merged into batch dim" + + new_space_size = kv_seq_len + 1 + + new_key_states, new_value_states, new_eviction_info, how_many_to_maintain = ( + _rewrite_cache_in_left_padding_style_aux( + compressed_attention_mask=compressed_attention_mask, + key_states=key_states, + value_states=value_states, + eviction_info=eviction_info, + heads_kv=heads_kv, + head_dim=head_dim, + new_space_size=new_space_size, + ) + ) + return new_key_states, new_value_states, new_eviction_info, how_many_to_maintain + + +@torch.compile() +def _rewrite_cache_in_left_padding_style_aux( + compressed_attention_mask: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + eviction_info: torch.Tensor, + heads_kv: int, + head_dim: int, + new_space_size: int, +): + batch, kv_seq_len = compressed_attention_mask.shape + # elements that should be evicted before cache end are removed + should_remove = compressed_attention_mask < kv_seq_len + + should_maintain = torch.logical_not(should_remove).to(torch.int32) + + maintain_id = should_maintain.cumsum(dim=1) + how_many_to_maintain = maintain_id[:, -1] + # write elements to their new positions omitting removed elements + write_indexer = new_space_size - how_many_to_maintain[:, None] + maintain_id - 1 + # removed elements will be written to position 0 + write_indexer[should_remove] = 0 + + new_key_states = torch.empty( + batch, heads_kv, new_space_size, head_dim, device=key_states.device, dtype=key_states.dtype + ) + new_value_states = torch.empty( + batch, + heads_kv, + new_space_size, + head_dim, + device=value_states.device, + dtype=value_states.dtype, + ) + new_eviction_info = torch.empty( + batch, new_space_size, device=eviction_info.device, dtype=eviction_info.dtype + ) + + assert write_indexer.shape == (batch, kv_seq_len), ( + f"write_indexer.shape: {write_indexer.shape} != (batch, kv_seq_len): {(batch, kv_seq_len)}" + ) + assert eviction_info.shape == (batch, kv_seq_len), ( + f"eviction_info.shape: {eviction_info.shape} != (batch, kv_seq_len): {(batch, kv_seq_len)}" + ) + new_eviction_info.scatter_(dim=1, index=write_indexer, src=eviction_info) + + write_indexer = write_indexer[:, None, :, None].broadcast_to( + batch, heads_kv, kv_seq_len, key_states.shape[3] + ) + new_key_states.scatter_(dim=2, index=write_indexer, src=key_states) + new_value_states.scatter_(dim=2, index=write_indexer, src=value_states) + + is_padding = ( + torch.arange(new_space_size - 1, -1, -1, device=new_key_states.device, dtype=torch.int32)[ + None, : + ] + >= how_many_to_maintain[:, None] + ) + + kv_states_mask = is_padding[:, None, :, None].broadcast_to( + batch, heads_kv, new_space_size, head_dim + ) + + new_key_states[kv_states_mask] = 0 + new_value_states[kv_states_mask] = 0 + new_eviction_info[is_padding[:, :]] = 0 + + return new_key_states, new_value_states, new_eviction_info, how_many_to_maintain + + +# ============================================================================= +# Prefill attention with FlexAttention +# ============================================================================= + + +def wrapped_flex_attention(query, key, value, score_mod, scale, enable_gqa): + """Run flex attention with LSE auxiliary output.""" + return flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + scale=scale, + enable_gqa=enable_gqa, + return_aux=AuxRequest(lse=True), + ) + + +@torch.compile() +def compiled_flex_attention(*args, **kwargs): + """Compile and run flex attention with LSE auxiliary output.""" + return wrapped_flex_attention(*args, **kwargs) + + +def get_mask( + dms_window_size: int, + compressed_attention_mask: torch.Tensor, + q_seq_len: int, + gqa_factor: int, + flex_attention_fn: Callable, +): + """Build a score modification function for DMS sparse attention masking.""" + _page_batch, kv_seq_len = compressed_attention_mask.shape + q_offset = kv_seq_len - q_seq_len + + def score_mod(score, b, h, q_idx, k_idx): + causal = q_idx + q_offset >= k_idx + + within_range = q_idx + q_offset < compressed_attention_mask[b, k_idx] + can_attend = torch.logical_and(causal, within_range).to(score.dtype) + + return can_attend * score + (1 - can_attend) * (-1e5) + + return score_mod, flex_attention_fn + + +def dms_prefill_flex( + queries: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + decisions: torch.Tensor, + attn_mask: torch.Tensor, + cache: DMSCombinedCacheLayer, + attn_scaling: float, + flash_attn_fn: Callable, + q_flash: torch.Tensor, + flex_attention_fn: Callable = compiled_flex_attention, +): + """Compute DMS prefill attention using FlexAttention with eviction-based sparse masking. + + This function performs attention over both recent (contiguous) and paged KV caches, + applying eviction decisions to determine which KV pairs remain visible. + It combines attention outputs from both cache types using softmax LSE rescaling, + then updates the cache by evicting dms marked tokens and rewriting in left-padded format. + + Args: + queries: Query tensor of shape (batch, heads_q, seq_len, head_dim). + keys: Key tensor of shape (batch, heads_kv, seq_len, head_dim). + values: Value tensor of shape (batch, heads_kv, seq_len, head_dim). + decisions: Eviction decisions per token (0=keep, 1=evict after window). + attn_mask: Boolean attention mask of shape (batch, seq_len). + cache: Combined DMS cache containing recent and paged KV storage. + attn_scaling: Scaling factor for attention scores. + flash_attn_fn: Flash attention function for paged cache computation. + q_flash: Query tensor in flash attention layout for paged cache. + flex_attention_fn: Flex attention function for local cache computation. + + Returns: + Attention output tensor of shape (batch, heads_q, seq_len, head_dim). + """ + assert decisions.dtype in (torch.int32, torch.long), ( + f"decisions.dtype: {decisions.dtype} is not int32 or long" + ) + batch_k, heads_kv, seq_len_k, head_dim_k = keys.size() + batch_q, heads_q, seq_len_q, head_dim_q = queries.size() + assert keys.size() == values.size(), f"keys.size: {keys.size()} != values.size: {values.size()}" + assert batch_q == batch_k, f"batch_q: {batch_q} != batch_k: {batch_k}" + assert seq_len_q == seq_len_k, ( + f"dms_prefill_flex handles cache by itself, " + f"so query and key must have the same sequence length: q_seq_len: {seq_len_q} k_seq_len: {seq_len_k}" + ) + assert head_dim_k == head_dim_q, f"head_dim_k: {head_dim_k} != head_dim_q: {head_dim_q}" + assert heads_kv <= heads_q, f"heads_kv: {heads_kv} > heads_q: {heads_q}" + assert decisions.size() == (batch_q, heads_kv, seq_len_k), ( + f"decisions.size: {decisions.size()} != (batch_q, heads_kv, seq_len_k): {(batch_q, heads_kv, seq_len_k)}" + ) + + batch = batch_q + head_dim = head_dim_q + + page_batch = batch * heads_kv + + gqa_factor = heads_q // heads_kv + + keys = keys.reshape(page_batch, 1, seq_len_k, head_dim_k) + values = values.reshape(page_batch, 1, seq_len_k, head_dim_k) + queries = queries.reshape(page_batch, gqa_factor, seq_len_q, head_dim_q) + + decisions = decisions.reshape(page_batch, seq_len_k) + + assert attn_mask.size() == ( + batch, + seq_len_k, + ), f"Attention mask shape does not match: {attn_mask.size()} != {batch, seq_len_k}" + + assert attn_mask.dtype == torch.bool, f"Attention mask dtype is not bool: {attn_mask.dtype}" + + # transformers uses False to mask out positions + # here we use True to mask out positions + attn_mask = torch.logical_not(attn_mask) + + # used to zero out results for masked positions + results_masking = attn_mask[:, None, :, None] + + attn_mask = ( + attn_mask[:, None, :] + .broadcast_to(batch, heads_kv, seq_len_k) + .reshape(batch_q * heads_kv, seq_len_k) + ) + + # eviction info about i'th token is produced by i'th+1 token + # and may require carrying over to the cached kv pairs + eviction_info = decisions.clone() + eviction_info_carry = eviction_info[:, 0] + # we assume contiguous masking + eviction_info = torch.nn.functional.pad(eviction_info[:, 1:], (0, 1), value=0) + + assert eviction_info.shape == attn_mask.shape, ( + f"eviction_info: {eviction_info.shape} attn_mask: {attn_mask.shape}" + ) + eviction_info[attn_mask] = 2 # 0 - no eviction, 1 - dms eviction, 2 - attention mask + + assert isinstance(cache, DMSCombinedCacheLayer), ( + f"requires DMSCombinedCacheLayer, got {type(cache)}" + ) + + if cache.get_recent_cache_csize() > 0: + past_keys, past_values, past_cache_seq_lengths, past_eviction_info = ( + cache.get_recent_cache() + ) + + past_keys = past_keys.reshape(page_batch, 1, -1, head_dim) + past_values = past_values.reshape(page_batch, 1, -1, head_dim) + past_cache_seq_lengths = past_cache_seq_lengths.reshape(page_batch) + past_eviction_info = past_eviction_info.reshape(page_batch, -1) + + _, _, past_seq_len, _ = past_keys.shape + + keys = torch.cat([past_keys, keys], dim=2) + values = torch.cat([past_values, values], dim=2) + + assert past_eviction_info.size() == (page_batch, past_seq_len) + # cache should be left padded + past_eviction_info = torch.nn.functional.pad(past_eviction_info[:, 1:], (0, 1), value=0) + + past_eviction_info[:, -1] = eviction_info_carry + + # mask out padding in the prefix + padded_prefix_indexer = torch.arange( + past_seq_len - 1, -1, -1, device=keys.device, dtype=torch.int32 + ) + assert past_cache_seq_lengths.size() == (page_batch,) + padded_prefix_indexer = padded_prefix_indexer[None, :] >= past_cache_seq_lengths[:, None] + assert past_eviction_info.size() == (page_batch, past_seq_len) + past_eviction_info[padded_prefix_indexer] = 2 # 1 - dms eviction 2-attention mask + + eviction_info = torch.cat( + [ + past_eviction_info.to(torch.int32), + eviction_info, + ], + dim=1, + ) + + dms_window_size = cache.cont_cache.dms_window_size + total_padded_seq_len = keys.shape[2] + + # paddle paddle flashmask style attention mask + # that for each kv-pair we say till what position the it is visible from the query + compressed_attention_mask = torch.full_like( + eviction_info, total_padded_seq_len, dtype=torch.int32 + ) + assert compressed_attention_mask.shape == (page_batch, total_padded_seq_len), ( + f"compressed_attention_mask.shape: {compressed_attention_mask.shape}" + f" != (page_batch, total_padded_seq_len): {(page_batch, total_padded_seq_len)}" + ) + + position_indexer = torch.arange(total_padded_seq_len, device=keys.device, dtype=torch.int32) + position_indexer = position_indexer[None, :] + position_indexer = position_indexer.broadcast_to(page_batch, total_padded_seq_len) + + compressed_attention_mask[eviction_info == 2] = 0 # attention mask/padding + compressed_attention_mask[eviction_info == 1] = ( + position_indexer[eviction_info == 1] + + dms_window_size # Warning we do not support attention gaps + ) + + compressed_attention_mask = torch.clamp(compressed_attention_mask, max=total_padded_seq_len) + + score_mod_fn, attention_fn = get_mask( + dms_window_size=dms_window_size, + compressed_attention_mask=compressed_attention_mask, + q_seq_len=seq_len_q, + gqa_factor=gqa_factor, + flex_attention_fn=flex_attention_fn, + ) + + attn_output_local, aux_request = attention_fn( + query=queries, + key=keys, + value=values, + score_mod=score_mod_fn, + scale=attn_scaling, + enable_gqa=True, + ) + + attn_output_local = attn_output_local.reshape(batch, heads_q, seq_len_q, head_dim_q) + + if cache.get_paged_cache_csize() > 0: + paged_cache = cache.get_paged_cache() + + attention_output_paged, softmax_lse_paged = flash_attn_fn( + q_flash, + paged_cache.get_key_blocks(), + paged_cache.get_value_blocks(), + k=None, + v=None, + cache_seqlens=paged_cache.get_seq_lengths(), + causal=False, + softmax_scale=attn_scaling, + block_table=paged_cache.get_block_table(), + return_softmax_lse=True, + ) + + softmax_lse_paged = torch.where(torch.isinf(softmax_lse_paged), 0, softmax_lse_paged) + + attention_output_paged = attention_output_paged.reshape( + batch, heads_kv, seq_len_q, gqa_factor, head_dim_k + ).transpose(2, 3) + + attention_output_paged = attention_output_paged.reshape(batch, heads_q, seq_len_q, head_dim) + + softmax_lse_local = aux_request.lse + denom_local = torch.exp(softmax_lse_local.float()) + denom_local = denom_local.reshape(batch, heads_q, seq_len_q) + denom_paged = torch.exp(softmax_lse_paged.float()) + denom_paged = denom_paged.reshape(batch, heads_q, seq_len_q) + + new_denom = denom_local + denom_paged + + denom_changer_local = (denom_local / new_denom).to(attn_output_local.dtype) + denom_changer_local = denom_changer_local[:, :, :, None] + denom_changer_paged = (denom_paged / new_denom).to(attention_output_paged.dtype) + denom_changer_paged = denom_changer_paged[:, :, :, None] + + assert denom_changer_local.ndim == attn_output_local.ndim, ( + f"denom_changer_local.ndim: {denom_changer_local.ndim} != attn_output_local.ndim: {attn_output_local.ndim}" + ) + assert denom_changer_paged.ndim == attention_output_paged.ndim, ( + f"denom_changer_paged.ndim: {denom_changer_paged.ndim}" + f" != attention_output_paged.ndim: {attention_output_paged.ndim}" + ) + assert attention_output_paged.ndim == attn_output_local.ndim, ( + f"attention_output_paged.ndim: {attention_output_paged.ndim}" + f" != attn_output_local.ndim: {attn_output_local.ndim}" + ) + + attn_output = ( + attn_output_local * denom_changer_local + attention_output_paged * denom_changer_paged + ) + else: + attn_output = attn_output_local + + attn_output[results_masking.broadcast_to(batch, heads_q, seq_len_q, head_dim_q)] = 0 + + # performs eviction and rewrites cache in left padding style + new_key_states, new_value_states, new_eviction_info, seq_lengths = ( + rewrite_cache_in_left_padding_style( + compressed_attention_mask=compressed_attention_mask, + key_states=keys, + value_states=values, + eviction_info=eviction_info, + ) + ) + + cache.update( + key_states=new_key_states.reshape(batch, heads_kv, -1, head_dim), + value_states=new_value_states.reshape(batch, heads_kv, -1, head_dim), + cache_kwargs={ + "eviction_info": torch.nn.functional.pad( + new_eviction_info[..., :-1].reshape(batch, heads_kv, -1), + (1, 0), + value=0, + ), + "sequence_lengths": seq_lengths.reshape(batch, heads_kv), + "cumulative_length": seq_len_q, + }, + ) + + return attn_output + + +def dms_run_prefill_flex( + q_flash: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + decisions: torch.Tensor, + attn_mask: torch.Tensor, + cache: DMSCombinedCacheLayer, + attn_scaling: float, + flash_attn_fn: Callable, + flex_attention_fn: Callable = compiled_flex_attention, +): + """Run DMS prefill using FlexAttention, reshaping tensors for flash layout.""" + _page_batch, seq_len_q, gqa_factor, head_dim_q = q_flash.size() + batch, head_k, _seq_len_k, head_dim_k = keys.size() + + head_q = gqa_factor * head_k + + queries = q_flash.transpose(1, 2).reshape(batch, head_q, seq_len_q, head_dim_q) + + attn_output = dms_prefill_flex( + queries=queries, + keys=keys, + values=values, + decisions=decisions, + attn_mask=attn_mask, + cache=cache, + attn_scaling=attn_scaling, + flash_attn_fn=flash_attn_fn, + q_flash=q_flash, + flex_attention_fn=flex_attention_fn, + ) + + assert attn_output.shape == (batch, head_q, seq_len_q, head_dim_q), ( + f"attn_output.shape: {attn_output.shape} != {batch, head_q, seq_len_q, head_dim_q}" + ) + + attn_output = attn_output.reshape(batch, head_k, gqa_factor, seq_len_q, head_dim_k).transpose( + 2, 3 + ) + + return attn_output diff --git a/experimental/dms/dms/cache.py b/experimental/dms/dms/cache.py new file mode 100644 index 0000000000..e453747fc3 --- /dev/null +++ b/experimental/dms/dms/cache.py @@ -0,0 +1,539 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMS KV cache: HF-compatible wrapper, combined cache layer, and contiguous cache layer. + +The paged cache layer (DMSPagedCacheLayer) is in dms.cache_paged. +""" + +import functools +from enum import Enum +from typing import Any + +import torch +from transformers import CacheLayerMixin +from transformers.cache_utils import Cache + +from dms.cache_paged import DMSPagedCacheLayer + +# ============================================================================= +# Contiguous (non-paged) DMS cache layer +# ============================================================================= + + +class DMSContCacheLayer(CacheLayerMixin): + """Used for storing contiguous (non-paged) cache.""" + + def __init__( + self, + dms_window_size: int, + max_context_length: int, + block_size: int = 256, + growth_factor: float = 1.5, + accommodate_min_initial_context_length: int = 4096, + disable_eviction: bool = False, + ): + """Initialize contiguous cache layer.""" + super().__init__() + self.dms_window_size = dms_window_size + self.max_context_length = max_context_length + self.block_size = block_size + self.growth_factor = growth_factor + self.min_initial_context_length = accommodate_min_initial_context_length + self.disable_eviction = disable_eviction + + self.key_cache = None + self.value_cache = None + self.eviction_info = None + self.cache_seq_lengths = None + self.cumulative_length = 0 + + self.device = None + + def offload(self): + """Offload cache tensors to CPU.""" + if self.key_cache is not None: + self.key_cache = self.key_cache.to("cpu", non_blocking=True) + self.value_cache = self.value_cache.to("cpu", non_blocking=True) + self.eviction_info = self.eviction_info.to("cpu", non_blocking=True) + self.cache_seq_lengths = self.cache_seq_lengths.to("cpu", non_blocking=True) + + def prefetch(self): + """Prefetch cache tensors back to the original device.""" + if self.key_cache is not None and self.key_cache.device != self.device: + self.key_cache = self.key_cache.to(self.device, non_blocking=True) + self.value_cache = self.value_cache.to(self.device, non_blocking=True) + self.eviction_info = self.eviction_info.to(self.device, non_blocking=True) + self.cache_seq_lengths = self.cache_seq_lengths.to(self.device, non_blocking=True) + + def reset(self): + """Reset cache to uninitialized state.""" + self.key_cache = None + self.value_cache = None + self.eviction_info = None + self.cache_seq_lengths = None + self.cumulative_length = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder cache for beam search (not supported).""" + raise NotImplementedError("Beam search is not supported") + + def lazy_initialization(self, key_states: torch.Tensor): + """Lazy initialization placeholder.""" + return None + + def update(self): + """Update cache (not implemented for contiguous cache).""" + raise NotImplementedError("update method is not implemented") + + def is_initialized(self) -> bool: + """Check if the cache has been initialized.""" + return self.key_cache is not None + + def replace( + self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: dict[str, Any] + ): + """Replace the entire cache contents.""" + if self.device is None: + self.device = key_states.device + eviction_info = cache_kwargs["eviction_info"] + seq_lengths = cache_kwargs["sequence_lengths"] + cumulative_length = cache_kwargs["cumulative_length"] + + assert key_states is not None, "key_states is None" + assert value_states is not None, "value_states is None" + assert eviction_info is not None, "eviction_info is None" + assert seq_lengths is not None, "seq_lengths is None" + + self.cumulative_length = cumulative_length + self.key_cache = key_states + self.value_cache = value_states + self.eviction_info = eviction_info + self.cache_seq_lengths = seq_lengths + + def get_cache( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get the cached key, value, sequence lengths, and eviction info.""" + return ( + self.key_cache, + self.value_cache, + self.cache_seq_lengths, + self.eviction_info, + ) + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask.""" + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object.""" + return self.max_context_length + + +# ============================================================================= +# Combined contiguous + paged cache layer +# ============================================================================= + + +class Mode(Enum): + """Cache operation modes.""" + + START = 0 + PREFILL = 1 + INFERENCE = 2 + + +class DMSCombinedCacheLayer(CacheLayerMixin): + """Used for handling prefill along with inference. + + Contiguous cache is used for recent tokens, paged cache is used for tokens outside of the sliding window. + """ + + def __init__( + self, + dms_window_size: int, + max_context_length: int, + block_size: int = 256, + growth_factor: float = 1.5, + accommodate_min_initial_context_length: int = 4096, + disable_eviction: bool = False, + ): + """Initialize combined cache with contiguous and paged sub-caches.""" + super().__init__() + self.dms_window_size = dms_window_size + self.block_size = block_size + self.disable_eviction = disable_eviction + self.paged_cache = DMSPagedCacheLayer( + dms_window_size=dms_window_size, + max_context_length=max_context_length, + block_size=block_size, + growth_factor=growth_factor, + accommodate_min_initial_context_length=accommodate_min_initial_context_length, + disable_eviction=True, + ) # For prefill & inference + self.cont_cache = DMSContCacheLayer( + dms_window_size=dms_window_size, + max_context_length=max_context_length, + block_size=block_size, + growth_factor=growth_factor, + accommodate_min_initial_context_length=accommodate_min_initial_context_length, + disable_eviction=True, + ) # For prefill + + self.max_context_length = max_context_length + + self.current_mode = Mode.START + self.cumulative_length = 0 + + def offload(self): + """Offload cache tensors to CPU.""" + self.paged_cache.offload() + self.cont_cache.offload() + + def prefetch(self): + """Prefetch cache tensors back to the original device.""" + self.paged_cache.prefetch() + self.cont_cache.prefetch() + + def reset(self): + """Reset both sub-caches and return to start mode.""" + self.paged_cache.reset() + self.cont_cache.reset() + self.current_mode = Mode.START + self.cumulative_length = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder cache for beam search (not supported).""" + raise NotImplementedError("Beam search is not supported") + + def lazy_initialization(self, key_states: torch.Tensor): + """Lazy initialization placeholder.""" + return None + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: dict[str, Any], + ): + """Update the cache with new key-value states and eviction info.""" + assert self.current_mode != Mode.START + + batch, head, seq_len, _head_dim = key_states.size() + seq_lengths = cache_kwargs["sequence_lengths"] + eviction_info = cache_kwargs["eviction_info"] + cumulative_length = cache_kwargs["cumulative_length"] + self.cumulative_length += cumulative_length + + assert value_states.size() == key_states.size(), ( + f"value_states.size: {value_states.size()} != key_states.size: {key_states.size()}" + ) + assert seq_lengths is None or seq_lengths.size() == (batch, head), ( + f"seq_lengths.size: {seq_lengths.size()} != (batch, head): {(batch, head)}" + ) + assert eviction_info.size() == ( + batch, + head, + seq_len, + ), ( + f"eviction info size: {eviction_info.size()} should be {(batch, head, seq_len)}" + ) # Eviction info is right shifted by 1 + + if self.current_mode == Mode.PREFILL: + assert seq_lengths is not None + + keys_recent = key_states[:, :, -self.cont_cache.dms_window_size :, :] + keys_to_paged_cache = key_states[:, :, : -self.cont_cache.dms_window_size, :] + + values_recent = value_states[:, :, -self.cont_cache.dms_window_size :, :] + values_to_paged_cache = value_states[:, :, : -self.cont_cache.dms_window_size, :] + + eviction_info_recent = eviction_info[..., -self.cont_cache.dms_window_size :] + + seq_lengths_recent = torch.clamp(seq_lengths, max=self.cont_cache.dms_window_size) + seq_lengths_to_paged_cache = (seq_lengths - self.cont_cache.dms_window_size).clamp( + min=0 + ) + + # move what we can to the paged cache + + cumulative_length_to_paged_cache = keys_to_paged_cache.shape[2] + + if cumulative_length_to_paged_cache > 0 and seq_lengths_to_paged_cache.max() > 0: + self.paged_cache.fast_update_ignore_eviction( + key_states=keys_to_paged_cache, + value_states=values_to_paged_cache, + sequence_lengths=seq_lengths_to_paged_cache, + ) + + self.cont_cache.replace( + key_states=keys_recent, + value_states=values_recent, + cache_kwargs={ + "eviction_info": eviction_info_recent, + "sequence_lengths": seq_lengths_recent, + "cumulative_length": keys_recent.shape[2], + }, + ) + elif self.current_mode == Mode.INFERENCE: + assert seq_lengths is None, "seq_lengths is not None in inference mode" + assert cumulative_length == 1, f"cumulative_length: {cumulative_length} != 1" + assert self.cont_cache.cumulative_length == 0 + + self.paged_cache.update( + key_states=key_states, + value_states=value_states, + cache_kwargs={ + "eviction_info": eviction_info, + "sequence_lengths": None, + "cumulative_length": 1, + }, + ) + else: + raise ValueError(f"Invalid mode: {self.current_mode}") + + def prefill_mode(self): + """Switch to prefill mode.""" + if self.current_mode == Mode.PREFILL: + pass + elif self.current_mode == Mode.INFERENCE: + # Revert last self.window_size keys and values to contiguous cache + raise NotImplementedError("Cannot revert to prefill mode from inference mode") + elif self.current_mode == Mode.START: + pass + else: + raise ValueError(f"Invalid mode: {self.current_mode}") + + self.paged_cache.enable_prefill_mode() + self.current_mode = Mode.PREFILL + + def inference_mode(self): + """Switch to inference mode.""" + if self.current_mode == Mode.INFERENCE: + pass + elif self.current_mode == Mode.PREFILL: + key_states, value_states, seq_lengths, eviction_info = self.cont_cache.get_cache() + + self.current_mode = Mode.INFERENCE + + self.paged_cache.disable_prefill_mode(disable_eviction=self.disable_eviction) + + self.paged_cache.update( + key_states=key_states, + value_states=value_states, + cache_kwargs={ + "eviction_info": eviction_info, + "sequence_lengths": seq_lengths, + "cumulative_length": self.cont_cache.cumulative_length, + }, + ) + self.cont_cache.reset() + + elif self.current_mode == Mode.START: + self.current_mode = Mode.INFERENCE + else: + raise ValueError(f"Invalid mode: {self.current_mode}") + + def start_mode(self): + """Assert that the cache is in start mode.""" + assert self.current_mode == Mode.START, ( + f"current_mode: {self.current_mode} is not Mode.START" + ) + + def get_recent_cache(self): + """Get the recent contiguous cache contents.""" + assert self.current_mode == Mode.PREFILL, ( + f"current_mode: {self.current_mode} is not Mode.PREFILL" + ) + return self.cont_cache.get_cache() + + def get_recent_cache_csize(self): + """Get the cumulative length of the recent cache.""" + return self.cont_cache.cumulative_length + + def get_paged_cache_csize(self): + """Get the cumulative length of the paged cache.""" + return self.paged_cache.cumulative_length + + def get_paged_cache(self): + """Get the paged cache layer.""" + return self.paged_cache + + def is_inference_mode(self): + """Check if in inference mode.""" + return self.current_mode == Mode.INFERENCE + + def is_prefill_mode(self): + """Check if in prefill mode.""" + return self.current_mode == Mode.PREFILL + + def is_start_mode(self): + """Check if in start mode.""" + return self.current_mode == Mode.START + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Return the length and offset of the cache, used to generate the mask.""" + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object.""" + return self.paged_cache.max_context_length + + +# ============================================================================= +# HuggingFace-compatible DMS cache wrapper +# ============================================================================= + + +class DMSCache(Cache): + """HuggingFace Cache implementation for DMS with combined cache layers.""" + + def __init__( + self, + dms_window_size: int, + max_context_length: int, + offloading: bool = False, + offload_only_non_sliding: bool = False, + accommodate_min_initial_context_length: int = 4096, + disable_eviction: bool = False, + block_size: int = 256, + ): + """Initialize the DMS cache.""" + super().__init__( + layer_class_to_replicate=functools.partial( + DMSCombinedCacheLayer, + dms_window_size=dms_window_size, + max_context_length=max_context_length, + accommodate_min_initial_context_length=accommodate_min_initial_context_length, + disable_eviction=disable_eviction, + block_size=block_size, + ), + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + + self.current_mode = Mode.START + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]: + """Convert to legacy cache format (not supported).""" + raise NotImplementedError("Not Supported") + + def _match_single_layer_mode(self, layer: DMSCombinedCacheLayer): + if self.current_mode == Mode.PREFILL: + layer.prefill_mode() + elif self.current_mode == Mode.INFERENCE: + layer.inference_mode() + elif self.current_mode == Mode.START: + layer.start_mode() + else: + raise ValueError(f"Invalid mode: {self.current_mode}") + + def _match_all_layers_mode(self): + for layer in self.layers: + assert isinstance(layer, DMSCombinedCacheLayer) + self._match_single_layer_mode(layer) + + def prefill_mode(self): + """Set all layers to prefill mode.""" + self.current_mode = Mode.PREFILL + self._match_all_layers_mode() + + def inference_mode(self): + """Set all layers to inference mode.""" + self.current_mode = Mode.INFERENCE + self._match_all_layers_mode() + + def is_prefill_mode(self): + """Check if in prefill mode.""" + return self.current_mode == Mode.PREFILL + + def is_inference_mode(self): + """Check if in inference mode.""" + return self.current_mode == Mode.INFERENCE + + def is_start_mode(self): + """Check if in start mode.""" + return self.current_mode == Mode.START + + @classmethod + def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]): + """Create from legacy cache (not supported).""" + raise NotImplementedError("Not Supported") + + def early_initialization( + self, + batch_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + ): + """Perform early initialization (not supported).""" + return None # not supported + + def __iter__(self): + raise NotImplementedError("Not Supported") + + def __getitem__(self, layer_idx: int): + while layer_idx >= len(self.layers): + self.layers.append(self.layer_class_to_replicate()) + self._match_single_layer_mode(self.layers[-1]) + return self.layers[layer_idx] + + def __setitem__(self, layer_idx: int, value: DMSCombinedCacheLayer): + while layer_idx >= len(self.layers): + self.layers.append(self.layer_class_to_replicate()) + self._match_single_layer_mode(self.layers[-1]) + self.layers[layer_idx] = value + self._match_single_layer_mode(self.layers[layer_idx]) + + def get_cr(self, get_per_layer_cr: bool = False) -> float | tuple[float, list[float]]: + """Compute the compression ratio across all cache layers.""" + per_elem = [] + for layer in self.layers: + assert isinstance(layer, DMSCombinedCacheLayer) + cum_seq_len = layer.get_seq_length() + if layer.paged_cache.cache_seq_lengths is None: + return 1.0 + sizes = layer.paged_cache.cache_seq_lengths.cpu() + + frac = sizes / max(cum_seq_len, 1) + per_elem.append(frac) + + per_elem = torch.stack(per_elem, dim=0) + total_cr = 1 / per_elem.mean() + + if get_per_layer_cr: + per_layer_cr = 1 / per_elem.mean(dim=-1) + return total_cr.item(), per_layer_cr.tolist() + else: + return total_cr.item() diff --git a/experimental/dms/dms/cache_paged.py b/experimental/dms/dms/cache_paged.py new file mode 100644 index 0000000000..2cd4df9de7 --- /dev/null +++ b/experimental/dms/dms/cache_paged.py @@ -0,0 +1,1054 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Paged DMS cache layer implementation with block-based memory management.""" + +import gc +import math +from typing import Any + +import torch +from transformers import CacheLayerMixin + + +def ceil_int_div(a: int, b: int) -> int: + """Return the ceiling of integer division a / b.""" + return (a + b - 1) // b + + +def float_ceil(a: float): + """Return the ceiling of a float value.""" + return math.ceil(a) + + +def _aux_potential_eviction( + vals_for_replacement: torch.Tensor, + to_be_evicted_table_block_id: torch.Tensor, + to_be_evicted_position_within_block: torch.Tensor, + to_be_evicted_mask: torch.Tensor, + block_table: torch.Tensor, + blocks: torch.Tensor, + page_batch_index: torch.Tensor, + last_table_block_id: torch.Tensor, + next_position_within_block: torch.Tensor, +): + """Adding a new element to KV cache may lead to eviction of the last element in the DMS sliding window.""" + # For each batch element the block table contains a list of blocks allocated for this batch element + block_ids = block_table[page_batch_index, to_be_evicted_table_block_id] + + # Override the last element of the sliding window with the new element if the last element of the sliding window + # is marked for the eviction and the window is full + blocks[block_ids, to_be_evicted_position_within_block, :, :] = ( + blocks[block_ids, to_be_evicted_position_within_block, :, :] + * (1 - to_be_evicted_mask[:, None, None]) + + vals_for_replacement[:, 0, None, :] * to_be_evicted_mask[:, None, None] + ) + + # Otherwise write the new element to the next position within the last allocated block + block_ids = block_table[page_batch_index, last_table_block_id] + blocks[block_ids, next_position_within_block, :, :] = blocks[ + block_ids, next_position_within_block, :, : + ] * to_be_evicted_mask[:, None, None] + vals_for_replacement[:, 0, None, :] * ( + 1 - to_be_evicted_mask[:, None, None] + ) + + +def _aux_no_eviction( + vals_for_replacement: torch.Tensor, + block_table: torch.Tensor, + blocks: torch.Tensor, + page_batch_index: torch.Tensor, + last_table_block_id: torch.Tensor, + next_position_within_block: torch.Tensor, +): + """Adding new element to kv cache without eviction of the last element.""" + # otherwise write the new element to the next position within the last + # allocated block + block_ids = block_table[page_batch_index, last_table_block_id] + blocks[block_ids, next_position_within_block, :, :] = vals_for_replacement[:, 0, None, :] + + +@torch.compile() +def _aux_update_single( + key_states: torch.Tensor, + value_states: torch.Tensor, + eviction_info: torch.Tensor, + recent_info: torch.Tensor | None, + recent_info_position: torch.Tensor | None, + block_table: torch.Tensor, + key_blocks: torch.Tensor, + value_blocks: torch.Tensor, + cache_seq_lengths: torch.Tensor, + page_batch_index: torch.Tensor, +) -> torch.Tensor: + """Updates the paged cache during token by token generation.""" + # page_batch, seq_len, head_dim = key_states.size() + # page_batch, seq_len = eviction_info.size() + # page_batch_index is a tensor of shape (page_batch,): 0, 1, 2, ... page_batch - 1 + block_size = key_blocks.size(1) + + last_table_block_id = cache_seq_lengths // block_size + next_position_within_block = cache_seq_lengths % block_size + + # `recent_info_position` points to the next position in the sliding window; when the sliding window is full, + # it points to the first position. Not-filled elements are not zeroed out and not marked for eviction + # (see recent_info initialization). + if recent_info is not None: # DMS eviction is enabled + assert recent_info_position is not None + eviction_candidate_info_position = recent_info_position % recent_info.size(1) + + eviction_candidate_info = recent_info[ + page_batch_index, eviction_candidate_info_position + ] # Note that this is zeroed out in the beginning + + # `eviction_candidate_info[:, 1]` is 1 when the element is marked for eviction and 0 otherwise + # `block_table[eviction_candidate_info[:, 0] // block_size]` is the block id where the element resides + # and `eviction_candidate_info[:, 0] % block_size` is the position (offset) within the block + to_be_evicted = eviction_candidate_info[:, 1] == 1 + to_be_evicted_kv = to_be_evicted.to(key_blocks.dtype) + to_be_evicted_int = to_be_evicted.to(torch.int32) + to_be_evicted_position = eviction_candidate_info[:, 0] + to_be_evicted_table_block_id = to_be_evicted_position // block_size + to_be_evicted_position_within_block = to_be_evicted_position % block_size + + _aux_potential_eviction( + vals_for_replacement=key_states, + to_be_evicted_table_block_id=to_be_evicted_table_block_id, + to_be_evicted_position_within_block=to_be_evicted_position_within_block, + to_be_evicted_mask=to_be_evicted_kv, + block_table=block_table, + blocks=key_blocks, + page_batch_index=page_batch_index, + last_table_block_id=last_table_block_id, + next_position_within_block=next_position_within_block, + ) + + _aux_potential_eviction( + vals_for_replacement=value_states, + to_be_evicted_table_block_id=to_be_evicted_table_block_id, + to_be_evicted_position_within_block=to_be_evicted_position_within_block, + to_be_evicted_mask=to_be_evicted_kv, + block_table=block_table, + blocks=value_blocks, + page_batch_index=page_batch_index, + last_table_block_id=last_table_block_id, + next_position_within_block=next_position_within_block, + ) + + final_position = to_be_evicted_position * to_be_evicted_int + (1 - to_be_evicted_int) * ( + cache_seq_lengths + ) + + previous_recent_info_position = ( + recent_info_position + recent_info.size(1) - 1 + ) % recent_info.size(1) + + # Update the eviction info for the previous element in the sliding window (if present) + recent_info[page_batch_index, previous_recent_info_position, 1] = ( + eviction_info[:, 0] * (cache_seq_lengths > 0).to(torch.int32) + ).to(torch.int32) + + # No info about eviction yet for the new element + recent_info[page_batch_index, eviction_candidate_info_position, 1] = 0 + recent_info[page_batch_index, eviction_candidate_info_position, 0] = final_position + + recent_info_position[...] += 1 + + cache_seq_lengths[...] = cache_seq_lengths + (1 - to_be_evicted_int) + + # At the beginning of this function call block_table[cache_seq_lengths // block_size] points to a block with + # at least one free position; need to maintain this invariant by detecting filled blocks + requires_free_page = torch.logical_and( + (cache_seq_lengths % block_size) == 0, to_be_evicted_int == 0 + ) + + else: # DMS eviction is disabled + _aux_no_eviction( + vals_for_replacement=key_states, + block_table=block_table, + blocks=key_blocks, + page_batch_index=page_batch_index, + last_table_block_id=last_table_block_id, + next_position_within_block=next_position_within_block, + ) + _aux_no_eviction( + vals_for_replacement=value_states, + block_table=block_table, + blocks=value_blocks, + page_batch_index=page_batch_index, + last_table_block_id=last_table_block_id, + next_position_within_block=next_position_within_block, + ) + cache_seq_lengths[...] = cache_seq_lengths + 1 + + requires_free_page = (cache_seq_lengths % block_size) == 0 + + return requires_free_page + + +def _aux_write_kv( + block_table: torch.Tensor, + blocks: torch.Tensor, + write_positions: torch.Tensor, + values: torch.Tensor, + page_batch_index: torch.Tensor, +): + _page_batch, _chunk_len = write_positions.size() + block_size = blocks.size(1) + block_table_id = write_positions // block_size + position_within_block = write_positions % block_size + + block_id = block_table[page_batch_index[:, None], block_table_id] + assert (block_id != -1).all(), f"block_id: {block_id} is -1" + + blocks[block_id, position_within_block, :, :] = values[:, :, None, :] + + +@torch.compile() +def _aux_update_many_handle_single_chunk( + update_key_chunk: torch.Tensor, + update_value_chunk: torch.Tensor, + eviction_info_chunk: torch.Tensor, + block_table: torch.Tensor, + key_blocks: torch.Tensor, + value_blocks: torch.Tensor, + cache_seq_lengths: torch.Tensor, + is_non_empty: torch.Tensor, + recent_info: torch.Tensor | None, + recent_info_position: torch.Tensor | None, + page_batch_index: torch.Tensor, + update_mask: torch.Tensor, + true_update_size: torch.Tensor, +) -> torch.Tensor: + """Used for prefilling the KV cache as each tensor has a fixed size. + + `true_update_size` represents the true number of elements to be added for each batch index. + """ + assert update_key_chunk.size() == update_value_chunk.size(), ( + f"update_key_chunk.size: {update_key_chunk.size()} != update_value_chunk.size: {update_value_chunk.size()}" + ) + page_batch, chunk_len, _head_dim = update_key_chunk.size() + assert recent_info is None or chunk_len < recent_info.size(1), ( + f"recent_info {recent_info.shape} {chunk_len}" + ) + + assert eviction_info_chunk.size() == (page_batch, chunk_len), ( + f"eviction_info_chunk.size: {eviction_info_chunk.size()} != (page_batch, chunk_len): {(page_batch, chunk_len)}" + ) + assert page_batch_index.size() == (page_batch,), ( + f"page_batch_index.size: {page_batch_index.size()} != (page_batch,): {(page_batch,)}" + ) + + block_size = key_blocks.size(1) + + device = update_key_chunk.device + + chunk_indexer = torch.arange(chunk_len, dtype=torch.int32, device=device) + + if recent_info is not None: # DMS eviction is enabled + assert recent_info_position is not None + # First we update the eviction info for the previous element if present + update_eviction_info_positions = (recent_info_position - 1) % recent_info.size(1) + update_eviction_info_mask = (cache_seq_lengths > 0).to(torch.int32) + + recent_info[page_batch_index, update_eviction_info_positions, 1] = ( + eviction_info_chunk[:, 0] * update_eviction_info_mask + + (1 - update_eviction_info_mask) + * recent_info[page_batch_index, update_eviction_info_positions, 1] + ).to(torch.int32) + + # The following trick handles variable lens: if the index is longer than true_update_size, then pad the index + # with the last element within the true_update_size, e.g., [0, 1, 2, 3, 4, 5] and true_update_size = [3] + # means that we have [0, 1, 2, 2, 2, 2] . This will later be used to write the same element multiple times + # while preserving the constant shapes of the tensors. + + potential_eviction_positions_in_recent_info = ( + recent_info_position[:, None] + + torch.minimum(chunk_indexer[None, :], true_update_size[:, None] - 1) + ) % recent_info.size(1) + + potential_eviction_positions_in_seq = recent_info[ + page_batch_index[:, None], potential_eviction_positions_in_recent_info, 0 + ] + confirmed_evictions_mask = ( + recent_info[ + page_batch_index[:, None], + potential_eviction_positions_in_recent_info, + 1, + ] + == 1 + ) + + confirmed_evictions_mask = torch.logical_and( + confirmed_evictions_mask, is_non_empty[:, None] + ) + + # Account for the padding with the last element (as described above) + # to get a proper count of confirmed evictions + confirmed_evictions_mask[:, 1:] = torch.logical_and( + confirmed_evictions_mask[:, 1:], + potential_eviction_positions_in_recent_info[:, 1:] + != potential_eviction_positions_in_recent_info[:, :-1], + ) + + confirmed_evictions_cum_sum = confirmed_evictions_mask.to(torch.int32).cumsum(dim=-1) + confirmed_evictions_mask = torch.logical_and( + confirmed_evictions_mask, + confirmed_evictions_cum_sum <= true_update_size[:, None], + ) + + # Count how many new positions are needed for each element of the batch + num_confirmed_evictions = confirmed_evictions_mask.to(torch.int32).sum(dim=-1) + new_positions_used = true_update_size - num_confirmed_evictions + + assert (new_positions_used >= 0).all(), ( + f"new_positions_used: {new_positions_used} is less than 0" + ) + assert new_positions_used.size() == (page_batch,), ( + f"new_positions_used.size: {new_positions_used.size()} != (page_batch,): {(page_batch,)}" + ) + + new_free_positions = cache_seq_lengths[:, None] + torch.clamp( + torch.minimum(chunk_indexer[None, :], new_positions_used[:, None] - 1), + min=0, + ) + + assert new_free_positions.size() == (page_batch, chunk_len), ( + f"new_free_positions.size: {new_free_positions.size()}" + f" != (page_batch, chunk_len): {(page_batch, chunk_len)}" + ) + assert new_free_positions.size() == potential_eviction_positions_in_seq.size(), ( + f"new_free_positions.size: {new_free_positions.size()}" + f" != potential_eviction_positions_in_seq.size: {potential_eviction_positions_in_seq.size()}" + ) + + potential_eviction_positions_in_seq = torch.cat( + [ + potential_eviction_positions_in_seq, + new_free_positions, + ], + dim=-1, + ) + + # Padding below allows for constant shape ops to take prefix + # of length new_positions_used from new_free_positions + confirmed_evictions_padding = torch.zeros_like(confirmed_evictions_mask) + padding_chunk_size = chunk_len - num_confirmed_evictions[:, None] + indexer = torch.minimum(chunk_indexer[None, :], torch.clamp(padding_chunk_size - 1, min=0)) + + confirmed_evictions_padding[page_batch_index[:, None], indexer] = True + # If only post eviction positions are used, then have writing padding that ends in the last of those positions, + # instead of the next free position + confirmed_evictions_padding = torch.logical_and( + confirmed_evictions_padding, padding_chunk_size > 0 + ) + + confirmed_evictions_mask = torch.cat( + [confirmed_evictions_mask, confirmed_evictions_padding], dim=-1 + ) + + pad_selector = (new_positions_used > 0).to(torch.int32)[:, None] + + potential_eviction_positions_in_seq[:, chunk_len:] = ( + pad_selector * potential_eviction_positions_in_seq[:, chunk_len:] + + (1 - pad_selector) * potential_eviction_positions_in_seq[:, [chunk_len - 1]] + ) + + new_write_positions = potential_eviction_positions_in_seq[confirmed_evictions_mask].reshape( + page_batch, chunk_len + ) + + # Always perform dummy write for empty sequences + new_write_positions = new_write_positions * is_non_empty[:, None] + cache_seq_lengths[ + :, None + ] * (~is_non_empty[:, None]) + + _aux_write_kv( + block_table=block_table, + blocks=key_blocks, + write_positions=new_write_positions, + values=update_key_chunk, + page_batch_index=page_batch_index, + ) + + _aux_write_kv( + block_table=block_table, + blocks=value_blocks, + write_positions=new_write_positions, + values=update_value_chunk, + page_batch_index=page_batch_index, + ) + + recent_indexer = torch.minimum( + chunk_indexer[None, :], torch.clamp(true_update_size[:, None] - 1, min=0) + ) + + recent_info_indexer = (recent_info_position[:, None] + recent_indexer) % recent_info.size(1) + + # update the info about last window positions + + non_empty_update = (true_update_size[:, None] > 0).to(torch.int32) + + recent_info[page_batch_index[:, None], recent_info_indexer, 0] = ( + new_write_positions * non_empty_update + + recent_info[page_batch_index[:, None], recent_info_indexer, 0] + * (1 - non_empty_update) + ).to(torch.int32) + + eviction_info_chunk = torch.cat( + [ + eviction_info_chunk[:, 1:], + torch.zeros_like(eviction_info_chunk[:, [0]]), + ], + dim=-1, + ) + recent_info[page_batch_index[:, None], recent_info_indexer, 1] = ( + eviction_info_chunk[:, :] * non_empty_update + + recent_info[page_batch_index[:, None], recent_info_indexer, 1] + * (1 - non_empty_update) + ).to(torch.int32) + + recent_info_position[...] += true_update_size + + cache_seq_lengths[...] += new_positions_used + + require_free_pages = torch.logical_and( + new_positions_used > 0, cache_seq_lengths % block_size == 0 + ) + else: + new_write_positions = cache_seq_lengths[:, None] + torch.clamp( + torch.minimum(chunk_indexer[None, :], true_update_size[:, None] - 1), min=0 + ) + + _aux_write_kv( + block_table=block_table, + blocks=key_blocks, + write_positions=new_write_positions, + values=update_key_chunk, + page_batch_index=page_batch_index, + ) + + _aux_write_kv( + block_table=block_table, + blocks=value_blocks, + write_positions=new_write_positions, + values=update_value_chunk, + page_batch_index=page_batch_index, + ) + + cache_seq_lengths[...] += true_update_size + + require_free_pages = torch.logical_and( + true_update_size > 0, cache_seq_lengths % block_size == 0 + ) + + return require_free_pages + + +class DMSPagedCacheLayer(CacheLayerMixin): + """Paged cache layer with block-based storage and optional DMS eviction.""" + + def __init__( + self, + dms_window_size: int, + max_context_length: int, + block_size: int = 256, + growth_factor: float = 1.5, + accommodate_min_initial_context_length: int = 4096, + disable_eviction: bool = False, + ): + """Initialize the paged cache layer.""" + super().__init__() + assert block_size <= dms_window_size, ( + f"block_size: {block_size} > dms_window_size: {dms_window_size}" + ) + self.block_size = block_size + self.dms_window_size = dms_window_size + self.prefill_chunk_size = max(self.dms_window_size - 2, block_size) + assert self.prefill_chunk_size > 0, ( + f"prefill_chunk_size: {self.prefill_chunk_size} is not greater than 0" + ) + self.growth_factor = growth_factor + self.min_initial_context_length = accommodate_min_initial_context_length + self.disable_eviction = disable_eviction + + self.max_context_length = max_context_length + + self.max_blocks_per_sequence = ceil_int_div(self.max_context_length, self.block_size) + + self.key_blocks = None + self.value_blocks = None + self.block_table = None + self.free_page_ids = None + self.cache_seq_lengths = None + self.recent_info = None # Position and eviction info of last window_size keys/values + self.recent_info_position = None + + self.device = None + + self.cumulative_length = 0 + + self.prefill_mode = False + + def offload(self): + """Offload cache tensors to CPU.""" + if self.key_blocks is not None: + self.key_blocks = self.key_blocks.to("cpu", non_blocking=True) + self.value_blocks = self.value_blocks.to("cpu", non_blocking=True) + self.block_table = self.block_table.to("cpu", non_blocking=True) + self.free_page_ids = self.free_page_ids.to("cpu", non_blocking=True) + self.cache_seq_lengths = self.cache_seq_lengths.to("cpu", non_blocking=True) + if self.recent_info is not None: + self.recent_info = self.recent_info.to("cpu", non_blocking=True) + self.recent_info_position = self.recent_info_position.to("cpu", non_blocking=True) + + def enable_prefill_mode(self): + """Enable prefill mode and disable eviction.""" + self.prefill_mode = True + self.disable_eviction = True + self.recent_info = None + self.recent_info_position = None + + def disable_prefill_mode(self, disable_eviction: bool): + """Disable prefill mode and optionally re-enable eviction.""" + self.prefill_mode = False + self.disable_eviction = disable_eviction + if self.key_blocks is not None and self.recent_info is None and (not disable_eviction): + self._initialize_recent_info() + + def prefetch(self): + """Prefetch cache tensors back to the original device.""" + if self.key_blocks is not None and self.key_blocks.device != self.device: + self.key_blocks = self.key_blocks.to(self.device, non_blocking=True) + self.value_blocks = self.value_blocks.to(self.device, non_blocking=True) + self.block_table = self.block_table.to(self.device, non_blocking=True) + self.free_page_ids = self.free_page_ids.to(self.device, non_blocking=True) + self.cache_seq_lengths = self.cache_seq_lengths.to(self.device, non_blocking=True) + if self.recent_info is not None: + self.recent_info = self.recent_info.to(self.device, non_blocking=True) + self.recent_info_position = self.recent_info_position.to( + self.device, non_blocking=True + ) + + def reset(self) -> None: + """Resets the cache values while preserving the objects.""" + if self.key_blocks is not None: + self.key_blocks = None + self.value_blocks = None + self.block_table = None + self.free_page_ids = None + self.cache_seq_lengths = None + self.recent_info = None + self.recent_info_position = None + gc.collect() + torch.cuda.empty_cache() + self.cumulative_length = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorders this layer's cache for beam search.""" + raise NotImplementedError("Beam search is not supported") + + def _get_free_pages(self, num_pages: int): + assert self.free_page_ids is not None + assert self.key_blocks is not None + assert self.value_blocks is not None + while len(self.free_page_ids) < num_pages: + + def expand_blocks(blocks: torch.Tensor): + return torch.cat( + [ + blocks, + torch.zeros( + ( + float_ceil(blocks.size(0) * self.growth_factor) - blocks.size(0), + blocks.size(1), + blocks.size(2), + blocks.size(3), + ), + dtype=blocks.dtype, + device=blocks.device, + ), + ], + dim=0, + ) + + old_num_blocks = self.key_blocks.size(0) + self.key_blocks = expand_blocks(self.key_blocks) + self.value_blocks = expand_blocks(self.value_blocks) + assert self.key_blocks.size(0) == self.value_blocks.size(0), ( + f"key_blocks.size: {self.key_blocks.size(0)} != value_blocks.size: {self.value_blocks.size(0)}" + ) + self.free_page_ids = torch.cat( + [ + self.free_page_ids, + torch.arange( + old_num_blocks, + self.key_blocks.size(0), + dtype=torch.int32, + device=self.device, + ), + ], + dim=0, + ) + + result = self.free_page_ids[:num_pages] + assert result.size() == (num_pages,), ( + f"result.size: {result.size()} != (num_pages,): {(num_pages,)}" + ) + self.free_page_ids = self.free_page_ids[num_pages:] + return result + + def _initialize_recent_info(self): + assert self.cache_seq_lengths is not None + self.recent_info = torch.zeros( + (self.page_batch, self.dms_window_size, 2), + dtype=torch.int32, + device=self.device, + ) + self.recent_info_position = self.cache_seq_lengths.clone() + + def lazy_initialization(self, key_states: torch.Tensor): + """Lazily initialize cache storage based on key state shape.""" + self.dtype, self.device = key_states.dtype, key_states.device + self.batch_size, self.num_heads, _, self.head_dim = key_states.shape + + self.page_batch = self.batch_size * self.num_heads + + initial_num_blocks = max( + ceil_int_div(self.min_initial_context_length, self.block_size) * self.page_batch, + self.page_batch, + ) + + self.block_table = -torch.ones( + self.page_batch, + self.max_blocks_per_sequence + 1, # +1 for handling full cache case + dtype=torch.int32, + device=self.device, + ) + self.key_blocks = torch.zeros( + (initial_num_blocks, self.block_size, 1, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + self.value_blocks = torch.zeros( + (initial_num_blocks, self.block_size, 1, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + + self.free_page_ids = torch.arange( + 0, initial_num_blocks, dtype=torch.int32, device=self.device + ) + + self.cache_seq_lengths = torch.zeros(self.page_batch, dtype=torch.int32, device=self.device) + + if not self.disable_eviction: + self._initialize_recent_info() + + assert self.block_table is not None + self.block_table[:, 0] = self._get_free_pages(self.block_table.size(0)) + + def _handle_page_allocation( + self, requires_free_page: torch.Tensor, page_batch_index: torch.Tensor + ): + assert self.block_table is not None + assert self.cache_seq_lengths is not None + if requires_free_page.any(): + req_free_pages = page_batch_index[requires_free_page] + free_pages = self._get_free_pages(len(req_free_pages)) + + self.block_table[ + req_free_pages, + self.cache_seq_lengths[req_free_pages] // self.block_size, + ] = free_pages + + def _update_single( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + eviction_info: torch.Tensor, + ): + batch_x_head, seq_len, _head_dim = key_states.size() + page_batch_index = torch.arange(batch_x_head, dtype=torch.int32, device=self.device) + + assert seq_len == 1, f"seq_len: {seq_len} != 1" + + requires_free_page = _aux_update_single( + key_states=key_states, + value_states=value_states, + eviction_info=eviction_info, + recent_info=self.recent_info, + recent_info_position=self.recent_info_position, + block_table=self.block_table, + key_blocks=self.key_blocks, + value_blocks=self.value_blocks, + cache_seq_lengths=self.cache_seq_lengths, + page_batch_index=page_batch_index, + ) + + self._handle_page_allocation( + requires_free_page=requires_free_page, page_batch_index=page_batch_index + ) + + def _update_many( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + eviction_info: torch.Tensor, + sequence_lengths: torch.Tensor, + ): + assert self.cache_seq_lengths is not None + # Assume key and value states are left padded, e.g., [_, _, _, 1, 2, 3, 4] + + page_batch, seq_len, head_dim = key_states.size() + assert page_batch == self.page_batch, ( + f"page_batch: {page_batch} != self.page_batch: {self.page_batch}" + ) + assert head_dim == self.head_dim, f"head_dim: {head_dim} != self.head_dim: {self.head_dim}" + assert eviction_info.size() == (page_batch, seq_len), ( + f"eviction_info.size: {eviction_info.size()} != (page_batch, seq_len): {(page_batch, seq_len)}" + ) + assert sequence_lengths.ndim == 1, f"sequence_lengths.ndim: {sequence_lengths.ndim} != 1" + + is_non_empty = sequence_lengths > 0 + + start_positions = seq_len - sequence_lengths + + end_positions = start_positions + sequence_lengths + + page_batch_index = torch.arange(page_batch, dtype=torch.int32, device=self.device) + + while (start_positions < end_positions).any(): + chunk_indexer = torch.arange( + self.prefill_chunk_size, dtype=torch.int32, device=self.device + )[None, :] + + update_mask = chunk_indexer < ( + self.block_size - (self.cache_seq_lengths[:, None] % self.block_size) + ) + + chunk_indexer = start_positions[:, None] + chunk_indexer + + update_mask = torch.logical_and(update_mask, chunk_indexer < end_positions[:, None]) + + chunk_indexer = torch.clamp( + torch.minimum(chunk_indexer, end_positions[:, None] - 1), min=0 + ) + + true_update_size = update_mask.to(torch.int32).sum(dim=1) + + chunk_indexer = torch.clamp( + torch.minimum( + chunk_indexer, + start_positions[:, None] + true_update_size[:, None] - 1, + ), + min=0, + ) + + key_chunk = key_states[page_batch_index[:, None], chunk_indexer] + value_chunk = value_states[page_batch_index[:, None], chunk_indexer] + eviction_info_chunk = eviction_info[page_batch_index[:, None], chunk_indexer] + + requires_free_page = _aux_update_many_handle_single_chunk( + update_key_chunk=key_chunk, + update_value_chunk=value_chunk, + eviction_info_chunk=eviction_info_chunk, + is_non_empty=is_non_empty, + block_table=self.block_table, + key_blocks=self.key_blocks, + value_blocks=self.value_blocks, + cache_seq_lengths=self.cache_seq_lengths, + recent_info=self.recent_info, + recent_info_position=self.recent_info_position, + page_batch_index=page_batch_index, + update_mask=update_mask, + true_update_size=true_update_size, + ) + + self._handle_page_allocation( + requires_free_page=requires_free_page, page_batch_index=page_batch_index + ) + + start_positions[...] += true_update_size + + def fast_update_ignore_eviction( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + sequence_lengths: torch.Tensor, + ): + """Bulk-write key-value pairs to the paged cache without eviction.""" + if self.key_blocks is None: + self.lazy_initialization(key_states) + + assert self.key_blocks is not None + assert self.value_blocks is not None + assert self.block_table is not None + assert self.cache_seq_lengths is not None + + assert self.disable_eviction, ( + "fast_update_ignore_eviction is only supported when eviction is disabled" + ) + + assert sequence_lengths.max().item() != 0, ( + f"sequence_lengths.max(): {sequence_lengths.max().item()} is 0" + ) + + batch, head, seq_len, head_dim = key_states.size() + + self.cumulative_length += seq_len + assert value_states.size() == key_states.size(), ( + f"value_states.size: {value_states.size()} != key_states.size: {key_states.size()}" + ) + + assert sequence_lengths.size() == (batch, head), ( + f"sequence_lengths.size: {sequence_lengths.size()} != (batch, head): {(batch, head)}" + ) + + key_states = key_states.reshape(self.page_batch, seq_len, head_dim) + value_states = value_states.reshape(self.page_batch, seq_len, head_dim) + sequence_lengths = sequence_lengths.reshape(self.page_batch) + + assert self.cache_seq_lengths.size() == sequence_lengths.size(), ( + f"cache_seq_lengths.size: {self.cache_seq_lengths.size()}" + f" != sequence_lengths.size: {sequence_lengths.size()}" + ) + + last_allocated_block_index = self.cache_seq_lengths // self.block_size + last_to_allocate_block_index = ( + self.cache_seq_lengths + sequence_lengths + ) // self.block_size + + per_page_batch_blocks_available = last_allocated_block_index + 1 + per_page_batch_blocks_required = last_to_allocate_block_index + 1 + + blocks_to_alloc = per_page_batch_blocks_required - per_page_batch_blocks_available + max_blocks_to_alloc = blocks_to_alloc.max().item() + + if max_blocks_to_alloc > 0: + assert blocks_to_alloc.shape == (self.page_batch,), ( + f"blocks_to_alloc.shape: {blocks_to_alloc.shape} != (self.page_batch,): {(self.page_batch,)}" + ) + assert (blocks_to_alloc >= 0).all(), ( + f"blocks_to_alloc: {blocks_to_alloc} is less than 0" + ) + + total_blocks_to_alloc = blocks_to_alloc.sum() + + free_blocks = self._get_free_pages(total_blocks_to_alloc) + + # Greedy assignment of free blocks + free_block_indexer = blocks_to_alloc.cumsum(dim=0) + + free_block_write_indexer = torch.arange( + max_blocks_to_alloc, device=self.device, dtype=torch.int32 + ) + # For writing into the block table + free_block_write_indexer = ( + last_allocated_block_index[:, None] + 1 + free_block_write_indexer[None, :] + ) + + assert free_block_write_indexer.size() == ( + self.page_batch, + max_blocks_to_alloc, + ), ( + f"free_block_write_indexer.size: {free_block_write_indexer.size()}" + f" != (self.page_batch, max_blocks_to_alloc):" + f" {(self.page_batch, max_blocks_to_alloc)}" + ) + + # +1 acts as a sink for handling sizes < max_blocks_to_alloc + # free_block_write_indexer[a, b] where to write b'th free block for a'th page batch + free_block_write_indexer = torch.minimum( + free_block_write_indexer, + last_allocated_block_index[:, None] + blocks_to_alloc[:, None] + 1, + ) + + # free_blocks[free_block_get_indexer[a, b]] is the b'th free block for a'th page batch + free_block_get_offset = torch.nn.functional.pad(free_block_indexer, (1, -1), value=0) + free_block_get_indexer = torch.arange( + max_blocks_to_alloc, device=self.device, dtype=torch.int32 + ) + free_block_get_indexer = torch.minimum( + free_block_get_indexer[None, :], blocks_to_alloc[:, None] + ) + free_block_get_indexer = free_block_get_offset[:, None] + free_block_get_indexer + free_block_get_indexer = torch.clamp( + free_block_get_indexer, max=total_blocks_to_alloc - 1 + ) + + free_block_assignment = free_blocks[free_block_get_indexer] + assert free_block_assignment.shape == (self.page_batch, max_blocks_to_alloc), ( + f"free_block_assignment.shape: {free_block_assignment.shape}" + f" != (self.page_batch, max_blocks_to_alloc):" + f" {(self.page_batch, max_blocks_to_alloc)}" + ) + + # If max_blocks_to_alloc is more than the number of blocks that we want + + mask = (free_block_write_indexer <= last_to_allocate_block_index[:, None]).to( + torch.int32 + ) + + masked_free_block_assignment = free_block_assignment * mask - ( + 1 - mask + ) * torch.ones_like(free_block_assignment) + + self.block_table.scatter_( + dim=1, index=free_block_write_indexer, src=masked_free_block_assignment + ) + + write_seq_positions = torch.arange( + seq_len, + device=self.cache_seq_lengths.device, + dtype=self.cache_seq_lengths.dtype, + ) + write_seq_positions = self.cache_seq_lengths[:, None] + write_seq_positions[None, :] + write_seq_positions = torch.minimum( + write_seq_positions, (self.cache_seq_lengths + sequence_lengths)[:, None] + ) + + source_seq_positions = torch.arange( + seq_len, + device=self.cache_seq_lengths.device, + dtype=self.cache_seq_lengths.dtype, + ) + + # Left padded input + + source_seq_positions = (seq_len - sequence_lengths)[:, None] + source_seq_positions[None, :] + source_seq_positions = torch.clamp(source_seq_positions, max=seq_len - 1) + assert source_seq_positions.size() == (self.page_batch, seq_len), ( + f"source_seq_positions.size: {source_seq_positions.size()}" + f" != (self.page_batch, seq_len): {(self.page_batch, seq_len)}" + ) + + write_block_table_ids = write_seq_positions // self.block_size + write_block_offsets = write_seq_positions % self.block_size + + assert write_block_table_ids.size() == (self.page_batch, seq_len), ( + f"write_block_table_ids.size: {write_block_table_ids.size()}" + f" != (self.page_batch, seq_len): {(self.page_batch, seq_len)}" + ) + assert write_block_offsets.size() == (self.page_batch, seq_len), ( + f"write_block_offsets.size: {write_block_offsets.size()}" + f" != (self.page_batch, seq_len): {(self.page_batch, seq_len)}" + ) + + write_block_ids = self.block_table.gather(dim=1, index=write_block_table_ids) + assert write_block_ids.size() == (self.page_batch, seq_len), ( + f"write_block_ids.size: {write_block_ids.size()}" + f" != (self.page_batch, seq_len): {(self.page_batch, seq_len)}" + ) + + source_seq_positions = source_seq_positions[:, :, None].broadcast_to( + self.page_batch, seq_len, head_dim + ) + + keys_to_write = key_states.gather(dim=1, index=source_seq_positions) + self.key_blocks[write_block_ids, write_block_offsets, 0, :] = keys_to_write + + values_to_write = value_states.gather(dim=1, index=source_seq_positions) + self.value_blocks[write_block_ids, write_block_offsets, 0, :] = values_to_write + + self.cache_seq_lengths += sequence_lengths + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: dict[str, Any], + ): + """Update the paged cache with new key-value states.""" + eviction_info = cache_kwargs["eviction_info"] + sequence_lengths = cache_kwargs["sequence_lengths"] + cumulative_length = cache_kwargs["cumulative_length"] + + if self.key_blocks is None: + self.lazy_initialization(key_states) + + batch, head, seq_len, head_dim = key_states.size() + assert key_states.size() == value_states.size(), ( + f"key_states.size: {key_states.size()} != value_states.size: {value_states.size()}" + ) + assert key_states.size()[:3] == eviction_info.size(), ( + f"key_states.size()[:3]: {key_states.size()[:3]} != eviction_info.size(): {eviction_info.size()}" + ) + assert sequence_lengths is None or sequence_lengths.size() == (batch, head), ( + f"sequence_lengths.size: {sequence_lengths.size()} != (batch, head): {(batch, head)}" + ) + + assert batch * head == self.page_batch, ( + f"batch * head: {batch * head} != self.page_batch: {self.page_batch}" + ) + assert self.head_dim == head_dim, f"self.head_dim: {self.head_dim} != head_dim: {head_dim}" + + key_states = key_states.reshape(self.page_batch, seq_len, head_dim) + value_states = value_states.reshape(self.page_batch, seq_len, head_dim) + eviction_info = eviction_info.reshape(self.page_batch, seq_len) + if sequence_lengths is not None: + sequence_lengths = sequence_lengths.reshape(self.page_batch) + + if seq_len == 1 and not self.prefill_mode: + assert sequence_lengths is None or (sequence_lengths == 1).all() + assert cumulative_length == 1 + self._update_single( + key_states=key_states, + value_states=value_states, + eviction_info=eviction_info, + ) + else: + self._update_many( + key_states=key_states, + value_states=value_states, + eviction_info=eviction_info, + sequence_lengths=sequence_lengths, + ) + + self.cumulative_length += cumulative_length + + return None, None + + def get_block_table(self): + """Get the block table mapping.""" + return self.block_table + + def get_key_blocks(self): + """Get the key cache blocks.""" + return self.key_blocks + + def get_value_blocks(self): + """Get the value cache blocks.""" + return self.value_blocks + + def get_seq_lengths(self): + """Get the sequence lengths per batch element.""" + return self.cache_seq_lengths + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + """Returns the length and offset of the cache, used to generate the mask.""" + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length() + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + def get_seq_length(self) -> int: + """Returns the sequence length of the cached states.""" + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + """Returns the maximum sequence length of the cache object.""" + return self.max_context_length diff --git a/experimental/dms/dms/core.py b/experimental/dms/dms/core.py new file mode 100644 index 0000000000..5a926956be --- /dev/null +++ b/experimental/dms/dms/core.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMS core operations: attention I/O, gating, output types, chunked prefill, and training state.""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch +from tqdm import tqdm +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from dms.cache import Mode +from dms.logging import get_logger + +logger = get_logger("DMSCore") + + +# ============================================================================= +# Setup utilities +# ============================================================================= + + +def setup_compile_limit_for_dms(compile_limit: int = 72): + """Set the torch.compile cache size limit for DMS layer compilation.""" + # we want to compile the prepare_attention_input and post_process_attention_output functions + # for each layer + if torch._dynamo.config.cache_size_limit != compile_limit: + logger.info(f"Setting up compile limit for DMS to {compile_limit}") + torch._dynamo.config.cache_size_limit = compile_limit + + +# ============================================================================= +# Training state +# ============================================================================= + + +@dataclass +class DMSTrainingStateAux: + """Auxiliary state information for DMS training. + + Attributes: + dms_teacher_mode: Whether the model is in teacher mode. + target_frac_to_close: Target fraction of DMS key-value pairs to evict. + current_step: Current training step. + grad_acc_step: Gradient accumulation step. + process_index: Index of the current process in distributed training. + noise: Tensor of noise values per layer (shape: [num_layers, ...]). + right_padding_size: Right padding size per batch item (shape: [batch_size]). + kv_cache_shape: Shape tuple for KV cache (num_layers, batch_size, num_kv_heads, seq_length). + """ + + dms_teacher_mode: bool + target_frac_to_close: float | None + current_step: int + grad_acc_step: int + process_index: int + noise: torch.Tensor | None + right_padding_size: int + kv_cache_shape: tuple[int, int, int, int] # num_layers, batch_size, num_kv_heads, seq_length + + +# ============================================================================= +# Output dataclasses +# ============================================================================= + + +@dataclass +class DMSBaseModelOutputWithPastAndCR(BaseModelOutputWithPast): + """DMS base model output with compression ratio. + + Args: + cr: (`float`, *optional*, returned when DMS cache is used): + Compression ratio, that is size of cache without compression + divided by size of cache with compression + dms_frac_closed: (`torch.Tensor`, *optional*, returned when DMS cache is used): + Per head average number of tokens (soft) evicted by DMS, used for DMS loss computation. + """ + + cr: float | None = None + dms_frac_closed: torch.Tensor | None = None + + +@dataclass +class DMSCausalLMOutputWithPastAndCR(CausalLMOutputWithPast): + """DMS causal LM output with compression ratio. + + Args: + cr: (`float`, *optional*, returned when DMS cache is used): + Compression ratio, that is size of cache without compression + divided by size of cache with compression + dms_frac_closed: (`torch.Tensor`, *optional*, returned when DMS cache is used): + Per head average number of tokens (soft) evicted by DMS, used for DMS loss computation. + """ + + cr: float | None = None + dms_loss: torch.Tensor | None = None + dms_frac_closed: torch.Tensor | None = None + + +# ============================================================================= +# Attention input preparation and output processing +# ============================================================================= + + +@torch.compile() +def prepare_attention_input( + pre_attn_norm_hidden_states: torch.Tensor, + post_attn_norm_hidden_states: torch.Tensor, + q_proj_fn: torch.nn.Linear, + k_proj_fn: torch.nn.Linear, + v_proj_fn: torch.nn.Linear, + q_norm_fn: torch.nn.Module, + k_norm_fn: torch.nn.Module, + head_dim: int, + cos: torch.Tensor, + sin: torch.Tensor, + dms_proj_alpha_norm_fn: torch.nn.Module | None, + dms_proj_alpha_fn: torch.nn.Linear | None, + dms_alpha_per: str, + dms_decision_scale: float, + dms_initial_decision_offset: float, + dms_training: bool, + dms_disable_eviction: bool, + dms_tau: float, + apply_rotary_pos_emb_fn: Callable, + dms_teacher_mode: bool, + dms_noise: torch.Tensor | None = None, +): + """Prepare query, key, value, and DMS decision tensors for attention.""" + batch, seq_len, _hidden_dim = pre_attn_norm_hidden_states.size() + + query_states = q_norm_fn( + q_proj_fn(post_attn_norm_hidden_states).view(batch, seq_len, -1, head_dim).transpose(1, 2) + ) + + key_states = k_norm_fn( + k_proj_fn(post_attn_norm_hidden_states).view(batch, seq_len, -1, head_dim).transpose(1, 2) + ) + value_states = ( + v_proj_fn(post_attn_norm_hidden_states).view(batch, seq_len, -1, head_dim).transpose(1, 2) + ) + + _, num_q_heads, _, _ = query_states.size() + _, num_kv_heads, _, _ = key_states.size() + gqa_factor = num_q_heads // num_kv_heads + + if dms_proj_alpha_fn is None: + assert dms_proj_alpha_norm_fn is None, ( + "dms_proj_alpha_norm_fn is not None when dms_proj_alpha_fn is None" + ) + decision_logits = ( + query_states[:, ::gqa_factor, :, -1].clone() * dms_decision_scale + - dms_initial_decision_offset + ) + assert decision_logits.shape == (batch, num_kv_heads, seq_len), ( + f"decision_logits.shape: {decision_logits.shape} != {(batch, num_kv_heads, seq_len)}" + ) + + query_states[:, ::gqa_factor, :, -1] = 0 + query_states, key_states = apply_rotary_pos_emb_fn(query_states, key_states, cos, sin) + query_states[:, ::gqa_factor, :, -1] = 0 + else: + assert dms_proj_alpha_norm_fn is not None, ( + "dms_proj_alpha_norm_fn is None when dms_proj_alpha_fn is not None" + ) + decision_logits = ( + dms_proj_alpha_fn(dms_proj_alpha_norm_fn(pre_attn_norm_hidden_states)) + * dms_decision_scale + - dms_initial_decision_offset + ) + assert decision_logits.shape == (batch, seq_len, num_kv_heads), ( + f"decision_logits.shape: {decision_logits.shape} != {(batch, seq_len, num_kv_heads)}" + ) + decision_logits = decision_logits.transpose(1, 2) + + query_states, key_states = apply_rotary_pos_emb_fn(query_states, key_states, cos, sin) + + if dms_training and not dms_teacher_mode: + assert dms_noise is not None, "dms_noise is None when dms_training and not dms_teacher_mode" + dms_noise = dms_noise.to(decision_logits.device) + _probs, decisions, decision_logits = get_gating_with_noise( + gating_weights=decision_logits, noise=dms_noise, tau=dms_tau + ) + else: + decisions = (decision_logits > 0).to(decision_logits.dtype) + assert decisions.shape == (batch, num_kv_heads, seq_len), ( + f"decisions.shape: {decisions.shape} != {(batch, num_kv_heads, seq_len)}" + ) + + if dms_alpha_per == "head": + decisions = decisions.broadcast_to(batch, num_kv_heads, seq_len) + decision_logits = decision_logits.broadcast_to(batch, num_kv_heads, seq_len) + elif dms_alpha_per == "layer": + decisions = decisions[:, [0], :].broadcast_to(batch, num_kv_heads, seq_len) + decision_logits = decision_logits[:, [0], :].broadcast_to(batch, num_kv_heads, seq_len) + else: + raise ValueError(f"Invalid dms_alpha_per: {dms_alpha_per}") + + flash_attn_query_states = query_states.reshape( + batch * num_kv_heads, gqa_factor, seq_len, head_dim + ).transpose(1, 2) + + if dms_disable_eviction or dms_teacher_mode: + decisions = torch.zeros_like(decisions) + decision_logits = torch.full_like(decision_logits, fill_value=-1000.0) + + return flash_attn_query_states, key_states, value_states, decisions, decision_logits + + +@torch.compile() +def post_process_attention_output( + attn_output: torch.Tensor, + o_proj: torch.nn.Linear, +): + """Reshape attention output and apply output projection.""" + batch, heads_kv, seq_len_q, gqa_factor, head_dim = attn_output.size() + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape( + batch, seq_len_q, heads_kv * gqa_factor * head_dim + ).contiguous() + + attn_output = o_proj(attn_output) + + return attn_output + + +def get_gating_with_noise(gating_weights: torch.Tensor, noise: torch.Tensor, tau: float): + """Apply Gumbel noise to gating weights and return discretized decisions.""" + assert gating_weights.shape == noise.shape, ( + f"gating_weights.shape: {gating_weights.shape} != noise.shape {noise.shape}" + ) + + logits = (gating_weights + noise) / tau + probs = torch.nn.functional.sigmoid(logits) + + discretized = (probs > 0.5).to(probs.dtype) - probs.detach() + probs + + return probs, discretized, logits + + +# ============================================================================= +# Chunked prefill +# ============================================================================= + + +def run_decoder_layers( + decoder_layers: list[torch.nn.Module], + hidden_states: torch.Tensor, + **kwargs: Any, +): + """Pass hidden states through decoder layers. + + Returns the final hidden states along with + the per head average number of tokens (soft) evicted by DMS. + """ + acc_dms_frac_closed = 0 + num_layers = 0 + for dl in decoder_layers: + hidden_states, dms_frac_closed = dl( + hidden_states, + **kwargs, + ) + if dms_frac_closed is not None: + num_layers += 1 + acc_dms_frac_closed += dms_frac_closed + + # NOTE: assumption that each attention enabled layer has the same number of attention heads + dms_frac_closed = acc_dms_frac_closed / num_layers if num_layers > 0 else None + return hidden_states, dms_frac_closed + + +def dms_perform_chunked_prefill( + decoder_layers: list[torch.nn.Module], + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None, + position_ids: torch.Tensor | None, + past_key_values: Any, + use_cache: bool, + cache_position: torch.Tensor | None, + position_embeddings: torch.Tensor | None, + dms_manual_inference_mode: bool, + dms_chunked_prefill: int | None, + **kwargs: Any, +): + """Used to chunk the input for transformer decoder layers. + + At this point in transformers most elements (masks, embeddings) + are already constructed. + """ + batch, seq_len, _hidden_dim = hidden_states.size() + + assert attention_mask is None or attention_mask.ndim == 4, ( + f"attention_mask.ndim: {attention_mask.ndim}" + ) + + if not dms_manual_inference_mode: + if ( + seq_len == 1 + and past_key_values is not None + and len(past_key_values) > 0 + and past_key_values[0].get_seq_length() > 0 + ): + if past_key_values.current_mode != Mode.INFERENCE: + past_key_values.inference_mode() + logger.debug( + f"Setting inference mode for past_key_values with cr: {past_key_values.get_cr()}" + ) + + elif past_key_values is not None: + logger.debug( + f"Setting prefill mode for past_key_values with seq_length {past_key_values[0].get_seq_length()}" + ) + past_key_values.prefill_mode() + + if seq_len > 1 and dms_chunked_prefill is not None: + num_chunks = (seq_len + dms_chunked_prefill - 1) // dms_chunked_prefill + + hidden_states_chunks = [] + + for chid in tqdm( + range(num_chunks), + desc=f"Chunked prefill for batch_size:{batch} seq_len:{seq_len} chunk_size:{dms_chunked_prefill}", + ): + start_pos = chid * dms_chunked_prefill + end_pos = min(start_pos + dms_chunked_prefill, seq_len) + hidden_states_chunk = hidden_states[:, start_pos:end_pos, :] + unpadded_chunk_len = hidden_states_chunk.shape[1] + + if attention_mask is not None: + # take attention mask from the last query + assert attention_mask.shape[-1] == seq_len, ( + f"attention_mask.shape[-1]: {attention_mask.shape[-1]} != {seq_len}" + ) + attention_mask_chunk = attention_mask[:, :, [-1], start_pos:end_pos] + else: + attention_mask_chunk = None + + if position_ids is not None: + assert position_ids.shape[-1] == seq_len, ( + f"position_ids.shape[-1]: {position_ids.shape[-1]} != {seq_len}" + ) + position_ids_chunk = position_ids[..., start_pos:end_pos] + else: + position_ids_chunk = None + + if cache_position is not None: + assert cache_position.shape[-1] == seq_len, ( + f"cache_position.shape[-1]: {cache_position.shape[-1]} != {seq_len}" + ) + cache_position_chunk = cache_position[..., start_pos:end_pos] + else: + cache_position_chunk = None + + if position_embeddings is not None: + assert isinstance(position_embeddings, tuple) + for e in position_embeddings: + assert e.shape[1] == seq_len, f"e.shape[1]: {e.shape[1]} != {seq_len}" + position_embeddings_chunk = tuple( + e[:, start_pos:end_pos, :] for e in position_embeddings + ) + else: + position_embeddings_chunk = None + + hidden_states_chunk, _dms_loss = run_decoder_layers( + decoder_layers=decoder_layers, + hidden_states=hidden_states_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids_chunk, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position_chunk, + position_embeddings=position_embeddings_chunk, + **kwargs, + ) + hidden_states_chunks.append(hidden_states_chunk[:, :unpadded_chunk_len, :]) + hidden_states_chunks = torch.cat(hidden_states_chunks, dim=1) + + return hidden_states_chunks, None + else: + return run_decoder_layers( + decoder_layers=decoder_layers, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) diff --git a/experimental/dms/dms/logging.py b/experimental/dms/dms/logging.py new file mode 100644 index 0000000000..72152ddd44 --- /dev/null +++ b/experimental/dms/dms/logging.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Logging utilities for DMS.""" + +import logging +import sys + + +def get_logger(name: str) -> logging.Logger: + """Get or create a logger with the specified name and DMS formatting.""" + logger = logging.getLogger(name) + + # Only configure if logger doesn't already have handlers (avoid duplicates) + if not logger.handlers: + logger.setLevel(logging.INFO) + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter( + logging.Formatter("[%(asctime)s] DMS %(name)s [%(levelname)s]: %(message)s") + ) + logger.addHandler(handler) + + return logger diff --git a/experimental/dms/dms/training/__init__.py b/experimental/dms/dms/training/__init__.py new file mode 100644 index 0000000000..9035bc54f4 --- /dev/null +++ b/experimental/dms/dms/training/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMS training infrastructure: data pipeline, distillation, and trainer utilities.""" diff --git a/experimental/dms/dms/training/data.py b/experimental/dms/dms/training/data.py new file mode 100644 index 0000000000..18527f0697 --- /dev/null +++ b/experimental/dms/dms/training/data.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data pipeline for DMS training: dataset loading, tokenization, blending, and concatenation. + +To add a new dataset: +1. Define a filter_fn and extract_fn for your dataset +2. Create a DatasetInfo instance at the bottom of this file +3. Reference it by name in your YAML config's data.blend field + (e.g. "MyNewDataset:0.5,OpenR1Math220k:0.5") +""" + +import hashlib +import json +from collections.abc import Callable +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset, load_dataset, load_from_disk +from tqdm import tqdm +from transformers import PreTrainedTokenizerBase + +from dms.logging import get_logger + +logger = get_logger("Data") + +# Cache directory structure constants +_CACHE_DIR_NAME = "datasets_cache" +_CACHE_PLAIN = "plain" +_CACHE_TOKENIZED = "tokenized" +_CACHE_SHUFFLED = "shuffled" +_CACHE_CONCATENATED = "concatenated" + +# Default parallelism for dataset operations +_DEFAULT_FILTER_NUM_PROC = 8 +_DEFAULT_TOKENIZE_NUM_PROC = 32 + + +# ============================================================================= +# Dataset pipeline utilities +# ============================================================================= + + +def _get_module_dir() -> Path: + """Return the directory containing this module.""" + return Path(__file__).parent + + +def _load_or_create_cached_dataset( + cache_path: Path, + create_fn: Callable[[], Dataset], + description: str, + caching_enabled: bool, +) -> Dataset: + """Load dataset from cache if available, otherwise create and optionally cache it. + + Args: + cache_path: Path to the cached dataset. + create_fn: Function that creates the dataset if not cached. + description: Human-readable description for logging. + caching_enabled: Whether to use/save cache. + + Returns: + The loaded or newly created dataset. + """ + if caching_enabled and cache_path.exists(): + logger.info(f"Loading {description} from cache: {cache_path}") + return load_from_disk(str(cache_path)) + + logger.info(f"Processing {description}") + dataset = create_fn() + + if caching_enabled: + logger.info(f"Saving {description} to cache: {cache_path}") + dataset.save_to_disk(str(cache_path)) + + return dataset + + +@dataclass +class ConfiguredTokenizer: + """A tokenizer with pre-configured kwargs for chat template and encoding. + + Attributes: + tokenizer: The underlying HuggingFace tokenizer. + apply_chat_template_kwargs: Additional kwargs for apply_chat_template. + encode_kwargs: Additional kwargs for encode. + """ + + tokenizer: PreTrainedTokenizerBase + apply_chat_template_kwargs: dict[str, Any] = field(default_factory=dict) + encode_kwargs: dict[str, Any] = field(default_factory=dict) + + def apply_chat_template(self, conversation: list[dict[str, str]]) -> str: + """Apply chat template to a conversation without tokenizing.""" + return self.tokenizer.apply_chat_template( + conversation, tokenize=False, **self.apply_chat_template_kwargs + ) + + def encode(self, prompt: str) -> list[int]: + """Encode a prompt string to token IDs.""" + return self.tokenizer.encode(prompt, **self.encode_kwargs) + + def decode(self, token_ids: list[int]) -> str: + """Decode token IDs back to a string.""" + return self.tokenizer.decode(token_ids) + + def get_hash(self) -> str: + """Generate a unique hash based on tokenizer configuration.""" + config = { + "tokenizer": self.tokenizer.name_or_path, + "apply_chat_template_kwargs": self.apply_chat_template_kwargs, + "encode_kwargs": self.encode_kwargs, + } + return hashlib.sha256(json.dumps(config, sort_keys=True).encode()).hexdigest() + + +@dataclass +class DatasetInfo: + """Configuration for loading and processing a dataset. + + To add a new dataset, create a DatasetInfo instance with: + - args/kwargs for HuggingFace load_dataset + - filter_fn to select relevant samples + - extract_fn to transform samples into chat format + + See the dataset definitions at the bottom of this file for examples. + + Attributes: + args: Positional arguments for load_dataset. + kwargs: Keyword arguments for load_dataset. + filter_fn: Function to filter dataset samples. + extract_fn: Function to extract/transform samples into chat format. + caching_enabled: Whether to cache intermediate results to disk. + """ + + args: tuple[Any, ...] + kwargs: dict[str, Any] + filter_fn: Callable[[Any], bool] + extract_fn: Callable[[Any], Any] + caching_enabled: bool = True + + def get_str_identifier(self) -> str: + """Return a JSON string uniquely identifying this dataset configuration.""" + identifier_parts = { + "args": self.args, + "kwargs": self.kwargs, + "filter_fn": self.filter_fn.__name__, + "extract_fn": self.extract_fn.__name__, + } + return json.dumps(identifier_parts, sort_keys=True) + + def get_hash(self) -> str: + """Generate a unique hash based on dataset configuration.""" + return hashlib.sha256(self.get_str_identifier().encode()).hexdigest() + + def _get_cache_base_path(self) -> Path: + """Return the base cache directory path for this dataset.""" + return _get_module_dir() / _CACHE_DIR_NAME / self.get_hash() + + def _load_plain_dataset(self) -> Dataset: + """Load and process the raw dataset (filter + extract).""" + cache_path = self._get_cache_base_path() / _CACHE_PLAIN + + def create_dataset() -> Dataset: + dataset = load_dataset(*self.args, **self.kwargs) + dataset = dataset.filter(self.filter_fn, num_proc=_DEFAULT_FILTER_NUM_PROC) + dataset = dataset.map(self.extract_fn, num_proc=_DEFAULT_FILTER_NUM_PROC) + return dataset + + return _load_or_create_cached_dataset( + cache_path=cache_path, + create_fn=create_dataset, + description=f"plain dataset {self.get_str_identifier()}", + caching_enabled=self.caching_enabled, + ) + + def _tokenize_dataset( + self, + dataset: Dataset, + configured_tokenizer: ConfiguredTokenizer, + ) -> Dataset: + """Apply tokenization to a dataset.""" + cache_path = ( + self._get_cache_base_path() / _CACHE_TOKENIZED / configured_tokenizer.get_hash() + ) + + def create_tokenized() -> Dataset: + def apply_tokenizer(sample: dict[str, Any]) -> dict[str, Any]: + conversation = sample["conversation"] + prompt = configured_tokenizer.apply_chat_template(conversation) + encoded_prompt = configured_tokenizer.encode(prompt) + return { + "conversation": conversation, + "prompt": prompt, + "encoded_prompt": encoded_prompt, + } + + return dataset.map(apply_tokenizer, num_proc=_DEFAULT_TOKENIZE_NUM_PROC) + + return _load_or_create_cached_dataset( + cache_path=cache_path, + create_fn=create_tokenized, + description=f"tokenized dataset {self.get_str_identifier()}", + caching_enabled=self.caching_enabled, + ) + + def _shuffle_dataset( + self, + dataset: Dataset, + configured_tokenizer: ConfiguredTokenizer, + shuffle_seed: int, + ) -> Dataset: + """Shuffle dataset with a given seed.""" + cache_path = ( + self._get_cache_base_path() + / _CACHE_SHUFFLED + / configured_tokenizer.get_hash() + / f"shuffle_seed_{shuffle_seed}" + ) + + return _load_or_create_cached_dataset( + cache_path=cache_path, + create_fn=lambda: dataset.shuffle(seed=shuffle_seed), + description=f"shuffled dataset {self.get_str_identifier()}", + caching_enabled=self.caching_enabled, + ) + + def _concatenate_dataset( + self, + dataset: Dataset, + configured_tokenizer: ConfiguredTokenizer, + shuffle_seed: int | None, + concat_up_to: int, + concat_always_start_new: bool, + ) -> Dataset: + """Concatenate samples to create fixed-length contexts. + + Args: + dataset: The tokenized dataset. + configured_tokenizer: Tokenizer used (for cache path). + shuffle_seed: Shuffle seed used (for cache path). + concat_up_to: Target context length in tokens. + concat_always_start_new: If True, discard tokens that overflow; + otherwise, carry them to the next context. + + Returns: + Dataset with concatenated samples. + """ + cache_path = ( + self._get_cache_base_path() + / _CACHE_CONCATENATED + / configured_tokenizer.get_hash() + / f"shuffle_seed_{shuffle_seed}" + / f"concat_up_to_{concat_up_to}" + / f"always_start_new_{concat_always_start_new}" + ) + + def create_concatenated() -> Dataset: + concatenated_samples: list[dict[str, Any]] = [] + current_context: dict[str, Any] = { + "prompt": "", + "encoded_prompt": [], + "num_samples": 0, + } + + for sample in tqdm(dataset, desc="Concatenating dataset"): + current_context["prompt"] += sample["prompt"] + current_context["encoded_prompt"] += sample["encoded_prompt"] + current_context["num_samples"] += 1 + + while len(current_context["encoded_prompt"]) >= concat_up_to: + # Store the full context before trimming + full_encoded = current_context["encoded_prompt"] + current_context["encoded_prompt_untrimmed"] = full_encoded + current_context["encoded_prompt"] = full_encoded[:concat_up_to] + concatenated_samples.append(current_context) + + # Handle overflow tokens + remaining_tokens = full_encoded[concat_up_to:] + current_context = { + "prompt": "", + "encoded_prompt": [], + "num_samples": 0, + } + + if not concat_always_start_new and remaining_tokens: + current_context["encoded_prompt"] = remaining_tokens + current_context["num_samples"] = 1 + + result = Dataset.from_list(concatenated_samples) + logger.info(f"Created concatenated dataset with {len(result)} samples") + return result + + return _load_or_create_cached_dataset( + cache_path=cache_path, + create_fn=create_concatenated, + description=f"concatenated dataset {self.get_str_identifier()}", + caching_enabled=self.caching_enabled, + ) + + def get_dataset( + self, + configured_tokenizer: ConfiguredTokenizer, + concat_up_to: int | None, + concat_always_start_new: bool = True, + shuffle_seed: int | None = None, + ) -> Dataset: + """Load and process dataset through the full pipeline. + + The pipeline stages are: + 1. Load raw dataset and apply filter/extract functions + 2. Tokenize using the configured tokenizer + 3. Optionally shuffle with given seed + 4. Optionally concatenate samples to fixed-length contexts + + Each stage is cached independently for efficient reprocessing. + + Args: + configured_tokenizer: Tokenizer configuration for encoding. + concat_up_to: Target context length in tokens. If None, no concatenation. + concat_always_start_new: If True, discard overflow tokens when concatenating; + otherwise, carry them to the next context. + shuffle_seed: Random seed for shuffling. If None, no shuffling. + + Returns: + Processed HuggingFace Dataset. + """ + # Stage 1: Load and filter/extract + dataset = self._load_plain_dataset() + + # Stage 2: Tokenize + dataset = self._tokenize_dataset(dataset, configured_tokenizer) + + # Stage 3: Shuffle (optional) + if shuffle_seed is not None: + dataset = self._shuffle_dataset(dataset, configured_tokenizer, shuffle_seed) + + # Stage 4: Concatenate (optional) + if concat_up_to is not None: + dataset = self._concatenate_dataset( + dataset, + configured_tokenizer, + shuffle_seed, + concat_up_to, + concat_always_start_new, + ) + + return dataset + + +# ============================================================================= +# Data blending +# ============================================================================= + + +@dataclass +class DataBlendElement: + """A single dataset element with its blend weight.""" + + dataset: DatasetInfo + weight: float # weight in the datablend (should be > 0.0) + + def __post_init__(self): + assert self.weight > 0.0, f"weight: {self.weight} is not greater than 0.0" + + +class DataBlend: + """Blends multiple datasets with configurable weights. + + Args: + data_blend_elements: list of datasets along with their weights in the datablend + configured_tokenizer: the tokenizer to use for the dataset + train_samples: the number of samples to provide + seed: used for datasets and datablend shuffling + concat_up_to: each sample is concatenated to match this length + concat_always_start_new: if true then suffixes of documents + that do not fit in concat_up_to context will be discarded, + otherwise they will be put at the beginning of the next context. + """ + + def __init__( + self, + data_blend_elements: list[DataBlendElement], + configured_tokenizer: ConfiguredTokenizer, + train_samples: int, + seed: int = 42, + concat_up_to: int | None = None, + concat_always_start_new: bool = True, + ): + """Initialize the data blend with weighted datasets.""" + self.configured_tokenizer = configured_tokenizer + logger.info(f"Configured tokenizer: {self.configured_tokenizer}") + + logger.info(f"Initializing DataBlend with {len(data_blend_elements)} data blend elements") + + self.dataset_weights = [] + self.datasets = [] + self.dataset_iterators = [] + + for dbe in tqdm(data_blend_elements, desc="Processing data blend elements"): + logger.info(f"Data blend element: {dbe.dataset.get_str_identifier()}") + logger.info(f"Data blend element weight: {dbe.weight}") + self.dataset_weights.append(dbe.weight) + self.datasets.append( + dbe.dataset.get_dataset( + configured_tokenizer=self.configured_tokenizer, + concat_up_to=concat_up_to, + shuffle_seed=seed, + concat_always_start_new=concat_always_start_new, + ) + ) + self.dataset_iterators.append(0) + + self.normalized_weights = np.array(self.dataset_weights, dtype=np.float64) + self.normalized_weights /= self.normalized_weights.sum() + + # self[id] -> dataset id + self.sample_mapping = [] + for i, nw in enumerate(self.normalized_weights): + nw = nw.item() + self.sample_mapping.append(np.full(int(nw * train_samples), i)) + + self.sample_mapping = np.concatenate(self.sample_mapping, axis=0) + rng = np.random.default_rng(seed=seed) + rng.shuffle(self.sample_mapping) + + def __len__(self): + return len(self.sample_mapping) + + def __getitem__(self, index: int): + ds_idx = self.sample_mapping[index] + ds_sample_idx = self.dataset_iterators[ds_idx] + ds_sample_idx %= len(self.datasets[ds_idx]) + + ds_sample = self.datasets[ds_idx][ds_sample_idx] + self.dataset_iterators[ds_idx] += 1 + + input_ids = np.array(ds_sample["encoded_prompt"], dtype=np.int64) + + ds_sample_augmented = { + "input_ids": input_ids, + "attention_mask": np.ones_like(input_ids, dtype=bool), + } + return ds_sample_augmented + + +# ============================================================================= +# Dataset definitions +# +# To add a new dataset: +# 1. Define filter_fn and extract_fn functions +# 2. Create a DatasetInfo instance +# 3. Reference it by name in YAML config data.blend field +# ============================================================================= + + +## OpenR1-Math-220k +def openr1_math_220k_filter_fn(ds_elem: Any) -> bool: + """Filter function to keep only samples with verified correct solutions.""" + return any(ds_elem["correctness_math_verify"]) + + +def openr1_math_220k_extract_fn(ds_elem: Any) -> dict[str, Any]: + """Extract problem-solution chat format from a dataset element.""" + problem = ds_elem["problem"] + solution = None + for gen, correctness in zip(ds_elem["generations"], ds_elem["correctness_math_verify"]): + if correctness: + solution = gen + + assert solution is not None, ( + "solution is None, filtering should remove problems without correct solutions" + ) + + chat = { + "conversation": [ + {"role": "user", "content": problem}, + {"role": "assistant", "content": solution}, + ] + } + + return chat + + +OpenR1Math220k = DatasetInfo( + args=("open-r1/OpenR1-Math-220k",), + kwargs={"split": "train"}, + filter_fn=openr1_math_220k_filter_fn, + extract_fn=openr1_math_220k_extract_fn, +) diff --git a/experimental/dms/dms/training/engine.py b/experimental/dms/dms/training/engine.py new file mode 100644 index 0000000000..740b94c715 --- /dev/null +++ b/experimental/dms/dms/training/engine.py @@ -0,0 +1,677 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DMS training engine: model configuration, distillation, noise, trainer state, and combined model.""" + +import hashlib +import json +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.distributed as dist +from transformers import ( + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + Trainer, + TrainerCallback, + TrainingArguments, +) + +from dms.core import DMSTrainingStateAux +from dms.logging import get_logger + +logger = get_logger("Engine") + + +# ============================================================================= +# Model configuration and loading +# ============================================================================= + +_DMS_PROJ_ALPHA_PATTERN = "dms_proj_alpha" +_UNFROZEN_DUMMY_PATTERN = "_unfrozen_dummy_param" + + +@dataclass +class ModelArguments: + """Arguments for specifying a model to load.""" + + model_name_or_path: str = field( + default="Qwen/Qwen3-0.6B", + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models." + }, + ) + dtype: str = field( + default="bfloat16", + metadata={"help": "Data type for model weights (e.g., 'bfloat16', 'float16', 'float32')."}, + ) + forward_fn_kwargs: dict[str, Any] = field( + default_factory=dict, metadata={"help": "Additional arguments for the forward function."} + ) + + +@dataclass +class DistillationModelArguments: + """Arguments for student-teacher distillation setup.""" + + student: ModelArguments = field(default_factory=ModelArguments) + teacher: ModelArguments = field(default_factory=ModelArguments) + + +def _is_trainable_param(name: str) -> bool: + """Check if a parameter should have gradients enabled for DMS training.""" + return _DMS_PROJ_ALPHA_PATTERN in name or _UNFROZEN_DUMMY_PATTERN in name + + +def _configure_gradients(model: PreTrainedModel) -> tuple[list[str], list[str]]: + """Configure which parameters require gradients for DMS training. + + Only DMS projection alpha parameters and unfrozen dummy parameters are trainable. + + Returns: + Tuple of (disabled_grad_names, enabled_grad_names). + """ + disabled_grad: list[str] = [] + enabled_grad: list[str] = [] + + for name, param in model.named_parameters(): + if _is_trainable_param(name): + param.requires_grad = True + enabled_grad.append(name) + else: + param.requires_grad = False + disabled_grad.append(name) + + return disabled_grad, enabled_grad + + +def get_student_model( + model_args: DistillationModelArguments, + model_constructor: type[PreTrainedModel], + zero_out_proj_alpha: bool, + dms_kwargs: dict[str, Any] | None, +) -> PreTrainedModel: + """Load and configure a student model for DMS distillation training. + + Args: + model_args: Distillation model arguments containing student config. + model_constructor: Model class with from_pretrained method. + zero_out_proj_alpha: If True, zero out DMS projection alpha parameters. + dms_kwargs: Additional keyword arguments for model construction. + + Returns: + Configured student model in training mode. + """ + if dms_kwargs is None: + dms_kwargs = {} + + logger.info( + f"Loading student model from {model_args.student.model_name_or_path} " + f"with dtype {model_args.student.dtype} and dms kwargs {json.dumps(dms_kwargs, indent=4)}" + ) + + model = model_constructor.from_pretrained( + model_args.student.model_name_or_path, + dtype=model_args.student.dtype, + **dms_kwargs, + ) + + if zero_out_proj_alpha: + for name, param in model.named_parameters(): + if _DMS_PROJ_ALPHA_PATTERN in name and f"{_DMS_PROJ_ALPHA_PATTERN}_norm" not in name: + logger.info(f"Zeroing out {name}") + param.data.zero_() + + disabled_grad, enabled_grad = _configure_gradients(model) + logger.info(f"Disabled gradients for: {disabled_grad}") + logger.info(f"Enabled gradients for: {enabled_grad}") + + model.train() + return model + + +def get_teacher_model( + model_args: DistillationModelArguments, + model_constructor: type[PreTrainedModel], +) -> PreTrainedModel: + """Load a teacher model for DMS distillation training. + + Args: + model_args: Distillation model arguments containing teacher config. + model_constructor: Model class with from_pretrained method. + + Returns: + Teacher model in evaluation mode. + """ + logger.info( + f"Loading teacher model from {model_args.teacher.model_name_or_path} with dtype {model_args.teacher.dtype}" + ) + + model = model_constructor.from_pretrained( + model_args.teacher.model_name_or_path, + dtype=model_args.teacher.dtype, + ) + model.eval() + return model + + +def get_tokenizer(model_args: ModelArguments) -> PreTrainedTokenizer: + """Load a tokenizer for the specified model. + + Args: + model_args: Model arguments containing the model path. + + Returns: + Loaded tokenizer instance. + """ + logger.info(f"Loading tokenizer from {model_args.model_name_or_path}") + return AutoTokenizer.from_pretrained(model_args.model_name_or_path) + + +# ============================================================================= +# DMS noise generation and schedule +# ============================================================================= + + +def get_gumbel_dist(dtype, device): + """Get the Gumbel distribution for DMS.""" + return torch.distributions.gumbel.Gumbel( + loc=torch.tensor(0.0, dtype=dtype, device=device), + scale=torch.tensor(1.0, dtype=dtype, device=device), + validate_args=None, + ) + + +def str_to_seed(text: str): + """Convert a string to a seed.""" + return int(hashlib.sha256(text.encode()).hexdigest(), 16) % (2**32 - 1) + + +def get_dms_noise(dist, device: torch.device, dms_state: DMSTrainingStateAux): + """Get the Gumbel noise for DMS. + + Uses current process index, gradient accumulation step and current step + to seed the random number generator. The shape of the noise is the same as the shape of the KV cache. + """ + with torch.random.fork_rng(devices=[device], enabled=True): + seed = str_to_seed( + f"{dms_state.process_index}_{dms_state.grad_acc_step}_{dms_state.current_step}" + ) + torch.manual_seed(seed) + a, b = [dist.sample(sample_shape=dms_state.kv_cache_shape).bfloat16() for _ in range(2)] + noise = a - b + + return noise + + +def dms_schedule( + step: int, + training_args: TrainingArguments, + dms_initial_cr, + dms_final_cr, + dms_final_step: int | None = None, +): + """Given the current training step, compute the DMS schedule. + + Returns the target fraction of DMS key-value pairs to evict and the compression ratio. + """ + if dms_final_step is not None: + max_steps = dms_final_step + else: + max_steps = training_args.max_steps + + progress = min(step / max_steps, 1.0) + + cr = dms_initial_cr + (dms_final_cr - dms_initial_cr) * progress + + frac = 1 / cr + + target = 1 - frac # what fraction of gates to close + + return target, cr + + +# ============================================================================= +# Trainer state (replaces module-level globals) +# ============================================================================= + + +class DMSTrainerState: + """Encapsulates training state shared between the trainer and the combined model. + + Usage: + state = DMSTrainerState() + combined_model = CombinedModel(..., trainer_state=state) + trainer = ModifiedTrainer(trainer_state=state, ...) + state.set_trainer(trainer) + """ + + def __init__(self): + """Initialize with no trainer (set later after trainer construction).""" + self.trainer: Trainer | None = None + self.logs: dict[str, float] = {} + self.grad_acc_step: int = 0 + + def set_trainer(self, trainer: Trainer): + """Set the HF trainer after construction.""" + self.trainer = trainer + + @property + def step(self) -> int: + """Get the current global training step.""" + assert self.trainer is not None, "Trainer not set. Call set_trainer() first." + return self.trainer.state.global_step + + @property + def process_index(self) -> int: + """Get the process index in distributed training.""" + assert self.trainer is not None, "Trainer not set. Call set_trainer() first." + return self.trainer.args.process_index + + def reset_grad_acc_step(self): + """Update the gradient accumulation step.""" + self.grad_acc_step = 0 + + def increment_grad_acc_step(self): + """Increment the gradient accumulation step.""" + self.grad_acc_step += 1 + + def update_logs(self, logs: dict[str, float]): + """Update logs and track gradient accumulation step.""" + self.logs = dict(**logs) + + +class DMSGradAccCallback(TrainerCallback): + """Callback to track gradient accumulation steps for DMS noise seeding.""" + + def __init__(self, trainer_state: DMSTrainerState): + """Initialize with a reference to the shared trainer state.""" + self.trainer_state = trainer_state + + def on_step_begin(self, *args, **kwargs): + """Reset the gradient accumulation step on new step.""" + self.trainer_state.reset_grad_acc_step() + + def on_substep_end(self, *args, **kwargs): + """Increment the gradient accumulation step on substep end.""" + self.trainer_state.increment_grad_acc_step() + + +class ModifiedTrainer(Trainer): + """Modified Trainer class that gathers DMS logs across distributed processes.""" + + def __init__(self, trainer_state: DMSTrainerState, **kwargs): + """Initialize with a reference to the shared trainer state.""" + super().__init__(**kwargs) + self.trainer_state = trainer_state + self.add_callback(DMSGradAccCallback(trainer_state)) + + def log(self, logs: dict[str, float], start_time: float | None = None): + """Log training metrics with gathered global DMS logs.""" + custom_logs = dict(**self.trainer_state.logs) + names = list(custom_logs.keys()) + values = [custom_logs[key] for key in names] + + if dist.is_initialized(): + values = torch.tensor(values, dtype=torch.float32, device=torch.cuda.current_device()) + dist.all_reduce(values, op=dist.ReduceOp.AVG) + + values = values.tolist() + + for key, value in zip(names, values): + logs["gl_" + key] = value + + super().log(logs=logs, start_time=start_time) + + +# ============================================================================= +# Distillation loss computation +# ============================================================================= + + +@torch.compile() +def distillation_loss( + student_raw_logits: torch.Tensor, # bfloat16 of shape batch, seq, vocab + teacher_raw_logits: torch.Tensor, # bfloat16 of shape batch, seq, vocab + loss_mask: torch.Tensor, # boolean of shape batch, seq + vocab_chunk: int, +): + """Compute KL divergence distillation loss between student and teacher logits (forward KL).""" + assert student_raw_logits.ndim == 3, ( + f"student_raw_logits.ndim: {student_raw_logits.ndim} != 3 (batch, seq, vocab)" + ) + assert teacher_raw_logits.shape == student_raw_logits.shape, ( + f"teacher_raw_logits.shape: {teacher_raw_logits.shape} != student_raw_logits.shape: {student_raw_logits.shape}" + ) + assert loss_mask.shape == student_raw_logits.shape[:2], ( + f"loss_mask.shape: {loss_mask.shape} != student_raw_logits.shape[:2]: {student_raw_logits.shape[:2]}" + ) + assert loss_mask.dtype == torch.bool, f"loss_mask.dtype: {loss_mask.dtype} != torch.bool" + + # log Denominator of size batch, seq + s_lse = torch.logsumexp(student_raw_logits.float(), dim=-1) # batch, seq + t_lse = torch.logsumexp(teacher_raw_logits.float(), dim=-1) + + # per-token KL + token_kl = torch.zeros_like(s_lse, dtype=torch.float32) + + vocab_size = student_raw_logits.shape[-1] + for start in range(0, vocab_size, vocab_chunk): + end = min(start + vocab_chunk, vocab_size) + + # batch, seq, vchunk + s_chunk = student_raw_logits[..., start:end] + t_chunk = teacher_raw_logits[..., start:end] + + s_logp = s_chunk.float() - s_lse[:, :, None] + t_logp = t_chunk.float() - t_lse[:, :, None] + + # Forward KL: KLD(Student, Teacher) + token_kl = token_kl + (s_logp.exp() * (s_logp - t_logp)).sum(dim=-1) + + denom = loss_mask.sum().clamp_min(1) + token_kl = token_kl.masked_fill(~loss_mask, 0.0) + return token_kl.sum() / denom + + +@torch.compile() +def calc_lm_loss( + student_raw_logits: torch.Tensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + eos_mask: torch.Tensor, +): + """Compute language modeling cross-entropy loss for the student model.""" + assert student_raw_logits.ndim == 3, ( + f"student_raw_logits.ndim: {student_raw_logits.ndim} != 3 (batch, seq, vocab)" + ) + assert input_ids.ndim == 2, f"input_ids.ndim: {input_ids.ndim} != 2 (batch, seq)" + assert attention_mask.ndim == 2, f"attention_mask.ndim: {attention_mask.ndim} != 2 (batch, seq)" + assert attention_mask.dtype == torch.bool, ( + f"attention_mask.dtype: {attention_mask.dtype} != torch.bool" + ) + assert eos_mask.ndim == 2, f"eos_mask.ndim: {eos_mask.ndim} != 2 (batch, seq)" + assert eos_mask.dtype == torch.bool, f"eos_mask.dtype: {eos_mask.dtype} != torch.bool" + student_raw_logits = student_raw_logits[:, :-1, :] + s_lse = torch.logsumexp(student_raw_logits, dim=-1) + + target_ids = input_ids[:, 1:] + source_raw_logits = student_raw_logits.gather(dim=-1, index=target_ids[:, :, None])[:, :, 0] + + assert source_raw_logits.shape == s_lse.shape + source_logp = source_raw_logits - s_lse + + # first do not predict from masked + # second do not predict masked + neg_loss_mask = torch.logical_or(~attention_mask[:, :-1], ~attention_mask[:, 1:]) + # do not predict from eos + neg_loss_mask = torch.logical_or(neg_loss_mask, ~eos_mask[:, :-1]) + source_logp = source_logp.masked_fill(neg_loss_mask, 0.0) + + lm_loss = -source_logp.sum() / (~neg_loss_mask).sum().clamp_min(1) + + return lm_loss + + +# ============================================================================= +# Distillation forward pass +# ============================================================================= + + +def distillation_forward( + student_model: PreTrainedModel, + teacher_model: PreTrainedModel | None, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + dms_schedule: Callable[[int], tuple[float, float]], + process_index: int, + tokenizer: PreTrainedTokenizer, + student_is_teacher: bool, + current_step: int, + grad_acc_step: int, + process_vocab_using_chunk: int, + forward_fn_kwargs_student: dict[str, Any], + forward_fn_kwargs_teacher: dict[str, Any], + **kwargs: Any, +): + """Run the distillation forward pass with student and teacher models. + + Args: + student_model: the student model + teacher_model: the teacher model, None if student is teacher + input_ids: input tokens for the models + attention_mask: boolean attention mask, true -> not masked, false -> masked + dms_schedule: fraction_of_kv_pairs_to_evict, 1/(1-fraction_of_kv_pairs_to_evict) = dms_schedule(current_step) + process_index: used for seeding of DMS noise + tokenizer: the tokenizer used for detecting eos tokens + student_is_teacher: true -> student is teacher, false -> student is student + current_step: the current step of the training loop (same across gradient accumulation steps) + grad_acc_step: the current gradient accumulation step (from 0 to gradient_accumulation_steps - 1) + process_vocab_using_chunk: the chunk size for processing the vocabulary in distillation loss calculation + forward_fn_kwargs_student: additional arguments for the student forward function + forward_fn_kwargs_teacher: additional arguments for the teacher forward function + **kwargs: additional arguments for the student and teacher models. + + Returns: + dict with loss (distillation + dms) and other metrics + """ + # here hf style mask + # true -> not masked + eos_mask = input_ids != tokenizer.eos_token_id + # no prediction from attn masked and no prediction from eos + assert attention_mask.dtype == torch.bool, ( + f"attention_mask.dtype: {attention_mask.dtype} != torch.bool" + ) + distill_loss_mask = torch.logical_and(attention_mask, eos_mask) + + dms_target_frac, dms_target_cr = dms_schedule(current_step) + + dms_state = DMSTrainingStateAux( + target_frac_to_close=dms_target_frac, + current_step=current_step, + grad_acc_step=grad_acc_step, + process_index=process_index, + noise=None, + right_padding_size=0, + kv_cache_shape=( + student_model.config.num_hidden_layers, + input_ids.shape[0], + student_model.config.num_key_value_heads, + input_ids.shape[1], + ), + dms_teacher_mode=False, + ) + dist_obj = get_gumbel_dist(dtype=torch.bfloat16, device=input_ids.device) + dms_state.noise = get_dms_noise(dist=dist_obj, device=input_ids.device, dms_state=dms_state) + + with torch.no_grad(): + dms_state_teacher = DMSTrainingStateAux( + target_frac_to_close=None, + current_step=current_step, + grad_acc_step=grad_acc_step, + process_index=process_index, + noise=None, + right_padding_size=0, + kv_cache_shape=dms_state.kv_cache_shape, + dms_teacher_mode=True, + ) + if student_is_teacher: + assert teacher_model is None, "teacher_model is not None when student is teacher" + teacher_output = student_model( + input_ids, + attention_mask, + dms_state=dms_state_teacher, + **forward_fn_kwargs_teacher, + **kwargs, + ) + else: + assert teacher_model is not None, "teacher_model is None when student is not teacher" + teacher_output = teacher_model( + input_ids, attention_mask, **forward_fn_kwargs_teacher, **kwargs + ) + teacher_logits = teacher_output.logits + assert teacher_logits.ndim == 3, ( + f"teacher_logits.ndim: {teacher_logits.ndim} != 3 (batch, seq, vocab)" + ) + + student_output = student_model( + input_ids, attention_mask, dms_state=dms_state, **forward_fn_kwargs_student, **kwargs + ) + dms_loss = student_output.dms_loss + + student_logits = student_output.logits + assert student_logits.ndim == 3, ( + f"student_logits.ndim: {student_logits.ndim} != 3 (batch, seq, vocab)" + ) + assert input_ids.ndim == 2, f"input_ids.ndim: {input_ids.ndim} != 2 (batch, seq)" + distil_loss = torch.utils.checkpoint.checkpoint( + distillation_loss, + student_logits, + teacher_logits, + distill_loss_mask, + process_vocab_using_chunk, + use_reentrant=False, + ) + with torch.no_grad(): + lm_loss_detach = calc_lm_loss(student_logits.detach(), input_ids, attention_mask, eos_mask) + + loss = distil_loss + dms_loss + + dms_closed_frac_detach = student_output.dms_frac_closed.detach() + + result = { + "loss": loss, + "dms_loss": dms_loss.detach(), + "distil_loss": distil_loss.detach(), + "dms_target_frac": torch.tensor(dms_target_frac, dtype=torch.float32, device=loss.device), + "dms_closed_frac": dms_closed_frac_detach, + "dms_target_cr": torch.tensor(dms_target_cr, dtype=torch.float32, device=loss.device), + "dms_cr": 1 / torch.clamp(1.0 - dms_closed_frac_detach, min=1e-6), + "input_tokens": torch.tensor(input_ids.shape[1], dtype=torch.int32, device=loss.device), + "detached_lm_loss": lm_loss_detach, + "positions_for_loss_calculation": distill_loss_mask.sum(), + "positions_without_loss_calculation": (~distill_loss_mask).sum(), + "eos_tokens": (~eos_mask).to(torch.int32).sum(), + "masked_tokens": (~attention_mask).to(torch.int32).sum(), + } + + return result + + +# ============================================================================= +# Combined student-teacher model +# ============================================================================= + + +class CombinedModel(torch.nn.Module): + """Combined student-teacher model wrapper for distillation training.""" + + def __init__( + self, + student_model: PreTrainedModel, + teacher_model: PreTrainedModel, + trainer_state: DMSTrainerState, + dms_schedule: Callable[[int], float], + forward_fn: Callable, + student_is_teacher: bool, + tokenizer: PreTrainedTokenizer, + process_vocab_using_chunk: int, + forward_fn_kwargs_student: dict[str, Any] | None = None, + forward_fn_kwargs_teacher: dict[str, Any] | None = None, + ): + """Initialize the combined model for distillation. + + Args: + student_model: the student model + teacher_model: the teacher model + trainer_state: shared trainer state object (replaces global callbacks) + dms_schedule: a function that given current step returns the DMS schedule + (target fraction of tokens to evict and compression ratio) + forward_fn: a function that performs the forward pass + student_is_teacher: whether the student is the teacher + tokenizer: the tokenizer + process_vocab_using_chunk: the chunk size for processing the vocabulary + forward_fn_kwargs_student: additional arguments for the student forward function + forward_fn_kwargs_teacher: additional arguments for the teacher forward function + """ + super().__init__() + if forward_fn_kwargs_student is None: + forward_fn_kwargs_student = {} + if forward_fn_kwargs_teacher is None: + forward_fn_kwargs_teacher = {} + self.student_is_teacher = student_is_teacher + self.student_model = student_model + if self.student_is_teacher: + self.teacher_model = None + else: + self.teacher_model = teacher_model + self._freeze_teacher_model() + self.trainer_state = trainer_state + self.dms_schedule = dms_schedule + self.forward_fn = forward_fn + + self.tokenizer = tokenizer + self.process_vocab_using_chunk = process_vocab_using_chunk + self.forward_fn_kwargs_student = forward_fn_kwargs_student + self.forward_fn_kwargs_teacher = forward_fn_kwargs_teacher + + def _freeze_teacher_model(self): + assert self.teacher_model is not None, "teacher_model is None" + for param in self.teacher_model.parameters(): + param.requires_grad = False + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Any): + """Run the forward pass with student and teacher models.""" + process_index = self.trainer_state.process_index + current_step = self.trainer_state.step + grad_acc_step = self.trainer_state.grad_acc_step + result = self.forward_fn( + student_model=self.student_model, + teacher_model=self.teacher_model, + input_ids=input_ids, + attention_mask=attention_mask, + dms_schedule=self.dms_schedule, + process_index=process_index, + student_is_teacher=self.student_is_teacher, + tokenizer=self.tokenizer, + current_step=current_step, + grad_acc_step=grad_acc_step, + process_vocab_using_chunk=self.process_vocab_using_chunk, + forward_fn_kwargs_student=self.forward_fn_kwargs_student, + forward_fn_kwargs_teacher=self.forward_fn_kwargs_teacher, + **kwargs, + ) + + self.trainer_state.update_logs( + {key: value.detach().clone().cpu().item() for key, value in result.items()} + ) + + return result + + def get_parameters_to_optimize(self): + """Get the trainable parameters from the student model.""" + params_to_optimize = [ + param for param in self.student_model.parameters() if param.requires_grad + ] + return tuple(params_to_optimize) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """Enable gradient checkpointing on the student model.""" + self.student_model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs + ) diff --git a/experimental/dms/example_inference.ipynb b/experimental/dms/example_inference.ipynb new file mode 100644 index 0000000000..e38ff213bc --- /dev/null +++ b/experimental/dms/example_inference.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7ea47246", + "metadata": {}, + "outputs": [], + "source": [ + "# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", + "# SPDX-License-Identifier: Apache-2.0\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import gc\n", + "\n", + "import torch\n", + "from models.qwen3.modeling_qwen3_dms import Qwen3ForCausalLMDMS\n", + "from transformers import AutoTokenizer, TextStreamer\n", + "\n", + "# Auxiliary functions\n", + "\n", + "\n", + "def get_model_and_tokenizer(model_name: str, **model_kwargs):\n", + " model = Qwen3ForCausalLMDMS.from_pretrained(model_name, dtype=torch.bfloat16, **model_kwargs)\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " model.eval()\n", + " model.to(torch.device(\"cuda:0\"))\n", + " return model, tokenizer\n", + "\n", + "\n", + "def clean_cache():\n", + " gc.collect()\n", + " torch.cuda.empty_cache()\n", + "\n", + "\n", + "def get_example_input_ids(tokenizer: AutoTokenizer, device: torch.device):\n", + " prompt = [{\"role\": \"user\", \"content\": \"Solve x^2 -2x + 1 = 0\"}]\n", + " prompt = tokenizer.apply_chat_template(\n", + " prompt, tokenize=False, add_generation_prompt=True, enable_thinking=True\n", + " )\n", + " input_data = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", + " input_ids = input_data[\"input_ids\"]\n", + " return input_ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "096cd6f8", + "metadata": {}, + "outputs": [], + "source": [ + "# The following function will run chunked prefill to avoid OOM,\n", + "# and follow with generation\n", + "\n", + "\n", + "def example_prefill_generate(model_name: str):\n", + " model, tokenizer = get_model_and_tokenizer(model_name, dms_chunked_prefill=4)\n", + "\n", + " input_ids = get_example_input_ids(tokenizer, model.device)\n", + "\n", + " streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)\n", + "\n", + " # automatically replaces DynamicCache with DMSCache\n", + " model.generate(input_ids, max_new_tokens=4096, do_sample=False, streamer=streamer)\n", + "\n", + "\n", + "clean_cache()\n", + "example_prefill_generate(model_name=\"nvidia/Qwen3-8B-DMS-8x\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b0a362a", + "metadata": {}, + "outputs": [], + "source": [ + "# In order to run multi-turn conversations using this code, which interleave prefill with generation,\n", + "# a fallback of using only prefills can be used.\n", + "\n", + "\n", + "def example_prefill_only_mode(model_name: str):\n", + " model, tokenizer = get_model_and_tokenizer(\n", + " model_name, dms_manual_inference_mode=True, dms_preallocate_for_tokens=512\n", + " )\n", + "\n", + " cache = model.get_cache()\n", + " input_ids = get_example_input_ids(tokenizer, model.device)\n", + "\n", + " with torch.no_grad():\n", + " cache.prefill_mode()\n", + "\n", + " for _ in range(4096):\n", + " # prefill mode is slower but allows for interleaving of prefill and inference\n", + " assert cache.is_prefill_mode()\n", + " outputs = model(input_ids, past_key_values=cache, use_cache=True)\n", + " input_ids = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n", + " print(\n", + " tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=False),\n", + " end=\"\",\n", + " flush=True,\n", + " )\n", + " if input_ids[0, 0] == tokenizer.eos_token_id:\n", + " break\n", + "\n", + "\n", + "clean_cache()\n", + "example_prefill_only_mode(model_name=\"nvidia/Qwen3-8B-DMS-8x\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/experimental/dms/models/__init__.py b/experimental/dms/models/__init__.py new file mode 100644 index 0000000000..15d1286188 --- /dev/null +++ b/experimental/dms/models/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-specific DMS adaptations.""" diff --git a/experimental/dms/models/qwen3/__init__.py b/experimental/dms/models/qwen3/__init__.py new file mode 100644 index 0000000000..7b44a9a499 --- /dev/null +++ b/experimental/dms/models/qwen3/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3 model family with DMS (Dynamic Memory Sparsification).""" diff --git a/experimental/dms/models/qwen3/configuration_qwen3_dms.py b/experimental/dms/models/qwen3/configuration_qwen3_dms.py new file mode 100644 index 0000000000..39d2c93739 --- /dev/null +++ b/experimental/dms/models/qwen3/configuration_qwen3_dms.py @@ -0,0 +1,101 @@ +# Adapted from https://github.com/huggingface/transformers/blob/47b0e478f324b54f177ea7998a0791870fdd0324/src/transformers/models/qwen3/configuration_qwen3.py + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3 model configuration.""" + +from dms.core import setup_compile_limit_for_dms +from transformers import Qwen3Config + + +class Qwen3ConfigDMS(Qwen3Config): + """DMS configuration for Qwen3 model.""" + + def __init__( + self, + dms_alpha_scale: float = 100.0, + dms_initial_alpha_offset: float = 5.0, + dms_window_size: int = 512, + dms_paged_attention_block_size: int = 256, + dms_cr: int = 8, + dms_disable_eviction: bool = False, + dms_separate_alpha: bool = False, + dms_alpha_per: str = "head", + dms_tau: float = 0.1, + dms_compile_limit: int | None = 72, + dms_manual_inference_mode: bool = False, + dms_chunked_prefill: int | None = None, + dms_preallocate_for_tokens: int = 4096, + **kwargs, + ): + """Initialize the Qwen3ConfigDMS model. + + Args: + dms_alpha_scale: scaling factor for DMS decision logits. + dms_initial_alpha_offset: initial offset for DMS decision logits. + dms_window_size: sliding window size for DMS. + dms_paged_attention_block_size: block size for paged cache. + dms_cr: compression ratio for DMS. For documentation purposes only. + dms_disable_eviction: turns adapter DMS models into vanilla models. + dms_separate_alpha: True -> We initialise new parameters, False -> DMS uses query parameters. + dms_alpha_per: Whether to make per head or per layer for DMS eviction decisions. + dms_tau: Temperature for DMS decision logits. + dms_compile_limit: Torch.compile limit. + dms_manual_inference_mode: Whether to use inference with manual prefill/inference switching for kv-cache. + dms_chunked_prefill: Chunk size for prefill. + dms_preallocate_for_tokens: Preallocate space for tokens in kv-cache. + """ + self.dms_alpha_scale = dms_alpha_scale + self.dms_initial_alpha_offset = dms_initial_alpha_offset + self.dms_window_size = dms_window_size + self.dms_paged_attention_block_size = dms_paged_attention_block_size + self.dms_cr = dms_cr + self.dms_disable_eviction = dms_disable_eviction + self.dms_separate_alpha = dms_separate_alpha + self.dms_alpha_per = dms_alpha_per + self.dms_tau = dms_tau + self.dms_manual_inference_mode = dms_manual_inference_mode + self.dms_chunked_prefill = dms_chunked_prefill + self.dms_preallocate_for_tokens = dms_preallocate_for_tokens + + assert self.dms_paged_attention_block_size > 0, ( + f"dms_paged_attention_block_size: {self.dms_paged_attention_block_size} is not greater than 0" + ) + assert self.dms_window_size > self.dms_paged_attention_block_size, ( + f"dms_window_size: {self.dms_window_size} " + f"is not greater than dms_paged_attention_block_size: {self.dms_paged_attention_block_size}" + ) + assert self.dms_alpha_per in ["head", "layer"], ( + f"dms_alpha_per: {self.dms_alpha_per} is not supported" + ) + if dms_compile_limit is not None: + setup_compile_limit_for_dms(compile_limit=dms_compile_limit) + super().__init__( + **kwargs, + ) diff --git a/experimental/dms/models/qwen3/extract.py b/experimental/dms/models/qwen3/extract.py new file mode 100644 index 0000000000..e93c4c0488 --- /dev/null +++ b/experimental/dms/models/qwen3/extract.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Extract a trained DMS student model from an intermediate checkpoint. + +Usage: + python -m models.qwen3.extract \ + --config outputs/qwen3_8b/config.yaml \ + --checkpoint outputs/qwen3_8b/checkpoint-238 \ + --output outputs/qwen3_8b/student_model_step238 +""" + +import argparse +from pathlib import Path + +import torch +from dms.logging import get_logger +from dms.training.engine import DistillationModelArguments, DMSTrainerState, ModelArguments +from transformers import AutoTokenizer, TrainingArguments + +from .train import build_combined_model, extract_student_model, load_config + +logger = get_logger("Extract") + + +def main() -> None: + """Extract student model from a training checkpoint.""" + parser = argparse.ArgumentParser( + description="Extract DMS student model from a training checkpoint." + ) + parser.add_argument( + "--config", type=str, required=True, help="Path to the training YAML config" + ) + parser.add_argument( + "--checkpoint", type=str, required=True, help="Path to the checkpoint directory" + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output path (defaults to checkpoint/student_model)", + ) + cli_args = parser.parse_args() + + cfg = load_config(cli_args.config) + model_cfg = cfg["model"] + dms_cfg = cfg["dms"] + data_cfg = cfg["data"] + + checkpoint_dir = Path(cli_args.checkpoint) + model_path = checkpoint_dir / "pytorch_model.bin" + save_path = cli_args.output or str(checkpoint_dir / "student_model") + + logger.info(f"Loading model from: {model_path}") + logger.info(f"Saving model to: {save_path}") + + tokenizer = AutoTokenizer.from_pretrained(str(checkpoint_dir)) + + model_args = DistillationModelArguments( + student=ModelArguments( + model_name_or_path=model_cfg["name"], + dtype=model_cfg.get("dtype", "float32"), + ), + teacher=ModelArguments( + model_name_or_path=model_cfg.get("teacher_name", model_cfg["name"]), + dtype=model_cfg.get("teacher_dtype", model_cfg.get("dtype", "float32")), + ), + ) + + training_args = TrainingArguments(output_dir=".") + trainer_state = DMSTrainerState() + + logger.info("Creating combined model...") + combined_model = build_combined_model( + model_args=model_args, + training_args=training_args, + dms_cfg=dms_cfg, + data_cfg=data_cfg, + tokenizer=tokenizer, + trainer_state=trainer_state, + ) + + logger.info("Loading checkpoint weights...") + combined_model.load_state_dict(torch.load(model_path, weights_only=True)) + + extract_student_model(combined_model, tokenizer, save_path) + + +if __name__ == "__main__": + main() diff --git a/experimental/dms/models/qwen3/modeling_qwen3_dms.py b/experimental/dms/models/qwen3/modeling_qwen3_dms.py new file mode 100644 index 0000000000..ce4a56e0ac --- /dev/null +++ b/experimental/dms/models/qwen3/modeling_qwen3_dms.py @@ -0,0 +1,498 @@ +# Adapted from https://github.com/huggingface/transformers/blob/47b0e478f324b54f177ea7998a0791870fdd0324/src/transformers/models/qwen3/modeling_qwen3.py + +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen3 model with DMS (Dynamic Memory Sparsification).""" + +import torch +from dms.attention import dms_attention +from dms.cache import DMSCache, Mode +from dms.core import ( + DMSBaseModelOutputWithPastAndCR, + DMSCausalLMOutputWithPastAndCR, + DMSTrainingStateAux, + dms_perform_chunked_prefill, + post_process_attention_output, + prepare_attention_input, +) +from dms.logging import get_logger +from torch import nn +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3MLP, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + apply_rotary_pos_emb, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg + +from .configuration_qwen3_dms import Qwen3ConfigDMS + +logger = get_logger("Qwen3ForCausalLMDMS") + + +class Qwen3AttentionDMS(Qwen3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__(self, config: Qwen3ConfigDMS, layer_idx: int): + """Initialize the Qwen3AttentionDMS model.""" + super().__init__(config=config, layer_idx=layer_idx) + self.dms_alpha_scale = config.dms_alpha_scale + self.dms_initial_alpha_offset = config.dms_initial_alpha_offset + self.dms_window_size = config.dms_window_size + self.dms_disable_eviction = config.dms_disable_eviction + + self.num_key_value_heads = config.num_key_value_heads + + self.dms_tau = config.dms_tau + + if self.config.dms_separate_alpha: + self.dms_proj_alpha_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dms_proj_alpha = nn.Linear( + config.hidden_size, self.num_key_value_heads, bias=config.attention_bias + ) + else: + self.dms_proj_alpha_norm = None + self.dms_proj_alpha = None + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + pre_attn_norm_hidden_states: torch.Tensor, + post_attn_norm_hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: DMSCache | None = None, + cache_position: torch.LongTensor | None = None, + dms_state: DMSTrainingStateAux = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """A modified version of the forward pass from the transformers Qwen3Attention model.""" + if self.training: + assert dms_state is not None, "dms_state is None in training mode" + flash_attn_query_states, key_states, value_states, decisions, decision_logits = ( + prepare_attention_input( + pre_attn_norm_hidden_states=pre_attn_norm_hidden_states, + post_attn_norm_hidden_states=post_attn_norm_hidden_states, + q_proj_fn=self.q_proj, + k_proj_fn=self.k_proj, + v_proj_fn=self.v_proj, + q_norm_fn=self.q_norm, + k_norm_fn=self.k_norm, + head_dim=self.head_dim, + cos=position_embeddings[0], + sin=position_embeddings[1], + dms_proj_alpha_norm_fn=self.dms_proj_alpha_norm, + dms_proj_alpha_fn=self.dms_proj_alpha, + dms_alpha_per=self.config.dms_alpha_per, + dms_decision_scale=self.dms_alpha_scale, + dms_initial_decision_offset=self.dms_initial_alpha_offset, + dms_training=self.training, + dms_disable_eviction=self.dms_disable_eviction, + dms_tau=self.dms_tau, + apply_rotary_pos_emb_fn=apply_rotary_pos_emb, + dms_teacher_mode=(dms_state is not None and dms_state.dms_teacher_mode), + dms_noise=dms_state.noise[self.layer_idx] + if (dms_state is not None and dms_state.noise is not None) + else None, + ) + ) + + attn_output = dms_attention( + new_q_flash=flash_attn_query_states, + new_k=key_states, + new_v=value_states, + decisions=decisions, + decision_logits=decision_logits, + attention_mask=attention_mask, + layer_idx=self.layer_idx, + dms_cache=past_key_values, + attn_scaling=self.scaling, + window_size=self.dms_window_size, + train_attn_kwargs=kwargs.get("train_attn_kwargs", {}), + ) + + attn_output = post_process_attention_output( + attn_output=attn_output, + o_proj=self.o_proj, + ) + + return attn_output, decisions + + +class Qwen3DecoderLayerDMS(GradientCheckpointingLayer): + """A modified version of the transformers Qwen3DecoderLayer model.""" + + def __init__(self, config: Qwen3ConfigDMS, layer_idx: int): + """Initialize the Qwen3DecoderLayerDMS model.""" + super().__init__() + self.dms_window_size = config.dms_window_size + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3AttentionDMS(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: DMSCache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + dms_state: DMSTrainingStateAux | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """A modified version of the forward pass from the transformers Qwen3DecoderLayer model.""" + residual = hidden_states + # Self Attention + pre_attn_norm_hidden_states = hidden_states + post_attn_norm_hidden_states = self.input_layernorm(hidden_states) + hidden_states, decisions = self.self_attn( + pre_attn_norm_hidden_states=pre_attn_norm_hidden_states, + post_attn_norm_hidden_states=post_attn_norm_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + dms_state=dms_state, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.training and dms_state is not None and not dms_state.dms_teacher_mode: + assert dms_state.kv_cache_shape[1:] == decisions.shape, ( + f"dms_state.kv_cache_shape[1:]: {dms_state.kv_cache_shape[1:]} != decisions.shape: {decisions.shape}" + ) + if dms_state.right_padding_size > 0: + decisions = decisions[:, :, : -dms_state.right_padding_size] + dms_frac_closed = decisions.float().mean(dim=(1, 2)) + else: + dms_frac_closed = None + + return hidden_states, dms_frac_closed + + +class Qwen3PreTrainedModelDMS(PreTrainedModel): + """A modified version of the transformers Qwen3PreTrainedModel model.""" + + config: Qwen3ConfigDMS + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayerDMS"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3DecoderLayerDMS, + "attentions": Qwen3AttentionDMS, + } + + +class Qwen3ModelDMS(Qwen3PreTrainedModelDMS): + """A modified version of the transformers Qwen3Model model.""" + + def __init__(self, config: Qwen3ConfigDMS): + """Initialize the Qwen3ModelDMS model.""" + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayerDMS(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: DMSCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + dms_state: DMSTrainingStateAux = None, + **kwargs: Unpack[TransformersKwargs], + ) -> DMSBaseModelOutputWithPastAndCR: + """A modified version of the forward pass from the transformers Qwen3Model model.""" + if self.training: + assert dms_state is not None + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + assert past_key_values is None or isinstance(past_key_values, DMSCache), ( + f"past_key_values is not a DMSCache: {type(past_key_values)}" + ) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask( + **mask_kwargs + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + hidden_states, dms_frac_closed = dms_perform_chunked_prefill( + decoder_layers=self.layers[: self.config.num_hidden_layers], + hidden_states=hidden_states, + attention_mask=causal_mask_mapping["full_attention"], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + dms_state=dms_state, + dms_manual_inference_mode=self.config.dms_manual_inference_mode, + dms_chunked_prefill=self.config.dms_chunked_prefill, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return DMSBaseModelOutputWithPastAndCR( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + cr=past_key_values.get_cr() if past_key_values is not None else None, + dms_frac_closed=dms_frac_closed, + ) + + +class Qwen3ForCausalLMDMS(Qwen3PreTrainedModelDMS, GenerationMixin): + """A modified version of the transformers Qwen3ForCausalLM model.""" + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config: Qwen3ConfigDMS): + """Initialize the Qwen3ForCausalLMDMS model.""" + super().__init__(config) + self.model = Qwen3ModelDMS(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_cache(self, preallocate_for_tokens: int | None = None): + """Get the DMS cache for inference.""" + if preallocate_for_tokens is None: + preallocate_for_tokens = self.config.dms_preallocate_for_tokens + return DMSCache( + dms_window_size=self.config.dms_window_size + 1, + max_context_length=self.config.max_position_embeddings, + accommodate_min_initial_context_length=preallocate_for_tokens, + disable_eviction=self.config.dms_disable_eviction, + block_size=self.config.dms_paged_attention_block_size, + ) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: DMSCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + dms_state: DMSTrainingStateAux = None, + **kwargs: Unpack[TransformersKwargs], + ) -> DMSCausalLMOutputWithPastAndCR: + """A modified version of the forward pass from the transformers Qwen3ForCausalLM model.""" + if self.training: + assert dms_state is not None, "dms_state is None in training mode" + if (not self.training) and ( + (use_cache and past_key_values is None) or not isinstance(past_key_values, DMSCache) + ): + if past_key_values is not None: + logger.warning( + f"past_key_values is of type {type(past_key_values)}, it will be replaced with an empty DMSCache!" + ) + past_key_values = self.get_cache() + + outputs: DMSBaseModelOutputWithPastAndCR = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + dms_state=dms_state, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs + ) + + dms_frac_closed: torch.Tensor | None = outputs.dms_frac_closed + dms_cr = None + + if dms_state is not None and dms_state.target_frac_to_close is not None: + assert self.training, ( + "dms_state.target_frac_to_close is only supported in training mode" + ) + assert dms_frac_closed is not None, "dms_frac_closed is None during training" + dms_loss = torch.clamp(dms_state.target_frac_to_close - dms_frac_closed, min=0.0).mean() + dms_frac_open = 1 - dms_frac_closed.detach().mean() + dms_cr = 1 / torch.clamp(dms_frac_open, min=1e-6) + else: + assert (not self.training) or dms_state.dms_teacher_mode, ( + "dms_state.target_frac_to_close is required in training mode" + ) + dms_loss = None + + if past_key_values is not None and past_key_values.current_mode == Mode.INFERENCE: + dms_cr = past_key_values.get_cr() + + return DMSCausalLMOutputWithPastAndCR( + loss=loss, + dms_loss=dms_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cr=dms_cr, + dms_frac_closed=dms_frac_closed.detach().mean() + if dms_frac_closed is not None + else None, + ) + + +class Qwen3ForSequenceClassificationDMS(GenericForSequenceClassification, Qwen3PreTrainedModelDMS): + """Qwen3 model for sequence classification with DMS.""" + + +class Qwen3ForTokenClassificationDMS(GenericForTokenClassification, Qwen3PreTrainedModelDMS): + """Qwen3 model for token classification with DMS.""" + + +class Qwen3ForQuestionAnsweringDMS(GenericForQuestionAnswering, Qwen3PreTrainedModelDMS): + """Qwen3 model for question answering with DMS.""" + + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "Qwen3ForCausalLMDMS", + "Qwen3ForQuestionAnsweringDMS", + "Qwen3ForSequenceClassificationDMS", + "Qwen3ForTokenClassificationDMS", + "Qwen3ModelDMS", + "Qwen3PreTrainedModelDMS", +] diff --git a/experimental/dms/models/qwen3/train.py b/experimental/dms/models/qwen3/train.py new file mode 100644 index 0000000000..78fbb19a2f --- /dev/null +++ b/experimental/dms/models/qwen3/train.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training entry point for Qwen3 model with DMS. + +Usage: + # First run prepares the dataset (single process): + python -m models.qwen3.train --config configs/qwen3_8b.yaml --prepare-dataset-only + + # Then launch distributed training: + accelerate launch -m models.qwen3.train --config configs/qwen3_8b.yaml +""" + +import argparse +import functools +import json +import os +import shutil +from pathlib import Path + +import torch +import yaml +from datasets import Dataset +from dms.logging import get_logger +from dms.training.data import ConfiguredTokenizer, DataBlend, DataBlendElement +from dms.training.engine import ( + CombinedModel, + DistillationModelArguments, + DMSTrainerState, + ModelArguments, + ModifiedTrainer, + distillation_forward, + dms_schedule, + get_student_model, + get_teacher_model, + get_tokenizer, +) +from transformers import ( + DataCollatorWithPadding, + PreTrainedTokenizer, + Qwen3ForCausalLM, + TrainingArguments, +) + +from .modeling_qwen3_dms import Qwen3ForCausalLMDMS + +logger = get_logger("Train") + + +# ============================================================================= +# Config loading +# ============================================================================= + + +def load_config(path: str) -> dict: + """Load a YAML configuration file.""" + with open(path) as f: + return yaml.safe_load(f) + + +def save_config(cfg: dict, output_dir: str) -> None: + """Save the configuration to the output directory for reproducibility.""" + os.makedirs(output_dir, exist_ok=True) + config_path = os.path.join(output_dir, "config.yaml") + with open(config_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False, sort_keys=False) + logger.info(f"Saved config to {config_path}") + + +def resolve_checkpoint(cfg: dict) -> str | None: + """Resolve the checkpoint path for resume, supporting 'auto' detection.""" + resume = cfg["hf_trainer"].get("resume_from_checkpoint") + if resume == "auto": + output_dir = Path(cfg["hf_trainer"]["output_dir"]) + checkpoints = sorted( + output_dir.glob("checkpoint-*"), + key=lambda p: int(p.name.split("-")[1]), + ) + if checkpoints: + logger.info(f"Auto-detected latest checkpoint: {checkpoints[-1]}") + return str(checkpoints[-1]) + else: + logger.info("No checkpoints found, starting fresh.") + return None + return resume + + +# ============================================================================= +# Dataset parsing +# ============================================================================= + + +def _parse_data_blend_elements(blend_string: str) -> list[DataBlendElement]: + """Parse data blend elements from a comma-separated string. + + Args: + blend_string: Comma-separated string of "dataset_name:weight" pairs. + + Returns: + List of DataBlendElement objects. + """ + import dms.training.data as rf_datasets + + elements = [] + for entry in blend_string.split(","): + dataset_name, weight_str = entry.split(":") + dataset = getattr(rf_datasets, dataset_name) + weight = float(weight_str) + logger.info(f"Adding dataset {dataset_name} with weight {weight}") + elements.append(DataBlendElement(dataset=dataset, weight=weight)) + return elements + + +def _create_train_dataset( + data_cfg: dict, + configured_tokenizer: ConfiguredTokenizer, +) -> Dataset: + """Create the training dataset from data blend configuration.""" + data_blend_elements = _parse_data_blend_elements(data_cfg["blend"]) + + data_blend = DataBlend( + data_blend_elements=data_blend_elements, + configured_tokenizer=configured_tokenizer, + train_samples=data_cfg["train_samples"], + concat_up_to=data_cfg["max_length"], + concat_always_start_new=data_cfg.get("concat_always_start_new", True), + ) + + return Dataset.from_generator(lambda: (data_blend[i] for i in range(len(data_blend)))) + + +# ============================================================================= +# Model building +# ============================================================================= + + +def build_combined_model( + model_args: DistillationModelArguments, + training_args: TrainingArguments, + dms_cfg: dict, + data_cfg: dict, + tokenizer: PreTrainedTokenizer, + trainer_state: DMSTrainerState, +) -> CombinedModel: + """Build the combined student-teacher model for distillation.""" + dms_kwargs = { + f"dms_{k}" if not k.startswith("dms_") else k: v + for k, v in dms_cfg.items() + if k not in ("initial_cr", "final_cr", "final_step") + } + + student_model = get_student_model( + model_args, + zero_out_proj_alpha=True, + model_constructor=Qwen3ForCausalLMDMS, + dms_kwargs=dms_kwargs, + ) + + student_is_teacher = ( + model_args.student.model_name_or_path == model_args.teacher.model_name_or_path + and model_args.student.dtype == model_args.teacher.dtype + ) + if student_is_teacher: + logger.info("Student and teacher are the same model - optimization enabled") + teacher_model = student_model + else: + logger.info("Student and teacher are different models") + teacher_model = get_teacher_model(model_args, model_constructor=Qwen3ForCausalLM) + + return CombinedModel( + student_model=student_model, + teacher_model=teacher_model, + trainer_state=trainer_state, + dms_schedule=functools.partial( + dms_schedule, + training_args=training_args, + dms_initial_cr=dms_cfg["initial_cr"], + dms_final_cr=dms_cfg["final_cr"], + dms_final_step=dms_cfg.get("final_step"), + ), + forward_fn=distillation_forward, + student_is_teacher=student_is_teacher, + tokenizer=tokenizer, + process_vocab_using_chunk=data_cfg.get("process_vocab_using_chunk", 4096), + forward_fn_kwargs_student=model_args.student.forward_fn_kwargs, + forward_fn_kwargs_teacher=model_args.teacher.forward_fn_kwargs, + ) + + +# ============================================================================= +# Student model extraction +# ============================================================================= + +AUTO_MAP_CONFIG = { + "AutoConfig": "configuration_qwen3_dms.Qwen3ConfigDMS", + "AutoModel": "modeling_qwen3_dms.Qwen3ModelDMS", + "AutoModelForCausalLM": "modeling_qwen3_dms.Qwen3ForCausalLMDMS", + "AutoModelForQuestionAnswering": "modeling_qwen3_dms.Qwen3ForQuestionAnsweringDMS", + "AutoModelForSequenceClassification": "modeling_qwen3_dms.Qwen3ForSequenceClassificationDMS", + "AutoModelForTokenClassification": "modeling_qwen3_dms.Qwen3ForTokenClassificationDMS", +} + + +def extract_student_model( + combined_model: CombinedModel, + tokenizer: PreTrainedTokenizer, + save_path: str, +) -> None: + """Extract the student model from a CombinedModel and save it for inference. + + The saved model includes: + - Model weights in bfloat16 + - Config with auto_map for trust_remote_code + - Model implementation files (config.py, model.py) + - Tokenizer + + Note: The saved model imports from the `dms` package. Make sure `dms` is + installed (pip install -e .) in any environment where you load this model. + """ + student_model = combined_model.student_model + logger.info(f"Extracting student model to: {save_path}") + + student_model.to(torch.bfloat16) + student_model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + + # Update config.json with auto_map + config_path = Path(save_path) / "config.json" + with open(config_path) as f: + config = json.load(f) + config.pop("architectures", None) + config["auto_map"] = AUTO_MAP_CONFIG + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + # Copy model implementation files for trust_remote_code + model_dir = Path(__file__).parent + for src_name in ["configuration_qwen3_dms.py", "modeling_qwen3_dms.py"]: + shutil.copy(model_dir / src_name, Path(save_path) / src_name) + + logger.info(f"Successfully saved student model to: {save_path}") + + +# ============================================================================= +# Main entry point +# ============================================================================= + + +def main() -> None: + """Main entry point for the training script.""" + parser = argparse.ArgumentParser(description="Train DMS adapter for Qwen3") + parser.add_argument("--config", required=True, help="Path to YAML config file") + parser.add_argument( + "--prepare-dataset-only", + action="store_true", + help="Only prepare the dataset, then exit (run with single process first)", + ) + args, _unknown = parser.parse_known_args() + + cfg = load_config(args.config) + + model_cfg = cfg["model"] + dms_cfg = cfg["dms"] + data_cfg = cfg["data"] + hf_trainer_cfg = cfg["hf_trainer"] + + # Build model arguments + model_args = DistillationModelArguments( + student=ModelArguments( + model_name_or_path=model_cfg["name"], + dtype=model_cfg.get("dtype", "float32"), + forward_fn_kwargs=model_cfg.get("forward_fn_kwargs", {}), + ), + teacher=ModelArguments( + model_name_or_path=model_cfg.get("teacher_name", model_cfg["name"]), + dtype=model_cfg.get("teacher_dtype", model_cfg.get("dtype", "float32")), + forward_fn_kwargs=model_cfg.get( + "teacher_forward_fn_kwargs", model_cfg.get("forward_fn_kwargs", {}) + ), + ), + ) + + # Resolve checkpoint resume + checkpoint_path = resolve_checkpoint(cfg) + if checkpoint_path: + hf_trainer_cfg["resume_from_checkpoint"] = checkpoint_path + + training_args = TrainingArguments(**hf_trainer_cfg) + + logger.info(f"\n--- Config ---\n{yaml.dump(cfg, default_flow_style=False)}") + + # Tokenizer + tokenizer = get_tokenizer(model_args.student) + tokenizer_kwargs = data_cfg.get("tokenizer_kwargs", {}) + configured_tokenizer = ConfiguredTokenizer( + tokenizer=tokenizer, + apply_chat_template_kwargs=tokenizer_kwargs, + encode_kwargs={}, + ) + + # Dataset + train_dataset = _create_train_dataset(data_cfg, configured_tokenizer) + + if args.prepare_dataset_only: + logger.info("Dataset preparation complete. Exiting (--prepare-dataset-only).") + return + + # Save config for reproducibility + save_config(cfg, training_args.output_dir) + + # Build model + trainer_state = DMSTrainerState() + combined_model = build_combined_model( + model_args=model_args, + training_args=training_args, + dms_cfg=dms_cfg, + data_cfg=data_cfg, + tokenizer=tokenizer, + trainer_state=trainer_state, + ) + + data_collator = DataCollatorWithPadding( + tokenizer=tokenizer, + padding="max_length", + max_length=data_cfg["max_length"], + return_tensors="pt", + ) + + trainer = ModifiedTrainer( + trainer_state=trainer_state, + model=combined_model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator, + ) + + trainer_state.set_trainer(trainer) + + # Train + if checkpoint_path and os.path.exists(checkpoint_path): + logger.info(f"Resuming from checkpoint: {checkpoint_path}") + trainer.train(resume_from_checkpoint=checkpoint_path) + else: + trainer.train() + + # Auto-save student model at end of training + student_model_path = os.path.join(training_args.output_dir, "student_model") + extract_student_model(combined_model, tokenizer, student_model_path) + + logger.info("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/experimental/dms/pyproject.toml b/experimental/dms/pyproject.toml new file mode 100644 index 0000000000..af904be23b --- /dev/null +++ b/experimental/dms/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "dms" +version = "0.1.0" +requires-python = ">=3.10,<3.14" +description = "Dynamic Memory Sparsification (DMS) for KV cache compression" +dependencies = [ + "transformers==4.57.3", + "datasets==4.4.2", + "accelerate==1.4.0", + "lm_eval[ruler]", +] + +[tool.setuptools.packages.find] +include = ["dms*"] + +[tool.pytest.ini_options] +testpaths = ["tests/"] +pythonpath = ["."] +filterwarnings = [ + "ignore::DeprecationWarning", + "ignore:flex_attention called without torch.compile:UserWarning", +] diff --git a/experimental/dms/scripts/evaluate.sh b/experimental/dms/scripts/evaluate.sh new file mode 100755 index 0000000000..08dfabc9a1 --- /dev/null +++ b/experimental/dms/scripts/evaluate.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Evaluate a trained DMS model using lm-eval-harness. +# +# Prerequisites: +# pip install -e . (installs dms + lm-eval) +# +# Usage: +# bash scripts/evaluate.sh /path/to/student_model +# +# The saved model imports from the dms package, so it must be installed +# in the environment where evaluation runs. + +set -x + +MODEL_PATH=$1 +test -z "$MODEL_PATH" && echo "Usage: bash scripts/evaluate.sh MODEL_PATH" && exit 1 + +accelerate launch -m lm_eval \ + --model hf \ + --model_args pretrained=${MODEL_PATH},dtype="bfloat16",trust_remote_code=true,dms_chunked_prefill=4096 \ + --tasks niah_single_2 \ + --output_path "${MODEL_PATH}/eval_results" \ + --log_samples \ + --device cuda \ + --batch_size 2 \ + --metadata '{"max_seq_lengths":[32768]}' diff --git a/experimental/dms/scripts/train.sh b/experimental/dms/scripts/train.sh new file mode 100755 index 0000000000..2f869b4a7d --- /dev/null +++ b/experimental/dms/scripts/train.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DMS adapter training script for Qwen3-8B. +# +# Usage: +# bash scripts/train.sh configs/qwen3_8b.yaml +# +# This first prepares the dataset with a single process, +# then launches distributed training with accelerate. + +set -e + +CONFIG=${1:-configs/qwen3_8b.yaml} +test -f "$CONFIG" || { echo "Config not found: $CONFIG"; exit 1; } + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +echo "=== Preparing dataset (single process) ===" +python -m models.qwen3.train --config "$CONFIG" --prepare-dataset-only + +echo "=== Launching distributed training ===" +accelerate launch -m models.qwen3.train --config "$CONFIG" diff --git a/experimental/dms/scripts/train_small.sh b/experimental/dms/scripts/train_small.sh new file mode 100755 index 0000000000..9e8e73de77 --- /dev/null +++ b/experimental/dms/scripts/train_small.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Example debug launch script for single-GPU training with limited memory. + +set -e + +CONFIG=${1:-configs/qwen3_1.7b.yaml} +test -f "$CONFIG" || { echo "Config not found: $CONFIG"; exit 1; } + + +# to handle limited memory on GPUs +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + + +# Single-node, single-GPU FSDP configuration +export MASTER_ADDR=localhost +export MASTER_PORT=29500 +export WORLD_SIZE=1 +export RANK=0 +export LOCAL_RANK=0 +export CUDA_VISIBLE_DEVICES=0 + + +echo "=== Launching single-GPU training with FSDP offloading ===" +python3 -m models.qwen3.train --config "$CONFIG" diff --git a/experimental/dms/tests/conftest.py b/experimental/dms/tests/conftest.py new file mode 100644 index 0000000000..01073e41fe --- /dev/null +++ b/experimental/dms/tests/conftest.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pytest configuration for DMS tests.""" + +import warnings + +# Runs at collection time, before imports in test files +warnings.filterwarnings( + "ignore", + message="The 'experimental' package contains unstable APIs", + category=FutureWarning, +) diff --git a/experimental/dms/tests/test_chunked_prefill.py b/experimental/dms/tests/test_chunked_prefill.py new file mode 100644 index 0000000000..6490d46916 --- /dev/null +++ b/experimental/dms/tests/test_chunked_prefill.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DMS chunked prefill.""" + +import pytest +import torch + +from experimental.dms.tests.utils import add_dms_to_path + +try: + from dms.core import dms_perform_chunked_prefill +except ImportError: + add_dms_to_path() + from dms.core import dms_perform_chunked_prefill + + +class IdentityDecoderLayer(torch.nn.Module): + """Decoder layer that returns hidden states unchanged.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: torch.Tensor, + use_cache: bool, + cache_position: torch.Tensor, + position_embeddings: torch.Tensor, + **kwargs: dict, + ): + """Return hidden states unchanged, matching the HF decoder layer interface. + + All arguments besides hidden_states are accepted but ignored, so that + this layer can be used as a drop-in replacement in dms_perform_chunked_prefill. + """ + return hidden_states, None + + +@pytest.fixture +def decoder_layers(): + """Create a list of identity decoder layers.""" + return [IdentityDecoderLayer() for _ in range(2)] + + +class TestChunkedPrefill: + """Tests for dms_perform_chunked_prefill with identity decoder layers.""" + + @pytest.mark.parametrize("seed", range(10)) + def test_chunked_matches_unchunked(self, decoder_layers, seed): + """Chunked prefill output should match non-chunked output.""" + torch.manual_seed(seed) + batch_size = torch.randint(1, 10, (1,)).item() + seq_len = torch.randint(1, 100, (1,)).item() + chunk_size = torch.randint(1, 10, (1,)).item() + hidden_dim = torch.randint(1, 16, (1,)).item() + + hidden_states = torch.randn(batch_size, seq_len, hidden_dim) + + output_no_chunking, _ = dms_perform_chunked_prefill( + decoder_layers=decoder_layers, + hidden_states=hidden_states, + attention_mask=None, + position_ids=None, + past_key_values=None, + use_cache=True, + cache_position=None, + position_embeddings=None, + dms_manual_inference_mode=False, + dms_chunked_prefill=None, + ) + + output_chunking, _ = dms_perform_chunked_prefill( + decoder_layers=decoder_layers, + hidden_states=hidden_states, + attention_mask=None, + position_ids=None, + past_key_values=None, + use_cache=True, + cache_position=None, + position_embeddings=None, + dms_manual_inference_mode=False, + dms_chunked_prefill=chunk_size, + ) + + torch.testing.assert_close(output_no_chunking, output_chunking) diff --git a/experimental/dms/tests/test_dms_utils.py b/experimental/dms/tests/test_dms_utils.py new file mode 100644 index 0000000000..f6c9cc6f74 --- /dev/null +++ b/experimental/dms/tests/test_dms_utils.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DMS utility functions.""" + +import torch + +from experimental.dms.tests.utils import add_dms_to_path + +try: + from dms.core import get_gating_with_noise +except ImportError: + add_dms_to_path() + from dms.core import get_gating_with_noise + + +class TestGetGatingWithNoise: + """Tests for the get_gating_with_noise function.""" + + def test_output_shapes(self): + """Gating outputs should match the shape of the input weights.""" + batch, heads, seq_len = 2, 4, 16 + gating_weights = torch.randn(batch, heads, seq_len) + noise = torch.randn(batch, heads, seq_len) + + probs, decisions, logits = get_gating_with_noise(gating_weights, noise, tau=1.0) + + assert probs.shape == (batch, heads, seq_len) + assert decisions.shape == (batch, heads, seq_len) + assert logits.shape == (batch, heads, seq_len) + + def test_decisions_are_binary(self): + """Discretized decisions should contain only 0s and 1s (in forward pass values).""" + gating_weights = torch.randn(2, 4, 16) + noise = torch.randn(2, 4, 16) + + _probs, decisions, _logits = get_gating_with_noise(gating_weights, noise, tau=1.0) + + # The straight-through estimator means forward values are binary + unique_vals = set(decisions.detach().unique().tolist()) + assert unique_vals.issubset({0.0, 1.0}) + + def test_probs_in_unit_interval(self): + """Probabilities should be in the [0, 1] range (sigmoid output).""" + gating_weights = torch.randn(2, 4, 16) + noise = torch.randn(2, 4, 16) + + probs, _decisions, _logits = get_gating_with_noise(gating_weights, noise, tau=1.0) + + assert (probs >= 0.0).all() + assert (probs <= 1.0).all() + + def test_temperature_effect(self): + """Higher temperature should push probabilities closer to 0.5.""" + gating_weights = torch.tensor([2.0, -2.0]) + noise = torch.zeros(2) + + probs_low_tau, _, _ = get_gating_with_noise(gating_weights, noise, tau=0.1) + probs_high_tau, _, _ = get_gating_with_noise(gating_weights, noise, tau=10.0) + + # With high tau, probs should be closer to 0.5 than with low tau + dist_low = (probs_low_tau - 0.5).abs() + dist_high = (probs_high_tau - 0.5).abs() + assert (dist_high < dist_low).all() + + def test_gradient_flows_through_decisions(self): + """Straight-through estimator: gradients should flow through decisions.""" + gating_weights = torch.randn(2, 4, 16, requires_grad=True) + noise = torch.randn(2, 4, 16) + + _probs, decisions, _logits = get_gating_with_noise(gating_weights, noise, tau=1.0) + loss = decisions.sum() + loss.backward() + + assert gating_weights.grad is not None + assert gating_weights.grad.shape == gating_weights.shape diff --git a/experimental/dms/tests/test_paged_cache.py b/experimental/dms/tests/test_paged_cache.py new file mode 100644 index 0000000000..fbffd0ad8d --- /dev/null +++ b/experimental/dms/tests/test_paged_cache.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DMS paged cache layer.""" + +import pytest +import torch + +from experimental.dms.tests.utils import add_dms_to_path + +try: + from dms.cache import DMSPagedCacheLayer +except ImportError: + add_dms_to_path() + from dms.cache import DMSPagedCacheLayer + +# --------------------------------------------------------------------------- +# Test helper: ExtendedDMSPagedCacheLayer +# --------------------------------------------------------------------------- + + +def _recent_position_size(cache_seq_lengths, dms_window_size): + """Return the number of tokens stored in the recent (sliding-window) region.""" + return torch.clamp(cache_seq_lengths, max=dms_window_size) + + +def _first_recent_position(recent_info_position, cache_seq_lengths, dms_window_size): + """Return the ring-buffer index of the oldest token in the recent region.""" + return recent_info_position - _recent_position_size(cache_seq_lengths, dms_window_size) + + +class ExtendedDMSPagedCacheLayer(DMSPagedCacheLayer): + """DMSPagedCacheLayer extended with a get_contiguous_cache method for test verification. + + Reconstructs a dense (contiguous) view of the paged KV cache so that tests can + compare the internal block-based storage against a naive tracked reference. + """ + + def get_contiguous_cache(self, right_padded: bool = True): + """Return a dense view of keys, values, seq lengths, and eviction info. + + Args: + right_padded: If True (default), valid tokens are left-aligned and the + right side is zero-padded. If False, valid tokens are right-aligned + (left-padded). + + Returns: + Tuple of (keys, values, seq_lengths, eviction_info), each shaped with + leading dims (batch_size, num_heads, ...). + """ + assert self.key_blocks is not None + num_blocks_per_seq = self.cache_seq_lengths.max().item() // self.block_size + 1 + blocks_to_retrieve = self.block_table[:, :num_blocks_per_seq] + max_length = self.cache_seq_lengths.max().item() + + def _gather_and_reorder(blocks): + """Gather blocks into a flat sequence and reorder by window position.""" + retrieved = blocks[blocks_to_retrieve].reshape( + self.page_batch, + num_blocks_per_seq * self.block_size, + self.head_dim, + ) + retrieved = retrieved[:, :max_length, :] + + permutation = torch.arange(max_length, device=blocks.device, dtype=torch.int32)[None, :] + permutation = torch.minimum(permutation, self.cache_seq_lengths[:, None] - 1) + permutation = permutation.broadcast_to(self.page_batch, max_length) + + page_batch_idx = torch.arange(self.page_batch, device=self.device, dtype=torch.int32) + + if not self.disable_eviction: + recent_size = torch.clamp(self.cache_seq_lengths, max=self.dms_window_size) + window_idx = torch.arange( + self.dms_window_size, device=self.device, dtype=torch.int32 + ) + adjusted_window_idx = torch.minimum( + window_idx[None, :], + torch.clamp(recent_size[:, None] - 1, min=0), + ) + + last_pos_ptr = ( + self.recent_info_position[:, None] - 1 - adjusted_window_idx + ) % self.dms_window_size + + window_positions = self.recent_info[page_batch_idx[:, None], last_pos_ptr, 0] + + perm_idx = self.cache_seq_lengths[:, None] - 1 - adjusted_window_idx + assert (perm_idx >= 0).all() + + non_window = permutation[:, :, None] != window_positions[:, None, :] + non_window = non_window.to(torch.int32).min(dim=-1).values.to(torch.bool) + + result_perm = torch.zeros_like(permutation) + result_perm[page_batch_idx[:, None], perm_idx] = window_positions + else: + non_window = torch.ones_like(permutation, dtype=torch.bool) + result_perm = torch.zeros_like(permutation) + + num_non_window = non_window.to(torch.int32).sum(dim=-1).cpu().tolist() + for i in range(self.page_batch): + result_perm[i, : num_non_window[i]] = permutation[i, non_window[i]] + + return retrieved[page_batch_idx[:, None], result_perm] + + retrieved_keys = _gather_and_reorder(self.key_blocks) + retrieved_values = _gather_and_reorder(self.value_blocks) + seq_lengths = self.cache_seq_lengths + + # Retrieve eviction info from the ring buffer + if not self.disable_eviction: + page_batch_idx = torch.arange(self.page_batch, device=self.device, dtype=torch.int32) + + ei_idx = torch.arange(self.dms_window_size, device=self.device, dtype=torch.int32) + ei_idx = torch.minimum( + ei_idx[None, :], + _recent_position_size(self.cache_seq_lengths, self.dms_window_size)[:, None] - 1, + ) + ei_idx = ( + _first_recent_position( + self.recent_info_position, self.cache_seq_lengths, self.dms_window_size + )[:, None] + + ei_idx + ) % self.dms_window_size + + eviction_info = self.recent_info[page_batch_idx[:, None], ei_idx, 1] + + # Convert to left-padded layout if requested + if not right_padded: + + def _left_pad(x, lens): + total = x.shape[1] + lens_list = lens.cpu().tolist() + padded = [] + for i in range(self.page_batch): + valid = x[i, : lens_list[i]] + pad = torch.zeros( + total - lens_list[i], *valid.shape[1:], device=x.device, dtype=x.dtype + ) + padded.append(torch.cat([pad, valid], dim=0)) + return torch.stack(padded, dim=0) + + retrieved_keys = _left_pad(retrieved_keys, seq_lengths) + retrieved_values = _left_pad(retrieved_values, seq_lengths) + seq_lengths = seq_lengths.reshape(self.batch_size, self.num_heads) + if not self.disable_eviction: + eviction_info = _left_pad( + eviction_info, + _recent_position_size(self.cache_seq_lengths, self.dms_window_size), + ) + + retrieved_keys = retrieved_keys.reshape( + self.batch_size, self.num_heads, max_length, self.head_dim + ).contiguous() + retrieved_values = retrieved_values.reshape( + self.batch_size, self.num_heads, max_length, self.head_dim + ).contiguous() + seq_lengths = seq_lengths.reshape(self.batch_size, self.num_heads) + + if not self.disable_eviction: + eviction_info = eviction_info.reshape(self.batch_size, self.num_heads, -1) + else: + eviction_info = torch.zeros( + self.batch_size, + self.num_heads, + max_length, + device=retrieved_values.device, + dtype=torch.int32, + ) + + return retrieved_keys, retrieved_values, seq_lengths, eviction_info + + +# --------------------------------------------------------------------------- +# Tests: fast_update_ignore_eviction +# --------------------------------------------------------------------------- + + +class TestFastUpdateIgnoreEviction: + """Verify fast_update_ignore_eviction produces the same cache state as the regular update path.""" + + @pytest.mark.parametrize("seed", range(5)) + def test_fast_update_matches_regular_update(self, seed): + """fast_update_ignore_eviction should produce identical cache contents to update().""" + torch.manual_seed(seed) + + max_elem = 32 + max_seq_len_bound = 128 + + max_seq_len = torch.randint(2, max_seq_len_bound, (1,)).item() + block_size = torch.randint(2, max_elem - 1, (1,)).item() + dms_window_size = torch.randint(block_size + 1, max_elem, (1,)).item() + batch_size = torch.randint(1, max_elem, (1,)).item() + head = torch.randint(1, 3, (1,)).item() + head_dim = torch.randint(1, 4, (1,)).item() + max_context_length = 10 * max_seq_len + + cache_regular = ExtendedDMSPagedCacheLayer( + dms_window_size=dms_window_size, + max_context_length=max_context_length, + block_size=block_size, + accommodate_min_initial_context_length=max_context_length, + disable_eviction=True, + ) + cache_fast = ExtendedDMSPagedCacheLayer( + dms_window_size=dms_window_size, + max_context_length=max_context_length, + block_size=block_size, + accommodate_min_initial_context_length=max_context_length, + disable_eviction=True, + ) + + for _ in range(5): + seq_len = torch.randint(0, max_seq_len, (batch_size, head)) + key_states = torch.randint(0, 100, (batch_size, head, max_seq_len, head_dim)) + value_states = torch.randint(0, 100, (batch_size, head, max_seq_len, head_dim)) + + if seq_len.max() == 0: + continue + + cache_regular.update( + key_states, + value_states, + { + "eviction_info": torch.zeros(batch_size, head, max_seq_len), + "sequence_lengths": seq_len, + "cumulative_length": max_seq_len, + }, + ) + cache_fast.fast_update_ignore_eviction(key_states, value_states, seq_len) + + cont_regular = cache_regular.get_contiguous_cache() + cont_fast = cache_fast.get_contiguous_cache() + + for regular_tensor, fast_tensor in zip(cont_regular, cont_fast): + assert (regular_tensor == fast_tensor).all() + + +# --------------------------------------------------------------------------- +# Tests: paged cache update correctness +# --------------------------------------------------------------------------- + + +def _run_paged_cache_update_test(seed, disable_eviction): + """Run a multi-step paged cache update test against a naive tracked reference. + + At each step, random KV pairs with eviction decisions are fed into the cache. + A naive Python tracker mirrors the expected cache state, and the two are compared + after each step for keys, values, eviction info, and left/right padding consistency. + """ + torch.manual_seed(seed) + max_val = 16 + upper_bound_seq_len = 100 + batch_size = torch.randint(1, max_val, (1,)).item() + num_heads = torch.randint(1, max_val, (1,)).item() + block_size = torch.randint(1, max_val, (1,)).item() + head_dim = torch.randint(1, max_val, (1,)).item() + dms_window_size = torch.randint(block_size + 1, max_val + 1, (1,)).item() + + cache = ExtendedDMSPagedCacheLayer( + dms_window_size=dms_window_size, + max_context_length=32768, + block_size=block_size, + accommodate_min_initial_context_length=torch.randint(1, max_val, (1,)).item(), + disable_eviction=disable_eviction, + ) + + page_batch = batch_size * num_heads + tracked_keys = [[] for _ in range(page_batch)] + tracked_values = [[] for _ in range(page_batch)] + tracked_eviction = [[] for _ in range(page_batch)] + + for step in range(10): + max_seq_len = torch.randint(1, upper_bound_seq_len, (1,)).item() + keys = torch.randint(0, 1_000_000_000, (batch_size, num_heads, max_seq_len, head_dim)) + values = torch.randint(0, 1_000_000_000, (batch_size, num_heads, max_seq_len, head_dim)) + eviction_info = ( + torch.randint(0, 2, (batch_size, num_heads, max_seq_len)) + if not disable_eviction + else torch.zeros(batch_size, num_heads, max_seq_len) + ) + + seq_lengths = torch.randint(1, max_seq_len + 1, (batch_size, num_heads)) + + # Occasionally force a sequence length of 1 for edge-case coverage + if max_seq_len > 1 and torch.randint(0, 2, (1,)).item() == 1: + p = torch.randint(0, batch_size, (1,)).item() + q = torch.randint(0, num_heads, (1,)).item() + seq_lengths[p, q] = 1 + + keys_flat = keys.reshape(page_batch, max_seq_len, head_dim) + values_flat = values.reshape(page_batch, max_seq_len, head_dim) + eviction_flat = eviction_info.reshape(page_batch, max_seq_len) + seq_lengths_flat = seq_lengths.reshape(page_batch) + + # --- Update naive tracked reference --- + for j in range(page_batch): + sl = seq_lengths_flat[j] + if sl == 0: + continue + + for s in range(max_seq_len - sl, max_seq_len): + tracked_keys[j].append(keys_flat[j, [s]]) + tracked_values[j].append(values_flat[j, [s]]) + if len(tracked_eviction[j]) > 0: + tracked_eviction[j][-1] = eviction_flat[j, [s]] + tracked_eviction[j].append(torch.zeros_like(eviction_flat[j, [s]])) + + if len(tracked_keys[j]) > dms_window_size: + if tracked_eviction[j][-dms_window_size - 1] == 1: + del tracked_keys[j][-dms_window_size - 1] + del tracked_values[j][-dms_window_size - 1] + del tracked_eviction[j][-dms_window_size - 1] + + # --- Update actual cache --- + cache.update( + key_states=keys, + value_states=values, + cache_kwargs={ + "eviction_info": eviction_info, + "sequence_lengths": seq_lengths, + "cumulative_length": 1, + }, + ) + + # --- Retrieve and verify --- + (ret_keys, ret_values, cache_seq_lens, ret_eviction) = cache.get_contiguous_cache() + (ret_keys_lp, ret_values_lp, cache_seq_lens_lp, ret_eviction_lp) = ( + cache.get_contiguous_cache(right_padded=False) + ) + + ret_keys = ret_keys.reshape(page_batch, -1, head_dim) + ret_keys_lp = ret_keys_lp.reshape(page_batch, -1, head_dim) + ret_values = ret_values.reshape(page_batch, -1, head_dim) + ret_values_lp = ret_values_lp.reshape(page_batch, -1, head_dim) + cache_seq_lens = cache_seq_lens.reshape(page_batch) + cache_seq_lens_lp = cache_seq_lens_lp.reshape(page_batch) + ret_eviction = ret_eviction.reshape(page_batch, -1) + ret_eviction_lp = ret_eviction_lp.reshape(page_batch, -1) + + def _assert_tracked_matches(tracked, retrieved, j): + """Assert that the tracked reference matches the retrieved cache for head j.""" + tracked_cat = torch.concat(tracked[j], dim=0) + sl = cache_seq_lens[j].item() + retrieved_trimmed = retrieved[j, :sl] + + assert len(retrieved_trimmed) == sl + assert len(tracked_cat) == len(retrieved_trimmed), ( + f"step: {step}, j: {j}\n" + f" tracked {tracked_cat.shape}: {tracked_cat}\n" + f" retrieved {retrieved_trimmed.shape}: {retrieved_trimmed}" + ) + + # Recent window must match exactly + a = tracked_cat[-dms_window_size:] + b = retrieved_trimmed[-dms_window_size:] + assert torch.allclose(a, b), ( + f"dms_window_size: {dms_window_size}, step: {step}, j: {j}\n" + f" tracked {a.shape}: {a}\n retrieved {b.shape}: {b}" + ) + + # Older tokens may be reordered by the paged allocator but must all be present + a = tracked_cat[:-dms_window_size] + b = retrieved_trimmed[:-dms_window_size] + if len(a) > 0: + cmp = a[:, None, :] == b[None, :, :] + cmp = cmp.to(torch.int32).min(dim=-1).values.max(dim=-1).values.to(torch.bool) + assert cmp.all(), ( + f"pref: dms_window_size: {dms_window_size}, step: {step}, j: {j}\n" + f" tracked {a.shape}: {a}\n retrieved {b.shape}: {b}" + ) + + for j in range(page_batch): + sl = cache_seq_lens[j].item() + sl_lp = cache_seq_lens_lp[j].item() + + # Left-padded view should match right-padded content + assert (ret_keys_lp[j][-sl_lp:] == ret_keys[j][:sl]).all() + assert (ret_values_lp[j][-sl_lp:] == ret_values[j][:sl]).all() + assert (ret_eviction_lp[j][-sl_lp:] == ret_eviction[j][:sl]).all() + + # Compare tracked reference against retrieved cache + _assert_tracked_matches(tracked_keys, ret_keys, j) + _assert_tracked_matches(tracked_values, ret_values, j) + + # Verify eviction info + a = ret_eviction[j][:sl] + b = torch.concat(tracked_eviction[j][-len(a) :], dim=0) + assert b[-1] == 0 + assert (a[:-1] == b[:-1]).all(), ( + f"step: {step}, j: {j}\n a {a.shape}: {a}\n b {b.shape}: {b}" + ) + + +class TestPagedCacheUpdate: + """Verify paged cache update against a naive element-by-element tracked reference.""" + + @pytest.mark.parametrize("seed", range(5)) + def test_update_with_eviction(self, seed): + """Cache update with eviction enabled should match the tracked reference.""" + _run_paged_cache_update_test(seed, disable_eviction=False) + + @pytest.mark.parametrize("seed", range(5)) + def test_update_without_eviction(self, seed): + """Cache update with eviction disabled should match the tracked reference.""" + _run_paged_cache_update_test(seed, disable_eviction=True) diff --git a/experimental/dms/tests/test_prefill_and_generate.py b/experimental/dms/tests/test_prefill_and_generate.py new file mode 100644 index 0000000000..55a53127ce --- /dev/null +++ b/experimental/dms/tests/test_prefill_and_generate.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DMS flex attention prefill and inference.""" + +import pytest +import torch + +from experimental.dms.tests.utils import add_dms_to_path, ignore_flex_attention_warnings + +try: + from dms.attention import dms_attn_eval_mode + from dms.attention_prefill import dms_run_prefill_flex, wrapped_flex_attention + from dms.cache import DMSCombinedCacheLayer +except ImportError: + add_dms_to_path() + from dms.attention import dms_attn_eval_mode + from dms.attention_prefill import dms_run_prefill_flex, wrapped_flex_attention + from dms.cache import DMSCombinedCacheLayer + +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + +MASK_VALUE = -1e9 + + +def fake_flash_attn_with_kvcache( + q: torch.Tensor, + k_blocks: torch.Tensor, + v_blocks: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cache_seqlens: torch.Tensor, + causal: bool, + softmax_scale: float, + block_table: torch.Tensor, + return_softmax_lse: bool = False, +): + """Naive flash_attn_with_kvcache replacement that reconstructs KV from paged blocks.""" + k = k_blocks + v = v_blocks + + _num_blocks, block_size, _, head_dim = k.size() + page_batch, seq_len, q_per_kv, _head_dim = q.size() + + cont_cache_keys = [] + cont_cache_values = [] + + max_seq_len = cache_seqlens.max() + max_blocks = (max_seq_len + block_size - 1) // block_size + + for pb in range(page_batch): + cont_cache_keys.append([]) + cont_cache_values.append([]) + for b in range(max_blocks): + block_id = block_table[pb, b] + cont_cache_keys[pb].append(k[block_id, :, 0, :]) + cont_cache_values[pb].append(v[block_id, :, 0, :]) + + cont_cache_keys[pb] = torch.cat(cont_cache_keys[pb], dim=0) + cont_cache_values[pb] = torch.cat(cont_cache_values[pb], dim=0) + + cont_cache_keys = torch.stack(cont_cache_keys, dim=0) + cont_cache_values = torch.stack(cont_cache_values, dim=0) + + attention_mask = ( + torch.arange(cont_cache_keys.size(1), device=q.device)[None, :] >= cache_seqlens[:, None] + ) + + attn_scores = torch.einsum("bsqd,bld->bqsl", q, cont_cache_keys * softmax_scale) + + attention_mask = attention_mask[:, None, None, :].broadcast_to(attn_scores.size()) + + iter_k = torch.arange(cont_cache_keys.size(1), device=q.device) + iter_q = ( + torch.arange(seq_len, device=q.device)[None, None, :, None] + + cache_seqlens[:, None, None, None] + - seq_len + ) + causal_mask = iter_q[:, :, :, :] < iter_k[None, None, None, :] + if not causal: + causal_mask = torch.zeros_like(causal_mask, dtype=torch.bool) + + attention_mask = torch.logical_or(attention_mask, causal_mask) + + attn_scores[attention_mask] = -1e9 + + logsumexp = torch.logsumexp(attn_scores, dim=-1) + + attn_scores = torch.softmax(attn_scores, dim=-1) + + result = torch.einsum("bqsl,bld->bsqd", attn_scores, cont_cache_values) + + if return_softmax_lse: + return result, logsumexp + else: + return result + + +def _simple_code_for_dms_exact_attention(q, k, v, d, a, state, attn_scaling, window_size): + """Reference implementation of DMS exact attention for prefill verification.""" + page_batch, seq_len_q, q_per_kv, head_dim = q.size() + batch, head_k, seq_len_k, head_dim_k = k.size() + assert seq_len_q == seq_len_k + k = k.reshape(page_batch, seq_len_k, head_dim) + v = v.reshape(page_batch, seq_len_k, head_dim) + d = d.reshape(page_batch, seq_len_k) + a = a.reshape(page_batch, seq_len_k) + + org_a = a.clone() + + if len(state) != 0: + k = torch.cat([state["k"], k], dim=1) + v = torch.cat([state["v"], v], dim=1) + d = torch.cat([state["d"], d], dim=1) + a = torch.cat([state["a"], a], dim=1) + + attn_scores = torch.einsum("psgh,plh->pgsl", q, k * attn_scaling) + + seq_len_k = k.shape[1] + offset = seq_len_k - seq_len_q + id_q = torch.arange(seq_len_q, device=q.device, dtype=torch.int32) + offset + id_k = torch.arange(seq_len_k, device=k.device, dtype=torch.int32) + + causal_mask = id_q[:, None] < id_k[None, :] + + attn_mask = ( + torch.logical_not(a).reshape(page_batch, 1, 1, seq_len_k).broadcast_to(attn_scores.size()) + ) + causal_mask = causal_mask[None, None, :, :].broadcast_to(attn_scores.size()) + attn_mask = torch.logical_or(attn_mask, causal_mask) + dms_within_window = (id_q[:, None] - id_k[None, :]) < window_size + + shifted_d = torch.nn.functional.pad(d[:, 1:], (0, 1), value=0) + dms_eviction = shifted_d == 1 + dms_masked = torch.logical_and( + torch.logical_not(dms_within_window[None, None, :, :]), + dms_eviction[:, None, None, :], + ) + + attn_mask = torch.logical_or(attn_mask, dms_masked) + attn_scores[attn_mask] = MASK_VALUE + + state["k"] = k + state["v"] = v + state["d"] = d + state["a"] = a + + attn_scores = torch.softmax(attn_scores, dim=-1) + attn_output = torch.einsum("pgsl,plh->psgh", attn_scores, v) + + tmp = org_a[:, :, None, None].broadcast_to(attn_output.size()) + attn_output[torch.logical_not(tmp)] = 0 + + return attn_output.reshape(batch, head_k, seq_len_q, q_per_kv, head_dim) + + +def _simple_code_for_dms_fast_attention_inference(q, k, v, d, a, attn_scaling, window_size): + """Reference implementation of DMS attention for single-step inference verification.""" + page_batch, seq_len_q, q_per_kv, head_dim = q.size() + batch, head_k, seq_len_k, head_dim_k = k.size() + assert seq_len_q == 1 + assert seq_len_k > seq_len_q + k = k.reshape(page_batch, seq_len_k, head_dim) + v = v.reshape(page_batch, seq_len_k, head_dim) + d = d.reshape(page_batch, seq_len_k) + a = a.reshape(page_batch, seq_len_k) + + attn_scores = torch.einsum("psgh,plh->pgsl", q, k * attn_scaling) + + shifted_d = torch.nn.functional.pad(d[:, 1:], (0, 1), value=0) + k_iter = torch.arange(seq_len_k - 1, -1, -1, device=d.device, dtype=torch.int32) + should_be_evicted = torch.logical_and(shifted_d == 1, k_iter[None, :] >= window_size) + + attn_mask = ( + torch.logical_not(a).reshape(page_batch, 1, 1, seq_len_k).broadcast_to(attn_scores.size()) + ) + attn_mask = torch.logical_or(attn_mask, should_be_evicted[:, None, None, :]) + + attn_scores[attn_mask] = MASK_VALUE + attn_scores = torch.softmax(attn_scores, dim=-1) + attn_output = torch.einsum("pgsl,plh->psgh", attn_scores, v) + + return attn_output.reshape(batch, head_k, seq_len_q, q_per_kv, head_dim) + + +def _generate_random_test_params(seed): + """Generate randomized test parameters from a seed.""" + torch.manual_seed(seed) + batch = torch.randint(1, 5, (1,)).item() + heads_kv = torch.randint(1, 5, (1,)).cuda().item() + gqa_factor = torch.randint(1, 4, (1,)).cuda().item() + seq_len = torch.randint(8, 1024, (1,)).cuda().item() + head_dim = 3 + chunk_size = torch.randint(1, 128, (1,)).cuda().item() + dms_block_size = torch.randint(2, 32, (1,)).cuda().item() + dms_window_size = torch.randint(dms_block_size + 1, 128, (1,)).cuda().item() + + return { + "batch": batch, + "heads_kv": heads_kv, + "gqa_factor": gqa_factor, + "seq_len": seq_len, + "head_dim": head_dim, + "chunk_size": chunk_size, + "dms_block_size": dms_block_size, + "dms_window_size": dms_window_size, + } + + +def _run_prefill(params): + """Run chunked prefill and verify against the reference implementation. + + Returns the tensors and cache needed for the subsequent inference test. + """ + batch = params["batch"] + heads_kv = params["heads_kv"] + gqa_factor = params["gqa_factor"] + seq_len = params["seq_len"] + head_dim = params["head_dim"] + chunk_size = params["chunk_size"] + dms_block_size = params["dms_block_size"] + dms_window_size = params["dms_window_size"] + + query = torch.randn( + (batch * heads_kv, seq_len, gqa_factor, head_dim), dtype=torch.float64 + ).cuda() + key = torch.randn((batch, heads_kv, seq_len, head_dim), dtype=torch.float64).cuda() + value = torch.randn((batch, heads_kv, seq_len, head_dim), dtype=torch.float64).cuda() + decisions = (torch.randint(0, 100, (batch, heads_kv, seq_len)) <= 90).to(torch.long).cuda() + attention_mask = torch.ones((batch, seq_len), dtype=torch.bool).cuda() + + for i in range(batch): + rnd = torch.randint(0, 2, (1,)).item() + if rnd == 0: + pad_len = torch.randint(0, 6, (1,)).item() + attention_mask[i, :pad_len] = False + + cache = DMSCombinedCacheLayer( + dms_window_size=dms_window_size, + max_context_length=8192, + block_size=dms_block_size, + ) + cache.prefill_mode() + state = {} + + for chunk_idx in range((seq_len + chunk_size - 1) // chunk_size): + start_idx = chunk_idx * chunk_size + end_idx = min(start_idx + chunk_size, seq_len) + q = query[:, start_idx:end_idx, :, :] + k = key[:, :, start_idx:end_idx, :] + v = value[:, :, start_idx:end_idx, :] + d = decisions[:, :, start_idx:end_idx] + a = attention_mask[:, start_idx:end_idx] + + attn_output_actual = dms_run_prefill_flex( + q_flash=q, + keys=k, + values=v, + decisions=d, + attn_mask=a, + cache=cache, + attn_scaling=1.0, + flash_attn_fn=fake_flash_attn_with_kvcache, + flex_attention_fn=wrapped_flex_attention, + ) + + a_expanded = a.reshape(a.shape[0], 1, a.shape[1]).broadcast_to( + (a.shape[0], heads_kv, a.shape[1]) + ) + attn_output_expected = _simple_code_for_dms_exact_attention( + q=q, + k=k, + v=v, + d=d, + a=a_expanded, + state=state, + attn_scaling=1.0, + window_size=dms_window_size, + ) + + diff = (attn_output_actual - attn_output_expected).abs().max() + assert diff < 1e-6, f"Prefill chunk {chunk_idx}: max diff = {diff}" + + return query, key, value, decisions, attention_mask, cache + + +@requires_cuda +@ignore_flex_attention_warnings +class TestDMSPrefill: + """Tests for DMS prefill attention matching the reference implementation.""" + + @pytest.mark.parametrize("seed", range(5)) + def test_prefill_matches_reference(self, seed): + """Chunked prefill output should match the naive reference implementation.""" + params = _generate_random_test_params(seed) + _run_prefill(params) + + +@requires_cuda +@ignore_flex_attention_warnings +class TestDMSInferenceAfterPrefill: + """Tests for DMS inference-mode attention after prefill.""" + + @pytest.mark.parametrize("seed", range(5)) + def test_generate_after_prefill_matches_reference(self, seed): + """Single-step generation output should match the naive reference implementation.""" + params = _generate_random_test_params(seed) + query, total_k, total_v, total_d, total_a, cache = _run_prefill(params) + + torch.manual_seed(seed) + cache.inference_mode() + + page_batch, _seq_len_q, q_per_kv, head_dim = query.size() + batch, head_k, _seq_len_k, head_dim_k = total_k.size() + dms_window_size = cache.dms_window_size + + num_generate_steps = 10 + for step in range(num_generate_steps): + q = torch.randn((page_batch, 1, q_per_kv, head_dim), dtype=torch.float64).cuda() + k = torch.randn((batch, head_k, 1, head_dim_k), dtype=torch.float64).cuda() + v = torch.randn((batch, head_k, 1, head_dim_k), dtype=torch.float64).cuda() + d = torch.randint(0, 2, (batch, head_k, 1), dtype=torch.long).cuda() + a = torch.ones((batch, 1), dtype=torch.bool).cuda() + + total_k = torch.cat([total_k, k], dim=2) + total_v = torch.cat([total_v, v], dim=2) + total_d = torch.cat([total_d, d], dim=2) + total_a = torch.cat([total_a, a], dim=1) + + attn_output_actual = dms_attn_eval_mode( + new_q_flash=q.clone(), + new_k=k.clone(), + new_v=v.clone(), + decisions=d.clone(), + attention_mask=None, + layer_idx=0, + dms_cache=[cache], + attn_scaling=1.0, + flash_attn_fn=fake_flash_attn_with_kvcache, + prefill_attn_fn_kwargs={ + "flex_attention_fn": wrapped_flex_attention, + }, + ) + + total_a_expanded = total_a.reshape(batch, 1, total_a.shape[1]).broadcast_to( + (batch, head_k, total_a.shape[1]) + ) + attn_output_expected = _simple_code_for_dms_fast_attention_inference( + q=q, + k=total_k, + v=total_v, + d=total_d, + a=total_a_expanded, + attn_scaling=1.0, + window_size=dms_window_size, + ) + + diff = (attn_output_actual - attn_output_expected).abs().max() + assert diff < 1e-7, f"Generate step {step}: max diff = {diff}" diff --git a/experimental/dms/tests/utils.py b/experimental/dms/tests/utils.py new file mode 100644 index 0000000000..56fe5a323c --- /dev/null +++ b/experimental/dms/tests/utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Utility functions for DMS tests.""" + +import os +import sys + +import pytest + + +def add_dms_to_path(): + """Add the DMS package to the Python path.""" + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +ignore_flex_attention_warnings = pytest.mark.filterwarnings( + "ignore:flex_attention called without torch.compile:UserWarning", +) diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index cabeff733a..54fc2c232a 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -66,8 +66,12 @@ def get_parser() -> argparse.ArgumentParser: "--calibration_data", "-d", type=str, - help="File path to inputs for reference runner, either NPZ or Polygraphy JSON file. " - "If not provided, random inputs will be used", + help="File path to inputs for reference runner. Supports: " + "(1) NPZ file for single batch, " + "(2) Directory containing multiple NPZ files for multi-batch calibration, " + "(3) Polygraphy JSON file (supports multiple batches). " + "Multi-batch calibration aggregates statistics across all batches for more robust " + "precision conversion decisions. If not provided, random inputs will be used.", ) parser.add_argument( "--nodes_to_exclude", @@ -185,6 +189,16 @@ def get_parser() -> argparse.ArgumentParser: "higher version." ), ) + parser.add_argument( + "--use_standalone_type_inference", + action="store_true", + default=False, + help=( + "Use local type inference implementation instead of ONNX's infer_shapes (experimental)." + "This is a workaround for cases where shape inference fails for any reason." + "Default: False (uses ONNX's infer_shapes which does both shape and type inference)." + ), + ) return parser @@ -218,6 +232,7 @@ def main(argv=None): trt_plugins_precision=args.trt_plugins_precision, max_depth_of_reduction=args.max_depth_of_reduction, opset=args.opset, + use_standalone_type_inference=args.use_standalone_type_inference, ) output_path = args.output_path diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 4328c9fc29..ec66ec11ed 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -33,6 +33,7 @@ from modelopt.onnx.autocast.nodeclassifier import NodeClassifier, NodeRuleBase from modelopt.onnx.autocast.precisionconverter import PrecisionConverter from modelopt.onnx.autocast.referencerunner import ReferenceRunner +from modelopt.onnx.utils import get_min_opset_for_precisions, get_qdq_precisions """ FP16 accuracy decreases in accordance with the data's magnitude. @@ -61,6 +62,7 @@ def convert_to_mixed_precision( trt_plugins_precision: list[str] = [], max_depth_of_reduction: int | None = None, opset: int | None = None, + use_standalone_type_inference: bool = False, ) -> onnx.ModelProto: """Convert model to mixed precision. @@ -83,8 +85,11 @@ def convert_to_mixed_precision( trt_plugins_precision: List indicating the precision for each custom op. max_depth_of_reduction: Maximum depth of reduction for node classification. opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type - (22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations + (22 for bf16, 19 for fp16). The opset may be automatically increased if certain operations require a higher version. + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. Returns: onnx.ModelProto: The converted mixed precision model. @@ -132,7 +137,7 @@ def convert_to_mixed_precision( model = graph_sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) # Automatically add 'trt' to list of providers if custom ops are detected @@ -164,6 +169,7 @@ def convert_to_mixed_precision( low_precision_type=low_precision_type, init_conversion_max_bytes=init_conversion_max_bytes, custom_ops=graph_sanitizer.custom_ops, + use_standalone_type_inference=use_standalone_type_inference, ) # Obtain reference data @@ -196,6 +202,8 @@ def convert_to_f16( op_block_list: list[str] = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], + use_standalone_type_inference: bool = False, + opset: int | None = None, ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -208,13 +216,48 @@ def convert_to_f16( op_block_list: List of operation types that should remain in FP32. tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). + use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's + infer_shapes. This is a workaround (WAR) when only type inference is + needed without shape inference. Default: False. + opset: Target ONNX opset version. If None, uses default minimum opset based on precision type + (22 for bf16, 19 for fp16) and Q/DQ node requirements. The opset may be automatically + increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19, + INT4 requires 21, NVFP4 requires 23). """ assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16" - # Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute) + # Check Q/DQ precision types in the model and determine required opset + qdq_precisions = get_qdq_precisions(model) + qdq_min_opset = get_min_opset_for_precisions(qdq_precisions) + + # Base minimum opset for FP16/BF16 conversion + # Opset 19 is the first to support fp16 scales in Q/DQ nodes + base_min_opset = 22 if low_precision_type == "bf16" else 19 + + # Determine target opset version + if opset is not None: + min_opset = opset + # Check if Q/DQ nodes require a higher opset + if qdq_precisions and qdq_min_opset > min_opset: + logger.warning( + f"Model contains Q/DQ nodes with precisions {qdq_precisions} that require " + f"opset >= {qdq_min_opset}. Upgrading from specified opset {opset} to {qdq_min_opset}." + ) + min_opset = qdq_min_opset + # Also ensure we meet base minimum for precision type + if min_opset < base_min_opset: + logger.warning( + f"Opset {min_opset} is below minimum opset {base_min_opset} for {low_precision_type}. " + f"Upgrading to opset {base_min_opset}." + ) + min_opset = base_min_opset + else: + # Use the highest required opset between base and Q/DQ requirements + min_opset = max(base_min_opset, qdq_min_opset) + sanitizer = GraphSanitizer( model, - min_opset=21, + min_opset=min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, ) @@ -225,7 +268,7 @@ def convert_to_f16( model = sanitizer.model # Setup internal mappings - model = onnx_utils.infer_shapes(model) + model = onnx_utils.infer_types(model, use_standalone_type_inference) value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) precision_converter = PrecisionConverter( @@ -237,6 +280,7 @@ def convert_to_f16( low_precision_type=low_precision_type, custom_ops=sanitizer.custom_ops, tensor_block_dict=tensor_block_dict, + use_standalone_type_inference=use_standalone_type_inference, ) high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list] low_precision_nodes = [ diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 0b63593705..85f407a591 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -67,6 +67,7 @@ def sanitize(self) -> None: self.convert_opset() self.replace_layernorm_pattern() self.ensure_graph_name_exists() + self.duplicate_shared_constants() onnx_utils.name_onnx_nodes(self.model.graph) self.replace_custom_domain_nodes() self.sanitize_io_casts() @@ -161,8 +162,8 @@ def convert_opset(self) -> None: ) self.min_opset = 19 - # Convert if any default domain opset is below minimum - if any(op.version < self.min_opset for op in default_opsets): + # Convert if the default domain opset is below minimum + if onnx_utils.get_opset_version(self.model) < self.min_opset: invalid_opsets = [op.version for op in default_opsets if op.version < self.min_opset] try: logger.info( @@ -254,6 +255,12 @@ def ensure_graph_name_exists(self) -> None: if not self.model.graph.name: self.model.graph.name = "model" + def duplicate_shared_constants(self) -> None: + """Duplicate constant tensors if they are shared.""" + self.model, is_duplicated_constant = onnx_utils.duplicate_shared_constants(self.model) + if is_duplicated_constant: + logger.warning("Shared constants were detected and duplicated accordingly.") + def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: """Match the sequence of operations that constitute a LayerNorm. diff --git a/modelopt/onnx/autocast/logging_config.py b/modelopt/onnx/autocast/logging_config.py index f65f041c09..21f9f65c5b 100644 --- a/modelopt/onnx/autocast/logging_config.py +++ b/modelopt/onnx/autocast/logging_config.py @@ -22,26 +22,54 @@ import logging import os +import sys -# Create a parent logger for all AutoCast components -logger = logging.getLogger("autocast") +# Create a logger for all AutoCast components as a child of modelopt.onnx +# This ensures autocast inherits log level and format when called from quantization +logger = logging.getLogger("modelopt.onnx.autocast") -def configure_logging(level=logging.INFO, log_file=None): +def configure_logging(level=None, log_file=None): """Configure logging for all AutoCast components. + If logging level is provided, it will be used regardless of parent logger log level. + Otherwise, inherits from parent logger if exists, or fallback to default: logging.INFO. + Args: - level: The logging level to use (default: logging.INFO). + level: The logging level to use. Can be a string (e.g., "DEBUG", "INFO") or + a logging constant (e.g., logging.DEBUG) default: None. log_file: Optional path to a log file. If provided, logs will be written to this file in addition to stdout (default: None). """ - # Set level for the parent logger and all child loggers + # Check if parent logger (modelopt.onnx) already has handlers configured + parent_logger = logging.getLogger("modelopt.onnx") + parent_has_handlers = len(parent_logger.handlers) > 0 + + # Determine the logging level to use + if level is None: + # No explicit level provided - inherit from parent or use default + if parent_has_handlers: + level = parent_logger.level + else: + level = logging.INFO + # else: use the provided level as-is + + # Set level for the autocast logger (accepts both string and int) logger.setLevel(level) + # If parent has handlers (standalone mode), also update parent's level + # so the parent's console handler respects the autocast log level + if parent_has_handlers: + parent_logger.setLevel(level) + # Remove any existing handlers to ensure clean configuration for handler in logger.handlers[:]: logger.removeHandler(handler) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(message)s" + ) + # Add file handler if log_file is specified if log_file: try: @@ -50,9 +78,6 @@ def configure_logging(level=logging.INFO, log_file=None): if log_dir: os.makedirs(log_dir, exist_ok=True) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(message)s" - ) file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) @@ -60,14 +85,22 @@ def configure_logging(level=logging.INFO, log_file=None): except Exception as e: logging.error(f"Failed to setup file logging to {log_file}: {e!s}") - # Allow log messages to propagate to the root logger for testing compatibility - # This enables pytest's caplog fixture to capture logs while still maintaining - # our custom formatting through the handlers above - logger.propagate = True + if parent_has_handlers: + # Parent logger is configured (called from quantization/other onnx modules) + # Propagate to parent to use its handlers and format + logger.propagate = True + else: + # Standalone mode (called directly via python3 -m modelopt.onnx.autocast) + # Add our own console handler with autocast-specific format + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + # Always propagate to support pytest's caplog fixture in tests + logger.propagate = True # Ensure all child loggers inherit the level setting for name in logging.root.manager.loggerDict: - if name.startswith("autocast"): + if name.startswith("modelopt.onnx.autocast"): logging.getLogger(name).setLevel(level) diff --git a/modelopt/onnx/autocast/nodeclassifier.py b/modelopt/onnx/autocast/nodeclassifier.py index 0a76384291..cbfc003412 100644 --- a/modelopt/onnx/autocast/nodeclassifier.py +++ b/modelopt/onnx/autocast/nodeclassifier.py @@ -150,20 +150,51 @@ def _log_skipped(self, node, **kwargs): class IORangeRule(NodeRuleBase): - """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + """Rule for keeping nodes with out-of-range inputs/outputs in high precision. + + Supports both single-batch (raw numpy arrays) and multi-batch (TensorStats objects) + reference data for flexible precision conversion decisions. + """ def __init__(self, data_max, reference_data, node_to_init_map): """Initialize the rule. Args: data_max: Maximum absolute value allowed for node I/O. - reference_data: Reference data for checking I/O ranges. + reference_data: Reference data for checking I/O ranges. Can contain either + raw numpy arrays (single batch) or TensorStats objects (multi-batch aggregated). node_to_init_map: Mapping from node names to their initializers. """ self.data_max = data_max self.reference_data = reference_data self.node_to_init_map = node_to_init_map self.output_data = None + self.output_stats = None # For TensorStats + + def _get_tensor_stats(self, ref_data): + """Extract statistics from reference data (supports both numpy arrays and TensorStats). + + Args: + ref_data: Either a numpy array or a TensorStats object. + + Returns: + tuple: (absmax, min_val, max_val, size) statistics. + """ + # Import here to avoid circular imports + from modelopt.onnx.autocast.referencerunner import TensorStats + + if isinstance(ref_data, TensorStats): + return ref_data.absmax, ref_data.min_val, ref_data.max_val, ref_data.size + else: + # Raw numpy array + if ref_data.size == 0: + return 0, 0, 0, 0 + return ( + np.max(np.abs(ref_data)), + np.min(ref_data), + np.max(ref_data), + ref_data.size, + ) def _check_inner(self, node): def is_io_out_of_range(node, tensor_name): @@ -176,18 +207,25 @@ def is_io_out_of_range(node, tensor_name): f"Node {node.name}: Tensor {tensor_name} not found in reference data." ) return False + ref_data = self.reference_data[tensor_name] - if ref_data.size == 0: + absmax, min_val, max_val, size = self._get_tensor_stats(ref_data) + + if size == 0: logger.debug( f"Node {node.name}: Tensor {tensor_name} has size 0. Skipping I/O range check." ) return False + logger.debug( - f"Node {node.name}: reference data: min={np.min(ref_data)}, max={np.max(ref_data)}" + f"Node {node.name}: reference data: min={min_val}, max={max_val}, absmax={absmax}" ) - if np.any(np.abs(ref_data) > self.data_max): + + if absmax > self.data_max: self.output_data = ref_data + self.output_stats = (absmax, min_val, max_val) return True + return False if node.op_type == "Constant": return False @@ -202,7 +240,13 @@ def is_io_out_of_range(node, tensor_name): def _log_skipped(self, node, **kwargs): """Log information about skipped nodes with I/O range violations.""" - if self.output_data is not None: + if self.output_stats is not None: + absmax, min_val, max_val = self.output_stats + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={min_val}, " + f"max={max_val}, absmax={absmax}, range=[{-self.data_max}, {self.data_max}]" + ) + elif self.output_data is not None: logger.info( f"Skipping node {node.name}: reference IO out of range: min={np.min(self.output_data)}, " f"max={np.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" @@ -230,9 +274,18 @@ def __init__(self, max_depth_of_reduction, reference_data, node_to_init_map, ini self.reduction_depth = 0 def _get_tensor_shape(self, tensor_name): - """Get tensor shape from reference data.""" + """Get tensor shape from reference data. + + Supports both raw numpy arrays and TensorStats objects. + """ if tensor_name in self.reference_data: - return self.reference_data[tensor_name].shape + ref_data = self.reference_data[tensor_name] + # Import here to avoid circular imports + from modelopt.onnx.autocast.referencerunner import TensorStats + + if isinstance(ref_data, TensorStats): + return ref_data.shape + return ref_data.shape if tensor_name in self.initializer_map: return self.initializer_map[tensor_name].dims return None diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index d708c890f7..278486c4b4 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -97,6 +97,7 @@ def __init__( max_ir_version: int | None = None, trt_plugins: list[str] | None = [], tensor_block_dict: dict[str, dict[str, list[int]]] = {}, + use_standalone_type_inference: bool = False, ) -> None: """Initialize PrecisionConverter. @@ -114,6 +115,7 @@ def __init__( max_ir_version: Max IR version for conversion. trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library). tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32. + use_standalone_type_inference: Use standalone type inference instead of ONNX's infer_shapes. """ self.model = deepcopy(model) self.value_info_map = value_info_map @@ -140,6 +142,7 @@ def __init__( self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins + self.use_standalone_type_inference = use_standalone_type_inference # Detect additional ops not supported in low precision according to the model's opset version self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + ( @@ -172,7 +175,7 @@ def convert( onnx.ModelProto: The converted mixed precision model. """ try: - self.model = onnx_utils.check_model(self.model) + onnx_utils.check_model(self.model) except onnx.checker.ValidationError as e: logger.error(f"Internal error: onnx.checker failed on input model {e}") raise Exception( @@ -254,10 +257,14 @@ def convert( # Clear type/shape information for intermediates and outputs (including subgraphs) self._clear_types_and_shapes_recursive(self.model.graph) # Populate type information with inferred types - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=False + ) self._ensure_types_are_defined() # Sanity check: Verify type correctness - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True, check_type=True + ) # Update value_info_map and initializer_map with casts we added self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( @@ -282,9 +289,9 @@ def _clear_types_and_shapes_recursive( ) -> None: """Recursively clear type/shape information for a graph and all its subgraphs. - This is necessary for control flow operators (Scan, If, Loop) which have subgraphs. - For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph). - For main graph, clear all value_info. + If use_standalone_type_inference is True, we clear only types, not shapes. + For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated + from the main graph. Args: graph: The ONNX graph to clear types and shapes for. @@ -301,9 +308,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for inp in g.input: if inp.type.HasField("tensor_type"): inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(inp.type.tensor_type.shape.dim): - if d.dim_value: - inp.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(inp.type.tensor_type.shape.dim): + if d.dim_value: + inp.type.tensor_type.shape.dim[idx].dim_param = "unk" if is_sub: # Identify which tensors are produced by nodes in this subgraph @@ -315,9 +323,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> for vi in g.value_info: if vi.name in subgraph_outputs: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" else: for vi in g.value_info: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED @@ -328,9 +337,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> # Clear outputs for both main graph and subgraphs for out in g.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(out.type.tensor_type.shape.dim): - if d.dim_value: - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + if not self.use_standalone_type_inference: + for idx, d in enumerate(out.type.tensor_type.shape.dim): + if d.dim_value: + out.type.tensor_type.shape.dim[idx].dim_param = "unk" utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph) @@ -347,6 +357,8 @@ def _get_np_type(node, inp, opset=onnx.defs.onnx_opset_version()): return node.inputs[1].dtype # scale type elif node.op == "QuantizeLinear": return node.inputs[2].dtype # zero_point type + elif node.op == "ConstantOfShape": + return node.attrs["value"].dtype elif not inp.dtype or inp.dtype == onnx.TensorProto.UNDEFINED: return None elif node.op not in self.custom_ops: @@ -1175,8 +1187,16 @@ def _remove_redundant_casts(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, self.use_standalone_type_inference, strict_mode=True + ) + if not self.use_standalone_type_inference: + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) nodes_to_remove = [] for node in self.model.graph.node: @@ -1261,7 +1281,12 @@ def _fix_network_output_names(self): if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) else: - self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True) + self.model = onnx_utils.infer_types( + self.model, + self.use_standalone_type_inference, + strict_mode=True, + check_type=True, + ) self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( self.model ) @@ -1419,6 +1444,11 @@ def _sanitize_model(self): graph_sanitizer.sanitize() self.model = graph_sanitizer.model + # Update value_info_map and initializer_map after sanitizing model + self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings( + self.model + ) + def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}): """Create mapping of op types to indices of inputs that should not be converted to low precision.""" skip_inputs_map = {} diff --git a/modelopt/onnx/autocast/referencerunner.py b/modelopt/onnx/autocast/referencerunner.py index 8dc91ff089..0228d88f06 100644 --- a/modelopt/onnx/autocast/referencerunner.py +++ b/modelopt/onnx/autocast/referencerunner.py @@ -19,22 +19,55 @@ implementation. It supports both random input generation and user-provided inputs through NPZ or Polygraphy JSON files. The runner is used to analyze model behavior and validate outputs during precision conversion. + +When multiple batches of calibration data are provided, the runner aggregates statistics +across all batches to provide more robust range information for precision conversion decisions. """ import copy import io import sys +import tempfile from collections import OrderedDict +from dataclasses import dataclass import numpy as np import onnx +from modelopt.onnx import utils as onnx_utils from modelopt.onnx.autocast.logging_config import configure_logging, logger +from modelopt.onnx.quantization.calib_utils import CalibrationDataProvider from modelopt.onnx.quantization.ort_utils import _prepare_ep_list configure_logging() +@dataclass +class TensorStats: + """Statistics for a tensor aggregated across multiple batches.""" + + absmax: float + """Maximum absolute value across all batches.""" + min_val: float + """Minimum value across all batches.""" + max_val: float + """Maximum value across all batches.""" + shape: tuple + """Shape of the tensor (from first batch).""" + + def __abs__(self): + """Return the maximum absolute value (for compatibility with np.abs).""" + return self.absmax + + @property + def size(self): + """Return total number of elements.""" + result = 1 + for dim in self.shape: + result *= dim + return result + + class ReferenceRunner: """A class to run ONNX models with ONNXRuntime for reference inference.""" @@ -69,8 +102,35 @@ def _load_inputs_from_json(self, input_data_path): return load_json(input_data_path, description="input data") def _load_inputs_from_npz(self, input_data_path): - """Load inputs from NPZ format.""" - return [np.load(input_data_path)] + """Load inputs from NPZ format. + + Supports both single NPZ file and directory containing multiple NPZ files for multi-batch calibration. + + Args: + input_data_path: Path to NPZ file or directory containing NPZ files. + + Returns: + List of input dictionaries, one per batch. + """ + import os + + if os.path.isdir(input_data_path): + # Load all NPZ files in the directory as multiple batches + npz_files = sorted([f for f in os.listdir(input_data_path) if f.endswith(".npz")]) + if not npz_files: + raise ValueError(f"No NPZ files found in directory: {input_data_path}") + logger.info( + f"Loading {len(npz_files)} NPZ files from directory for multi-batch calibration" + ) + return [np.load(os.path.join(input_data_path, f)) for f in npz_files] + else: + calib_data = np.load(input_data_path) + if isinstance(calib_data, np.lib.npyio.NpzFile): + # Wrap data into a CalibDataProvider to support a single NPZ file + # containing data from multiple batches + data_loader = {key: calib_data[key] for key in calib_data.files} + return CalibrationDataProvider(self.model, data_loader).calibration_data_list + return [calib_data] def _validate_inputs(self, data_loader): """Validate that input names and shapes match the model.""" @@ -80,7 +140,16 @@ def _validate_inputs(self, data_loader): if sorted(self.input_names) != sorted(data_loader[0].keys()): raise ValueError("Input names from ONNX model do not match provided input names.") for inp_name, inp_shape in data_loader[0].items(): - if self.input_shapes[inp_name] != list(inp_shape.shape): + # Get model and data shapes as numpy arrays + inp_shape_model = np.array(self.input_shapes[inp_name]) + inp_shape_data = np.array(inp_shape.shape) + # Compare input rank + raise_value_error = len(inp_shape_model) != len(inp_shape_data) + if not raise_value_error: + # Compare input shape, skipping check for unknown dimensions + mask = inp_shape_model > 0 + raise_value_error = np.any(inp_shape_model[mask] != inp_shape_data[mask]) + if raise_value_error: raise ValueError( f"Input shape from '{inp_name}' does not match provided input shape: " f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. " @@ -96,16 +165,18 @@ def _load_inputs(self, inputs): # If no inputs are provided, use random inputs data_loader = DataLoader(val_range={"": (-1, 1)}) + import os + if inputs is not None: if isinstance(inputs, str): if inputs.endswith(".json"): data_loader = self._load_inputs_from_json(inputs) - elif inputs.endswith(".npz"): + elif inputs.endswith(".npz") or os.path.isdir(inputs): data_loader = self._load_inputs_from_npz(inputs) else: raise ValueError( - f"Invalid input file: {inputs}. Supported input file types: .json (Polygraphy JSON format), " - ".npz (Numpy)" + f"Invalid input file: {inputs}. Supported input types: .json (Polygraphy JSON format), " + ".npz (Numpy), or a directory containing .npz files" ) elif isinstance(inputs, (dict, OrderedDict)): data_loader = [inputs] @@ -118,13 +189,105 @@ def _load_inputs(self, inputs): return data_loader + def _get_ort_runner(self, model): + from polygraphy.backend.onnx import BytesFromOnnx + from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx + + # Check if model has external data by checking: + # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded) + # 2. If model size would exceed 2GB (indicating need for external data) + needs_external_data = onnx_utils.check_model_uses_external_data( + self.model + ) or self.model.ByteSize() > 2 * (1024**3) + if needs_external_data: + logger.debug("Model has external data, using file-based approach") + # Get the actual ONNX ModelProto from ModifyOutputs wrapper + modified_model = model() + + # Use a persistent temp file, because we need the file to be present in an broader context + tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) + tmp_file.close() + tmp_file_path = tmp_file.name + onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) + logger.debug(f"Model with all outputs saved to {tmp_file_path}") + build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers) + + else: + # For models without external data, use the original BytesFromOnnx approach (no tmp files) + logger.debug("Model has no external data, using BytesFromOnnx approach") + serialize_onnx = BytesFromOnnx(model) + build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) + runners = [OnnxrtRunner(build_onnxrt_session)] + return runners + + def _aggregate_tensor_stats(self, all_batch_data: list[OrderedDict]) -> OrderedDict: + """Aggregate tensor statistics across multiple batches. + + Args: + all_batch_data: List of dictionaries containing tensor data for each batch. + + Returns: + OrderedDict mapping tensor names to TensorStats objects. + """ + if len(all_batch_data) == 1: + # Single batch - return raw data for backward compatibility + return all_batch_data[0] + + logger.info(f"Aggregating statistics across {len(all_batch_data)} batches...") + + aggregated = OrderedDict() + tensor_names = all_batch_data[0].keys() + + for name in tensor_names: + absmax = -np.inf + min_val = np.inf + max_val = -np.inf + shape = None + + for batch_data in all_batch_data: + if name not in batch_data: + continue + data = batch_data[name] + if shape is None: + shape = data.shape + + batch_absmax = np.max(np.abs(data)) if data.size > 0 else 0 + batch_min = np.min(data) if data.size > 0 else 0 + batch_max = np.max(data) if data.size > 0 else 0 + + absmax = max(absmax, batch_absmax) + min_val = min(min_val, batch_min) + max_val = max(max_val, batch_max) + + if shape is not None: + aggregated[name] = TensorStats( + absmax=absmax, + min_val=min_val, + max_val=max_val, + shape=shape, + ) + + return aggregated + def run(self, inputs=None): - """Run FP32 inference with provided or random inputs.""" + """Run FP32 inference with provided or random inputs. + + When multiple batches of input data are provided, inference is run for each batch + and statistics are aggregated across all batches for more robust range estimation. + + Args: + inputs: Optional input data. Can be: + - None: Random inputs will be generated + - str: Path to JSON file, NPZ file, or directory containing NPZ files + - dict/OrderedDict: Single batch of input data + + Returns: + OrderedDict: Combined input and output data. For single batch, returns raw arrays. + For multiple batches, returns TensorStats objects with aggregated statistics. + """ import onnxruntime as ort from polygraphy import constants - from polygraphy.backend.onnx import BytesFromOnnx from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs - from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx from polygraphy.comparator import Comparator logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...") @@ -133,9 +296,9 @@ def run(self, inputs=None): model_copy = copy.deepcopy(self.model) modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL) - serialize_onnx = BytesFromOnnx(modify_outputs) - build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) - runners = [OnnxrtRunner(build_onnxrt_session)] + + # Load the modified model and create an inference session + runners = self._get_ort_runner(modify_outputs) # Comparator is used despite the fact that we are using ONNXRuntime # because it provides the ability to generate random inputs using DataLoader @@ -156,15 +319,30 @@ def run(self, inputs=None): logger.error(f"ONNXRuntime execution failed with output:\n{captured_output}") raise Exception("ONNXRuntime failed to run, see logs for details") - # Get the output results - output_dict = OrderedDict(results[0][1][0]) + # Collect all batch data (inputs + outputs) + all_batch_data = [] + runner_results = results[0][1] # Get all iteration results for the first runner + data_loader_iter = iter(data_loader) + + for iter_idx, iter_result in enumerate(runner_results): + output_dict = OrderedDict(iter_result) + + # Get corresponding input data + try: + input_data = next(data_loader_iter) + except StopIteration: + # If data_loader is exhausted, it might be a DataLoader that generates random data + input_data = {} - # Include input data for completeness - input_data = next(iter(data_loader)) + # Combine inputs and outputs for this batch + batch_dict = OrderedDict() + batch_dict.update(input_data) + batch_dict.update(output_dict) + all_batch_data.append(batch_dict) - # Combine inputs and outputs in the returned dictionary - combined_dict = OrderedDict() - combined_dict.update(input_data) - combined_dict.update(output_dict) + num_batches = len(all_batch_data) + if num_batches > 1: + logger.info(f"Processed {num_batches} batches of calibration data") - return combined_dict + # Aggregate statistics across all batches + return self._aggregate_tensor_stats(all_batch_data) diff --git a/modelopt/onnx/logging_config.py b/modelopt/onnx/logging_config.py index fd0c306a68..99468b87f2 100644 --- a/modelopt/onnx/logging_config.py +++ b/modelopt/onnx/logging_config.py @@ -64,8 +64,10 @@ def configure_logging(level=logging.INFO, log_file=None): console_handler.setFormatter(formatter) logger.addHandler(console_handler) - # Prevent log messages from propagating to the root logger - logger.propagate = False + # Allow log messages to propagate to the root logger for testing compatibility + # This enables pytest's caplog fixture to capture logs while still maintaining + # our custom formatting through the handlers above + logger.propagate = True # Ensure all child loggers inherit the level setting for name in logging.root.manager.loggerDict: diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221fb..7e11d25e66 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def get_copy_ops(): +def get_copy_ops() -> list[str]: """Returns list of copy operators.""" return [ "Flatten", @@ -303,3 +303,86 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_bool_ops(): + """Returns set of bool operations.""" + return { + "Not", + "And", + "Or", + "Xor", + } + + +def get_bitwise_ops(): + """Returns set of bitwise operations.""" + return { + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + } + + +def get_value_check_ops(): + """Returns set of value checking operations.""" + return { + "IsNaN", + "IsInf", + "Sign", + "Abs", + } + + +def get_comparison_ops(): + """Returns set of comparison operations.""" + return { + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + } + + +def get_conditional_ops(): + """Returns set of conditional operations.""" + return { + "Where", + } + + +def get_aggregation_ops(): + """Returns set of aggregation operations.""" + return { + "All", + "Any", + } + + +def get_set_ops(): + """Returns set of set/search operations.""" + return { + "Unique", + "NonZero", + } + + +def get_symmetric_ops(): + """Returns set of commutative/symmetric operations where operand order doesn't matter.""" + return { + "Add", + "Mul", + "And", + "Or", + "Xor", + "Equal", + "Max", + "Min", + "Sum", + "Mean", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + } diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 55cca6ee51..6c79d93179 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -16,6 +16,7 @@ """Command-line entrypoint for ONNX PTQ.""" import argparse +import os import numpy as np @@ -24,6 +25,31 @@ __all__ = ["main"] +def validate_file_size(file_path: str, max_size_bytes: int) -> None: + """Validate that a file exists and does not exceed the maximum allowed size. + + Args: + file_path: Path to the file to validate + max_size_bytes: Maximum allowed file size in bytes + + Raises: + FileNotFoundError: If the file does not exist + ValueError: If the file exceeds the maximum allowed size + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + file_size = os.path.getsize(file_path) + if file_size > max_size_bytes: + max_size_gb = max_size_bytes / (1024 * 1024 * 1024) + actual_size_gb = file_size / (1024 * 1024 * 1024) + raise ValueError( + f"File size validation failed: {file_path} ({actual_size_gb:.2f}GB) exceeds " + f"maximum allowed size of {max_size_gb:.2f}GB. This limit helps prevent potential " + f"denial-of-service attacks." + ) + + def get_parser() -> argparse.ArgumentParser: """Get the argument parser for ONNX PTQ.""" argparser = argparse.ArgumentParser("python -m modelopt.onnx.quantization") @@ -52,6 +78,11 @@ def get_parser() -> argparse.ArgumentParser: type=str, help="Calibration data in npz/npy format. If None, random data for calibration will be used.", ) + group.add_argument( + "--trust_calibration_data", + action="store_true", + help="If True, trust the calibration data and allow pickle deserialization.", + ) group.add_argument( "--calibration_cache_path", type=str, @@ -255,18 +286,50 @@ def get_parser() -> argparse.ArgumentParser: "The currently supported precisions are {fp16, int8, fp8}." ), ) + argparser.add_argument( + "--opset", + type=int, + help=( + "Target ONNX opset version for the quantized model. If not specified, uses default minimum opset " + "(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased " + "if certain operations require a higher version." + ), + ) return argparser def main(): """Command-line entrypoint for ONNX PTQ.""" args = get_parser().parse_args() + + # Security: Validate onnx model size is under 2GB by default + if not args.use_external_data_format: + try: + validate_file_size(args.onnx_path, 2 * (1024**3)) + except ValueError as e: + raise ValueError( + "Onnx model size larger than 2GB. Please set --use_external_data_format flag to bypass this validation." + ) from e + calibration_data = None if args.calibration_data_path: - calibration_data = np.load(args.calibration_data_path, allow_pickle=True) - if args.calibration_data_path.endswith(".npz"): - # Convert the NpzFile object to a Python dictionary - calibration_data = {key: calibration_data[key] for key in calibration_data.files} + # Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks + try: + calibration_data = np.load( + args.calibration_data_path, allow_pickle=args.trust_calibration_data + ) + if args.calibration_data_path.endswith(".npz"): + # Convert the NpzFile object to a Python dictionary + calibration_data = {key: calibration_data[key] for key in calibration_data.files} + except ValueError as e: + if "allow_pickle" in str(e) and not args.trust_calibration_data: + raise ValueError( + "Calibration data file contains pickled objects which pose a security risk. " + "For trusted sources, you may enable pickle deserialization by setting the " + "--trust_calibration_data flag." + ) from e + else: + raise quantize( args.onnx_path, @@ -298,6 +361,7 @@ def main(): simplify=args.simplify, calibrate_per_node=args.calibrate_per_node, direct_io_types=args.direct_io_types, + opset=args.opset, ) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 0000000000..a8929315a8 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner.""" + +import hashlib +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, +) + + +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion. + + Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions + contain child regions, and LEAF regions contain only nodes. Each region tracks + its direct nodes, input/output tensors, and a pattern signature for matching + regions with identical structure. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} + + def get_children(self, *, sort: bool = False) -> list["Region"]: + """Get all child regions. If sort is True, sort the children by level and size. + + Args: + sort: Whether to sort the children by level and size + + Returns: + List of child regions + """ + if sort: + return sorted( + self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) + ) + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list.""" + if child not in self.children: + return False + self.children.remove(child) + if child.parent and child.parent.id == self.id: + child.parent = None + return True + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + if child.id == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + if self.is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.id}" + ) + return + + if child.parent is not None and child.parent.id != self.id: + old_parent_id = child.parent.id + logger.debug( + f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}" + ) + child.parent.remove_child(child) + + if any(c.id == child.id for c in self.children): + logger.debug(f"Region {child.id} already child of {self.id}") + return + + self.children.append(child) + child.parent = self + + def is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.id in visited: + return False + visited.add(current.id) + if current.id == potential_ancestor.id: + return True + current = current.parent + return False + + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only.""" + if sort: + return sorted(self.nodes) + return list(self.nodes) + + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_within_region_and_descendants(self, node_index: int) -> bool: + """Check if region contains a node recursively.""" + return node_index in self.get_region_nodes_and_descendants() + + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_size_of_region_and_descendants(_visited) + return total + + def merge(self, other: "Region") -> None: + """Merge another region into this one.""" + if not other: + return + self.nodes.update(other.nodes) + for child in other.children: + self.add_child(child) + + def __repr__(self) -> str: + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + """ + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions.""" + return not self.node_inputs and not self.child_region_inputs and not self.region_outputs + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + """ + return self.error or self.latency_ms != float("inf") + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary.""" + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + """ + return ( + len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) + + len(set(self.region_outputs).symmetric_difference(other.region_outputs)) + ) + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={len(self.node_inputs)}, " + f"region_insertions={len(self.child_region_inputs)}, " + f"region_output_insertions={len(self.region_outputs)}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 0000000000..dd01848ddb --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ insertion point management for ONNX quantization autotune. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. +""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.op_types import ( + get_aggregation_ops, + get_bitwise_ops, + get_bool_ops, + get_comparison_ops, + get_conditional_ops, + get_copy_ops, + get_set_ops, + get_value_check_ops, + is_fusible_reduction_op, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points.""" + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... + + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names.""" + ... + + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region.""" + ... + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. It contains the + tensor name and the node index (if applicable) and input index (if applicable). + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable). + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region.""" + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + node = graph.nodes[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + resolved_ips = set() + # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + input_indices = [self.input_index] + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv/ConvTranspose inputs and weights must be quantized together" + ) + assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" + input_indices.append(1) + + for idx in input_indices: + inp = node.inputs[idx] + if hasattr(inp, "name") and inp.name: + resolved_ips.add( + ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=idx + ) + ) + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region.""" + node_indices = region.get_nodes(sort=True) + insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + name = getattr(inp, "name", None) + if not name or skip_invalid_insertion_points(graph, name, node): + continue + insertion_points.append( + NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + Only applies to COMPOSITE regions; LEAF regions have no children. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative child region index + region_index: int + # Input tensor index of that child region + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return set() + + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.inputs), "Input index out of range" + tensor_name = child_region.inputs[self.input_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return [] + + insertion_points = [] + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for input_idx, inp in enumerate(child_region.inputs): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + insertion_points.append( + ChildRegionInputInsertionPoint(region_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable). + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionOutputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names.""" + if self.region_index is not None: + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.outputs), "Output index out of range" + tensor_name = child_region.outputs[self.output_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + if self.node_index is not None: + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + node = graph.nodes[node_indices[self.node_index]] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + return resolve_region_io_insertion_points(None, graph, tensor.name) + + return set() + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + region_outputs_set = set(region.outputs) + insertion_points = [] + + # For COMPOSITE regions: collect child region outputs + if region.type != RegionType.LEAF: + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for output_idx, out in enumerate(child_region.outputs): + if out in region_outputs_set and not skip_invalid_insertion_points( + graph, out, child_region + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=local_idx, node_index=None, output_index=output_idx + ) + ) + + # For all regions: collect node outputs + for local_idx, node_idx in enumerate(region.get_nodes(sort=True)): + node = graph.nodes[node_idx] + for output_idx, out in enumerate(node.outputs): + if not (hasattr(out, "name") and out.name): + continue + if out.name in region_outputs_set and not skip_invalid_insertion_points( + graph, out.name, node + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=None, node_index=local_idx, output_index=output_idx + ) + ) + + return insertion_points + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_region_nodes_and_descendants() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax + if node.op in ["Relu", "Softmax"]: + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + if ( + producer.op == "BatchNormalization" + and len(producer.inputs[0].inputs) == 1 + and producer.inputs[0].inputs[0].op in ["Conv", "ConvTranspose"] + ): + return True + # Conv -> BatchNormalization + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in ( + get_bool_ops() + | get_bitwise_ops() + | get_value_check_ops() + | get_comparison_ops() + | get_conditional_ops() + | get_aggregation_ops() + | get_set_ops() + ) or is_fusible_reduction_op(node.op): + return True + # Filter 2: out shape operations + if node.op in get_autotuner_skip_ops(): + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations (only checks LEAF regions). + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type != RegionType.LEAF: + return True + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops & get_autotuner_quantizable_ops()) + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( + graph + ) + + node_indices: set[int] = set() + if region is not None: + node_indices.update(region.get_region_nodes_and_descendants()) + node_indices.update(tensor_users_map.get(tensor_name, [])) + + resolved = set() + for node_idx in node_indices: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + if not skip_invalid_insertion_points(graph, tensor_name, node): + resolved.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + ) + return resolved + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} + + results = resolved_insertion_points - node_ips + for tensor_name in {ip.tensor_name for ip in node_ips}: + all_users = set(tensor_users_map.get(tensor_name, [])) + qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} + if all_users == {ip.node_index for ip in qdq_users}: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + return results + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable.""" + return set(get_copy_ops()) | { + # Additional indexing/scatter/reshape ops + "Compress", + "Scatter", + "ExpandDims", + "Unsqueeze", + "View", + "Pad", + # Utility ops + "Cast", + "Ceil", + "Clip", + "Identity", + "Range", + "Shape", + } + + +def get_autotuner_quantizable_ops(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/modelopt/onnx/quantization/autotune/region_inspect.py b/modelopt/onnx/quantization/autotune/region_inspect.py new file mode 100644 index 0000000000..beb60268d0 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_inspect.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region search inspection tool for ONNX models.""" + +import argparse +import logging +import sys +from collections import Counter + +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations +from modelopt.onnx.quantization.autotune.region_search import ( + DEFAULT_MAX_STEPS, + CombinedRegionSearch, +) + + +def inspect_region_search( + onnx_path: str, + max_sequence_size: int = 10, + include_all_regions: bool = False, +) -> list[Region]: + """Inspect region search results for an ONNX model. + + This function loads an ONNX model, runs CombinedRegionSearch (which performs + both bottom-up partitioning and top-down refinement internally), and prints + detailed information about the discovered regions including their hierarchical + structure. + + **What it does:** + 1. Loads ONNX model and converts to GraphSurgeon format + 2. Creates CombinedRegionSearch instance with specified parameters + 3. Runs two-phase search (partitioning + refinement) via search_regions() + 4. Displays detailed region structure and statistics + 5. Returns the final list of refined regions + + **Output Sections:** + - Initialization: Shows search parameters + - Two-Phase Search: Runs automatically via CombinedRegionSearch.search_regions() + - Detailed Structure: Shows each region's hierarchy and properties + - Summary Statistics: Shows region counts and node coverage + + Args: + onnx_path: Path to the ONNX model file + max_sequence_size: Maximum size for sequence regions during refinement (default: 10) + include_all_regions: Include all regions, even those without major quantizable + operations (Conv, MatMul, etc.). Default: False (skips such regions) + + Returns: + List of discovered and refined regions (LEAF and COMPOSITE) + """ + # Load ONNX model + logger.info(f"Loading model: {onnx_path}") + onnx_model = onnx.load(onnx_path) + # Convert to onnx_graphsurgeon Graph + graph = gs.import_onnx(onnx_model) + graph.cleanup().toposort() + logger.info( + f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" + ) + # Initialize CombinedRegionSearch (contains RegionPartitioner internally) + logger.debug( + f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" + ) + + combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) + + # Run complete two-phase region search + logger.info("Running region search") + regions = combined_search.search_regions() + # Show detailed region structure + logger.info("Analyzing region structure") + all_regions = [] + for i, region in enumerate(regions): + region.children = [ + c + for c in region.get_children() + if include_all_regions or has_quantizable_operations(c, graph) + ] + if not include_all_regions and not has_quantizable_operations(region, graph): + logger.debug(f"Filtered out region {i} (no quantizable operations)") + continue + logger.debug( + f"Region {i}: {region.type.value}, {len(region.get_region_nodes_and_descendants())} nodes, " + f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + all_regions.append(region) + if region.type == RegionType.COMPOSITE: + logger.debug(f" {len(region.get_children())} child regions") + all_regions.extend(region.get_children()) + combined_search.print_tree(region, indent=2) + + # Summary statistics + type_counts = Counter(r.type for r in all_regions) + leaf_regions, composite_regions = ( + type_counts[RegionType.LEAF], + type_counts[RegionType.COMPOSITE], + ) + + all_nodes = {n for r in all_regions for n in r.get_region_nodes_and_descendants()} + total_nodes = len(all_nodes) + coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 + + logger.info( + f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " + f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Print histogram of region sizes + region_sizes = [ + len(r.get_region_nodes_and_descendants()) for r in all_regions if r.type == RegionType.LEAF + ] + + if region_sizes: + min_size = min(region_sizes) + max_size = max(region_sizes) + avg_size = sum(region_sizes) / len(region_sizes) + + logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + size_counts = Counter(region_sizes) + logger.debug("Size distribution:") + for size in sorted(size_counts.keys()): + count = size_counts[size] + bar = "█" * min(count, 50) + logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") + + return all_regions + + +def main(): + """Command-line entry point for region search inspection.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune.region_inspect", + description="Inspect region search results for ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic inspection + python -m modelopt.onnx.quantization.autotune.region_inspect --model model.onnx + + # Verbose mode for debug logging + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --verbose + + # Custom maximum sequence size + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --max-sequence-size 20 + """, + ) + + parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + parser.add_argument( + "--max-sequence-size", + type=int, + default=10, + help="Maximum size for sequence regions during refinement (default: 10)", + ) + parser.add_argument( + "--include-all-regions", + action="store_true", + help="Include all regions, even those without major quantizable operations. " + "Default: False (skips such regions)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + logger.setLevel(log_level) + + try: + regions = inspect_region_search( + onnx_path=args.model, + max_sequence_size=args.max_sequence_size, + include_all_regions=args.include_all_regions, + ) + logger.info(f"✓ Inspection complete: {len(regions)} regions discovered") + return 0 + except Exception as e: + logger.error(f"Inspection failed: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelopt/onnx/quantization/autotune/region_pattern.py b/modelopt/onnx/quantization/autotune/region_pattern.py new file mode 100644 index 0000000000..a32273f849 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_pattern.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region pattern signature generator for grouping structurally similar regions.""" + +import hashlib +from typing import Union, overload + +import onnx_graphsurgeon as gs + +from modelopt.onnx.op_types import get_symmetric_ops +from modelopt.onnx.quantization.autotune.common import InsertionScheme, Region +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, + ResolvedInsertionPoint, +) + + +class RegionPattern: + """Represents a structural pattern of a region. + + The pattern captures the topology and operation types in a region, + enabling pattern matching and region comparison. Patterns are hashable + and can be used as dictionary keys for efficient grouping and lookup. + """ + + def __init__(self, signature: str, size: int): + """Initialize a region pattern. + + Args: + signature: The structural signature of the region. + size: The number of nodes in the region. + """ + self.signature = signature + self.size = size + + @property + def is_empty(self) -> bool: + """Check if the pattern represents an empty region.""" + return self.size == 0 + + @property + def is_composite(self) -> bool: + """Check if the pattern represents a composite region.""" + return self.signature.startswith("COMPOSITE(") + + @property + def is_leaf(self) -> bool: + """Check if the pattern represents a leaf region (no composite structure).""" + return not self.is_composite and not self.is_empty + + def __str__(self) -> str: + """String representation of the pattern.""" + return self.signature + + def __repr__(self) -> str: + """Developer-friendly representation with signature and size.""" + return f"RegionPattern('{self.signature}', size={self.size})" + + def __eq__(self, other) -> bool: + """Check equality based on signature only.""" + if not isinstance(other, RegionPattern): + return False + return self.signature == other.signature + + def __hash__(self) -> int: + """Hash based on signature for use as dict key.""" + return hash(self.signature) + + def get_hash(self) -> str: + """Get a 128-bit cryptographic hash of the pattern signature.""" + return hashlib.sha256(self.signature.encode("utf-8")).hexdigest()[:32] + + def get_short_signature(self, max_length: int = 80) -> str: + """Get a truncated version of the signature for display purposes.""" + if len(self.signature) <= max_length or max_length > len(self.signature): + return self.signature + return self.signature[: max_length - 3] + "..." + + @classmethod + def from_region(cls, region: Region, graph: gs.Graph) -> "RegionPattern": + """Compute a structural pattern for a region. + + The pattern captures: + - Direct node operations in the region + - Structure of sub-regions (recursively) + - Handles symmetric operations consistently + - Sorts sub-regions by size for determinism + + Args: + region: The region to compute pattern for + graph: The ONNX graph containing the nodes + + Returns: + RegionPattern object containing the signature and metadata + """ + signature_str = cls._compute_signature_recursive(region, graph) + total_size = len(region.get_region_nodes_and_descendants()) + return cls(signature_str, total_size) + + @overload + def matches(self, other: "RegionPattern") -> bool: ... + @overload + def matches(self, other: Region, graph: gs.Graph, scheme: None = None) -> list[int] | None: ... + @overload + def matches( + self, other: Region, graph: gs.Graph, scheme: InsertionScheme + ) -> set[ResolvedInsertionPoint]: ... + + def matches( + self, + other: Union["RegionPattern", Region], + graph: gs.Graph | None = None, + scheme: InsertionScheme | None = None, + ) -> bool | list[int] | set[ResolvedInsertionPoint] | None: + """Check if this pattern matches another pattern or region. + + This method provides three distinct behaviors depending on the arguments: + + 1. **Pattern-to-pattern comparison** (other is RegionPattern, scheme is None): + Returns bool indicating structural equivalence. + + 2. **Pattern-to-region matching** (other is Region, scheme is None): + Returns list of node IDs in pattern order if match succeeds, None otherwise. + + 3. **Pattern-to-region with insertion scheme** (other is Region, scheme provided): + Returns set of resolved insertion points where Q/DQ should be inserted, considering: + - NodeInputInsertionPoints from the scheme (node-level Q/DQ) + - ChildRegionInputInsertionPoints from the scheme (child region input Q/DQ) + - RegionOutputInsertionPoints from the scheme (region output Q/DQ) + Returns empty set if pattern doesn't match. + + Args: + other: Either a RegionPattern or Region to compare with + graph: Required when other is a Region (for computing its pattern) + scheme: Optional InsertionScheme containing node_inputs, + child_region_inputs, and region_outputs + to resolve to tensor names + + Returns: + - True if other is RegionPattern and patterns match + - List of node IDs in pattern order if other is Region and scheme is None, None if no match + - Set of resolved insertion points for Q/DQ insertion if other is Region and scheme is provided + + Raises: + ValueError: If other is Region but graph is not provided, or if scheme + is provided but other is not a Region + TypeError: If other is neither RegionPattern nor Region + """ + if isinstance(other, RegionPattern): + if scheme is not None: + raise ValueError("scheme parameter can only be used when matching against a Region") + return self._matches_pattern(other) + elif isinstance(other, Region) and scheme is None: + return self._matches_region(other, graph) + elif isinstance(other, Region) and scheme is not None: + if graph is None: + raise ValueError("graph parameter is required") + + region_pattern = RegionPattern.from_region(other, graph) + if self != region_pattern: + return set() + + resolved_ips = set() + for ip in scheme.node_inputs: + resolved_ips.update(ip.resolve(other, graph)) + for ip in scheme.child_region_inputs: + resolved_ips.update(ip.resolve(other, graph)) + for ip in scheme.region_outputs: + resolved_ips.update(ip.resolve(other, graph)) + return resolved_ips + else: + raise TypeError(f"Expected RegionPattern or Region, got {type(other).__name__}") + + def _matches_pattern(self, other: "RegionPattern") -> bool: + """Internal function: Match this pattern against another pattern. + + Args: + other: Another RegionPattern to compare with + + Returns: + True if patterns are structurally equivalent, False otherwise + """ + return self == other + + def _matches_region(self, region: Region, graph: gs.Graph | None) -> list[int] | None: + """Internal function: Match this pattern against a region. + + Args: + region: The region to match against + graph: The ONNX graph containing the nodes + + Returns: + List of node IDs in match order if pattern matches, None otherwise. + Match order follows the pattern computation order: + - Direct nodes of the region (sorted) + - Then recursively, nodes from child regions (in child sort order) + + Raises: + ValueError: If graph is not provided + """ + if graph is None: + raise ValueError("graph parameter is required when matching against a Region") + + region_pattern = RegionPattern.from_region(region, graph) + + if self == region_pattern: + return self._collect_nodes_in_match_order(region) + else: + return None + + def get_full_insertion_scheme(self, region: Region, graph: gs.Graph) -> InsertionScheme: + """Collect all possible insertion points for quantization in a region. + + This method gathers all locations where Q/DQ nodes could be inserted within a region's + computational graph. These insertion points are organized into three categories: + - node_inputs: Inputs to individual nodes within the region + - child_region_inputs: Inputs to child regions within composite regions + - region_outputs: Outputs from the region or its child regions + + Args: + region: The region to collect insertion points for + graph: The ONNX graph containing the nodes + + Returns: + InsertionScheme object containing the insertion points + """ + region_pattern = RegionPattern.from_region(region, graph) + + if self != region_pattern: + raise ValueError("Region pattern mismatch") + + scheme = InsertionScheme() + scheme.node_inputs = NodeInputInsertionPoint.collect_from_region(region, graph) + scheme.child_region_inputs = ChildRegionInputInsertionPoint.collect_from_region( + region, graph + ) + scheme.region_outputs = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + return scheme + + def format_tree(self, region: Region, graph: gs.Graph, indent: int = 0) -> str: + """Format this pattern and region as a human-readable tree. + + Useful for debugging and visualization. + + Args: + region: The region associated with this pattern + graph: The ONNX graph + indent: Indentation level + + Returns: + Formatted string representation + """ + prefix = " " * indent + result = f"{prefix}Region {region.id}: {self.signature} (size={self.size})\n" + + for child in region.get_children(): + child_pattern = RegionPattern.from_region(child, graph) + result += child_pattern.format_tree(child, graph, indent + 1) + + return result + + @staticmethod + def _collect_nodes_in_match_order(region: Region) -> list[int]: + """Collect node IDs in the same order as signature computation. + + This follows the traversal order used by _compute_signature_recursive: + 1. Direct nodes of the region (sorted by node index) + 2. Recursively, nodes from child regions (children sorted by -level, then size) + + The child sorting order MUST match _compute_signature_recursive and + insertion_points.py for correct pattern-relative index alignment. + + Args: + region: The region to collect nodes from + + Returns: + List of node IDs in match order + """ + node_ids = [] + + node_ids.extend(region.get_nodes(sort=True)) + sorted_children = region.get_children(sort=True) + + for child in sorted_children: + node_ids.extend(RegionPattern._collect_nodes_in_match_order(child)) + + return node_ids + + @staticmethod + def _compute_signature_recursive(region: Region, graph: gs.Graph) -> str: + """Recursively compute structural signature for a region. + + The signature captures: + - Node operations and their key parameters (for LEAF regions) + - Hierarchical structure with child patterns (for COMPOSITE regions) + - Deterministic ordering (sorted nodes and children) + - Normalized handling of symmetric/commutative operations + + Signature formats: + - Empty region: "EMPTY" + - Leaf region: "Op1->Op2->Op3" or "Op1[params]->Op2[params]" + - Composite with nodes: "COMPOSITE(nodes|child1+child2)" + - Composite without nodes: "COMPOSITE(child1+child2)" + + Child Sorting: + - Children are sorted by (-level, size) for deterministic signatures + - This order MUST match insertion_points.py for correct pattern-relative indexing + - Higher-level (more abstract) children come first + - Within same level, smaller children come first + + Args: + region: The region to process + graph: The ONNX graph containing the nodes + + Returns: + Deterministic signature string representing the region structure + """ + nodes_list = list(graph.nodes) + node_indices_set = set(region.get_nodes()) + + if node_indices_set and max(node_indices_set) >= len(nodes_list): + raise ValueError("Region contains node indices outside the graph") + + node_ops = [ + RegionPattern._make_node_with_params_signature(nodes_list[idx], graph, node_indices_set) + for idx in sorted(node_indices_set) + ] + + sorted_children = region.get_children(sort=True) + + if not sorted_children: + return "->".join(node_ops) if node_ops else "EMPTY" + + child_sigs = "+".join( + [RegionPattern._compute_signature_recursive(child, graph) for child in sorted_children] + ) + + if node_ops: + node_sig = "->".join(node_ops) + return f"COMPOSITE({node_sig}|{child_sigs})" + return f"COMPOSITE({child_sigs})" + + @staticmethod + def _get_symmetric_input_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str | None: + """Compute normalized input source signature for symmetric operations.""" + if node.op not in get_symmetric_ops() or len(node.inputs) <= 1: + return None + + nodes_list = list(graph.nodes) + node_to_idx = {id(n): idx for idx, n in enumerate(nodes_list)} + + input_sources = [] + for inp in node.inputs: + if inp is None or not hasattr(inp, "inputs") or not inp.inputs: + input_sources.append(("external", "input-or-constant")) + else: + producer_node = inp.inputs[0] if inp.inputs else None + if producer_node and id(producer_node) in node_to_idx: + producer_idx = node_to_idx[id(producer_node)] + location = "internal" if producer_idx in region_node_indices else "external" + input_sources.append((location, producer_node.op)) + else: + input_sources.append(("external", "unknown")) + + sorted_sources = sorted(input_sources) + return ",".join(f"{loc}:{op}" for loc, op in sorted_sources) + + @staticmethod + def _format_attr_value(value: object) -> str: + """Format an attribute value for inclusion in a signature.""" + if isinstance(value, (list, tuple)): + if len(value) > 0 and all(isinstance(v, (int, float)) for v in value): + if all(isinstance(v, int) for v in value): + return "x".join(str(v) for v in value) + return "x".join(f"{v:.4g}" if isinstance(v, float) else str(v) for v in value) + return ",".join(str(v) for v in value) + if isinstance(value, float): + return f"{value:.4g}" + if isinstance(value, bool): + return "1" if value else "0" + if isinstance(value, bytes): + hex_str = value.hex() + return hex_str if len(hex_str) <= 16 else f"{hex_str[:16]}..." + return str(value) + + @staticmethod + def _make_node_with_params_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str: + """Create signature for a single node including its parameters. + + Includes operation type and key attributes that affect behavior. + For symmetric/commutative operations (Add, Mul, etc.), normalizes + input order to ensure consistent signatures regardless of operand order. + Ensures deterministic ordering by sorting attributes by key name. + + Args: + node: The ONNX node + graph: The ONNX graph containing all nodes + region_node_indices: Set of node indices in the current region + + Returns: + Signature string examples: + - "Relu" - Simple operation without attributes + - "Conv[dilations=1x1,kernel_shape=3x3]" - Operation with attributes + - "Add" - Symmetric op with sorted input sources + - "Mul[axis=1]" - Symmetric op with both + """ + op = node.op + sym_sig = RegionPattern._get_symmetric_input_signature(node, graph, region_node_indices) + + attr_sig = "" + if node.attrs: + attr_parts = [ + f"{key}={RegionPattern._format_attr_value(node.attrs[key])}" + for key in sorted(node.attrs.keys()) + ] + attr_sig = f"[{','.join(attr_parts)}]" + + if attr_sig and sym_sig: + return f"{op}{attr_sig}<{sym_sig}>" + if sym_sig: + return f"{op}<{sym_sig}>" + if attr_sig: + return f"{op}{attr_sig}" + return op diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py new file mode 100644 index 0000000000..02f8282a01 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -0,0 +1,1083 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hierarchical region discovery and partitioning for ONNX graphs.""" + +import sys +from collections import defaultdict, deque + +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +DEFAULT_MAX_STEPS = 10 +DEFAULT_MAX_NODES_TO_SHOW = 20 +MAX_PROBE_STEPS_AFTER_CONVERGE = 3 + + +class RegionSearchBase: + """Base class for region search algorithms providing common graph analysis utilities. + + This class serves as a foundation for region-based graph analysis algorithms by + providing essential data structures and methods for: + - Graph traversal and reachability analysis + - Divergence/convergence pattern detection + - Region boundary computation + - Tensor flow tracking + + For large graphs, initialization may take significant time but enables + efficient queries during region formation. + """ + + def __init__( + self, + graph: gs.Graph, + root: Region | None = None, + max_steps: int = DEFAULT_MAX_STEPS, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the base region search with graph analysis. + + Performs pre-computation of essential data structures for efficient + region analysis: + 1. Creates or validates root region containing all nodes + 2. Builds tensor-to-users mapping for divergence detection + 3. Pre-computes forward reachability for convergence detection + """ + self.graph = graph + if tensor_users_map is None: + tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.tensor_users_map = tensor_users_map + if root is None: + root = self._build_root_region() + self.root = root + if forward_reachable_nodes_map is None: + forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( + max_steps=max_steps + ) + self.forward_reachable_nodes_map = forward_reachable_nodes_map + + def _build_root_region(self) -> Region: + """Create a root region containing all nodes in the graph. + + The root region serves as the universal search space for region + formation algorithms. It represents the entire computation graph + as a single region before any partitioning. + + Returns: + Region of type ROOT containing all graph nodes. + """ + root = Region(region_id=0, level=0, region_type=RegionType.ROOT) + root.nodes.update(range(len(self.graph.nodes))) + self.compute_region_boundaries(root) + return root + + def _is_tensor_divergent(self, tensor_name: str) -> bool: + """Check if a tensor is consumed by multiple nodes (divergent). + + A divergent tensor indicates branching in the computation graph, + where one operation's output feeds into multiple downstream operations. + + Args: + tensor_name: Name of the tensor to check + + Returns: + True if tensor has more than one consumer, False otherwise + """ + return len(self.tensor_users_map.get(tensor_name, [])) > 1 + + def _is_node_divergent(self, node_idx: int) -> bool: + """Check if a node has outputs that branch to multiple consumers. + + A divergent node is one that produces outputs consumed by multiple + downstream nodes, creating branches in the computation graph. These + nodes are important boundaries for region formation. + + Args: + node_idx: Index of the node to check + + Returns: + True if the node has at least one output consumed by multiple nodes, + False otherwise or if node is not in root region. + """ + if node_idx not in self.root.get_nodes(): + logger.debug(f"Node {node_idx} not in root region") + return False + + node = self.graph.nodes[node_idx] + divergent_outputs = [ + out.name for out in node.outputs if self._is_tensor_divergent(out.name) + ] + is_divergent = len(divergent_outputs) > 0 + + if is_divergent: + logger.debug( + f"Divergent node {node_idx} ({node.op}): {len(divergent_outputs)} branches" + ) + + return is_divergent + + def _compute_forward_reachable_nodes( + self, start_node_idx: int, max_steps: int + ) -> dict[int, int]: + """Compute all nodes reachable forward from a starting node with distances. + + Uses breadth-first search (BFS) to find all nodes reachable by following + forward edges (data flow direction) from the start node, up to a maximum + distance. Records the shortest-path distance to each reachable node. + + Args: + start_node_idx: Index of node to start search from + max_steps: Maximum forward distance to explore + + Returns: + Dictionary mapping reachable node indices to their distances from start. + Includes start_node_idx mapped to distance 0. + """ + reachable: dict[int, int] = {start_node_idx: 0} + queue: deque[tuple[int, int]] = deque([(start_node_idx, 0)]) + while queue: + current_node_idx, distance = queue.popleft() + if distance >= max_steps: + continue + for output in self.graph.nodes[current_node_idx].outputs: + for next_node_idx in self.tensor_users_map.get(output.name, ()): + if next_node_idx not in reachable: + reachable[next_node_idx] = distance + 1 + queue.append((next_node_idx, distance + 1)) + return reachable + + def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[int, int]]: + """Pre-compute forward reachability for all nodes in the graph. + + This is a key optimization that enables efficient convergence detection. + By pre-computing forward reachability once, we can quickly answer queries + like "Can node A reach node B?" and "What is the distance from A to B?" + + Args: + max_steps: Maximum forward distance to pre-compute for each node. + Limits both time and space complexity. + + Returns: + Nested dictionary where outer key is start node index, inner key is + reachable node index, and value is shortest-path distance. + """ + logger.debug(f"Building forward reachability map (max_steps={max_steps})...") + forward_reachable_nodes_map: dict[int, dict[int, int]] = {} + for node_idx in self.root.get_nodes(): + forward_reachable_nodes_map[node_idx] = self._compute_forward_reachable_nodes( + node_idx, max_steps + ) + + total_reachable = sum(len(reachable) for reachable in forward_reachable_nodes_map.values()) + avg_reachable = total_reachable / len(self.root.get_nodes()) if self.root.get_nodes() else 0 + logger.debug(f"Reachability map complete: avg {avg_reachable:.1f} reachable nodes per node") + return forward_reachable_nodes_map + + def _find_common_reachable_nodes( + self, node_idx: int, branches: list[int] + ) -> tuple[list[dict], set[int]]: + """Find common reachable nodes from all branches (potential convergence points). + + Used as STEP 1 of convergence detection in _find_converge_nodes. + + Args: + node_idx: Index of the divergent node (excluded from common_nodes). + branches: List of branch head node indices. + + Returns: + (branch_reachable, common_nodes) + """ + branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches] + + if not branch_reachable: + logger.debug(" No reachable nodes from branches") + return [], set() + + common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable]) + logger.debug(f" {len(common_nodes)} common nodes found") + common_nodes.discard(node_idx) + + if not common_nodes: + logger.debug(" No valid convergence candidates") + return [], set() + + return branch_reachable, common_nodes + + def _evaluate_convergence_candidate( + self, + candidate_idx: int, + reachable_from_start: dict, + branch_reachable: list, + ) -> tuple[bool, int]: + r"""Check if a candidate convergence node forms a valid region and return its max distance. + + A valid region has no \"escaping\" edges: no node inside the region may reach a node + outside the region before reaching the candidate convergence point. + + Args: + candidate_idx: Candidate convergence node index. + reachable_from_start: Forward reachability from the divergent node. + branch_reachable: Per-branch reachability dicts (for max distance). + + Returns: + (is_valid, max_distance). max_distance is only meaningful when is_valid is True. + """ + region_nodes: set[int] = set(reachable_from_start.keys()) + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + region_nodes = region_nodes - set(reachable_from_candidate.keys()) + + for rnode_index in region_nodes: + reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {}) + rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf")) + for test_node_idx in reachable_from_rnode: + if test_node_idx in region_nodes: + continue + rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf")) + if any( + d == float("inf") for d in (rnode_to_test_distance, rnode_to_candidate_distance) + ): + return False, 0 + + max_distance = max(reachable[candidate_idx] for reachable in branch_reachable) + return True, max_distance + + def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: + """Find convergence point and intermediate nodes for a divergent node. + + Given a divergent node (where computation branches), this method finds: + 1. The convergence node: Where the branches rejoin + 2. All nodes between divergence and convergence + + Args: + node_idx: Index of the divergent node to find convergence for + + Returns: + Tuple containing: + - Convergence node index (None if no convergence found) + - Set of nodes between divergence and convergence + """ + node = self.graph.nodes[node_idx] + logger.debug(f"Finding convergence for node {node_idx} ({node.op})") + + branches: list[int] = [] + for output in node.outputs: + branches.extend(self.tensor_users_map.get(output.name, [])) + + branches = list(dict.fromkeys(branches)) + + logger.debug(f" {len(branches)} unique branches found") + + if len(branches) <= 1: + logger.debug(" Insufficient branches for convergence") + return None, set() + + branch_reachable, common_nodes = self._find_common_reachable_nodes(node_idx, branches) + if not branch_reachable or not common_nodes: + return None, set() + + # Select Best Convergence Node with Region Validity Check + converge_node_idx: int | None = None + min_max_distance = float("inf") + + reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {}) + + for candidate_idx in common_nodes: + valid, max_distance = self._evaluate_convergence_candidate( + candidate_idx, reachable_from_start, branch_reachable + ) + if not valid: + continue + if max_distance < min_max_distance: + min_max_distance = max_distance + converge_node_idx = candidate_idx + + # If no valid convergence found, this divergence has no convergence + if converge_node_idx is None: + logger.debug(" No valid convergence found") + return None, set() + + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}" + ) + + # Compute All Nodes Between Divergence and Convergence + visited_nodes: set[int] = set() + for candidate_idx in reachable_from_start: + if candidate_idx == converge_node_idx: + continue + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + if converge_node_idx in reachable_from_candidate: + visited_nodes.add(candidate_idx) + logger.debug(f" {len(visited_nodes)} nodes between divergence and convergence") + return converge_node_idx, visited_nodes + + def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: + """Compute maximum distance from a source node to a set of destination nodes. + + Uses pre-computed forward reachability to efficiently find the maximum + shortest-path distance from src_idx to any node in dst_indices. + + Args: + src_idx: Index of the source node + dst_indices: Set of destination node indices + + Returns: + Maximum distance from src_idx to any node in dst_indices + """ + max_distance = 0 + for dst_idx in dst_indices: + reachable = self.forward_reachable_nodes_map.get(src_idx, {}) + if dst_idx in reachable: + max_distance = max(max_distance, reachable[dst_idx]) + + logger.debug( + f"Max distance from node {src_idx}: {max_distance} steps to {len(dst_indices)} nodes" + ) + return max_distance + + def compute_region_boundaries(self, region: Region, include_constant: bool = False) -> None: + """Compute input and output tensor boundaries for a region. + + Args: + region: The region to compute boundaries for + include_constant: Whether to include constant tensors in inputs + """ + node_indices = region.get_region_nodes_and_descendants() + + consumed_tensors: set[str] = set() + produced_tensors: set[str] = set() + region_outputs: set[str] = set() + + for node_idx in node_indices: + if node_idx >= len(self.graph.nodes): + continue + node = self.graph.nodes[node_idx] + + # Collect consumed tensors (potential inputs) + for input_tensor in node.inputs: + if isinstance(input_tensor, gs.Constant) and not include_constant: + continue + consumed_tensors.add(input_tensor.name) + + # Collect produced tensors and determine outputs + for output_tensor in node.outputs: + tensor_name = output_tensor.name + produced_tensors.add(tensor_name) + + consumer_indices = self.tensor_users_map.get(tensor_name, []) + has_external_consumer = any(idx not in node_indices for idx in consumer_indices) + is_graph_output = output_tensor in self.graph.outputs + + if has_external_consumer or is_graph_output or not consumer_indices: + region_outputs.add(tensor_name) + + # Region inputs = consumed tensors not produced internally + region.inputs = sorted(consumed_tensors - produced_tensors) + region.outputs = sorted(region_outputs) + + logger.debug( + f"Computed boundaries: {len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + + def print_tree( + self, + region: Region | None = None, + indent: int = 0, + max_items: int = DEFAULT_MAX_NODES_TO_SHOW, + file=None, + ) -> None: + """Print hierarchical region tree in human-readable text format.""" + region = region or self.root + file = file or sys.stdout + p = " " * indent + + def truncated(items, fmt=str): + """Yield formatted items, truncating with count if needed.""" + items = list(items) + yield from (fmt(x) for x in items[:max_items]) + if len(items) > max_items: + yield f"... and {len(items) - max_items} more" + + direct_nodes = region.get_nodes() + children = region.get_children() + # Header + counts + print( + f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})", file=file + ) + print(f"{p}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) + print(f"{p}│ ├─ Total nodes: {len(region.get_region_nodes_and_descendants())}", file=file) + print(f"{p}│ ├─ Children: {len(children)}", file=file) + # I/O + for label, items in [("Inputs", region.inputs), ("Outputs", region.outputs)]: + print(f"{p}│ ├─ {label}: {len(items)}", file=file) + for line in truncated(items): + print(f"{p}│ │ - {line}", file=file) + # Direct nodes + if direct_nodes: + print(f"{p}│\n{p}│ Nodes in this region:", file=file) + + def node_fmt(i: int) -> str: + return f"Node {i}: {self.graph.nodes[i].op} ({self.graph.nodes[i].name})" + + for line in truncated(sorted(direct_nodes), node_fmt): + print(f"{p}│ - {line}", file=file) + # Children + if children: + print(f"{p}│\n{p}│ Child regions:", file=file) + for child in children: + print(f"{p}│", file=file) + self.print_tree(child, indent + 1, max_items, file) + + +class RegionPartitioner(RegionSearchBase): + """Bottom-up graph partitioner that creates initial regions based on divergence patterns. + + This class implements Phase 1 of the combined region search strategy. It performs + a systematic traversal of the computation graph from inputs to outputs, identifying + natural boundaries for region formation based on computation flow patterns. + + **Core Strategy:** + Partitions the graph by analyzing three types of computational patterns: + + 1. **Divergent Nodes with Convergence:** + - Nodes whose outputs branch to multiple paths (divergence) + - Paths that eventually rejoin at a common node (convergence) + - Creates a single region encompassing divergence + branches + convergence + - Example: A → (B,C) → D creates region containing {A, B, C, D} + + 2. **Divergent Nodes without Convergence:** + - Nodes whose outputs branch but never rejoin + - Creates a single-node "orphan" region for the divergent node + - Example: A → (B,C) with no convergence creates region {A} + + 3. **Linear Sequences:** + - Chains of non-divergent nodes (simple sequential computation) + - Groups entire sequence into one region + - Example: A → B → C → D creates region {A, B, C, D} + """ + + def __init__( + self, + graph: gs.Graph, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the partitioner with a computation graph. + + Sets up necessary data structures and inherits graph analysis utilities + from RegionSearchBase (tensor users map, reachability, etc.). + + Args: + graph: The ONNX computation graph (onnx_graphsurgeon.Graph) + tensor_users_map: Mapping from tensor names to consuming node indices + forward_reachable_nodes_map: Pre-computed forward reachability for all nodes + """ + super().__init__( + graph, + root=None, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.current_region: Region | None = None + self.current_region_id: int = 0 + self.visited_nodes: set[int] = set() + + def _append_node_to_region(self, node_idx: int): + """Add a node to the current region, creating a new region if needed. + + This is the primary method for building regions incrementally. If no + region is currently active, creates a new LEAF region. Then adds the + specified node to that region. + + Args: + node_idx: Index of the node to add to the current region + + Returns: + None + """ + node = self.graph.nodes[node_idx] + if self.current_region is None: + self.current_region = Region( + region_id=self.current_region_id, level=0, region_type=RegionType.LEAF + ) + logger.debug(f"Started region {self.current_region_id}") + self.current_region_id += 1 + + self.current_region.nodes.add(node_idx) + logger.debug( + f" Added node {node_idx} ({node.op}), region size: {len(self.current_region.nodes)}" + ) + + def _commit_region(self): + """Finalize and store the current region being built. + + Completes region construction by: + 1. Computing input/output tensor boundaries + 2. Adding region to the completed regions list + 3. Resetting current_region to None for next region + + **Post-Conditions:** + - current_region is added to regions list + - current_region is reset to None + - Region has computed input/output tensor lists + + Side Effects: + - Appends current_region to self.regions + - Sets current_region to None + - Logs region commit with size info + """ + if self.current_region is not None: + region_size = len(self.current_region.nodes) + region_id = self.current_region.id + + self.compute_region_boundaries(self.current_region) + + self.regions.append(self.current_region) + logger.debug( + f"Committed region {region_id}: {region_size} nodes (total: {len(self.regions)})" + ) + self.current_region = None + else: + logger.debug("No region to commit") + + def _build_sequence_from_node(self, node_idx: int, max_nodes: int = -1): + """Build a region from a linear sequence of nodes. + + Starting from a non-divergent node, follows the forward chain of nodes, + adding each non-divergent node to the current region. Stops when hitting: + - A divergent node (branches to multiple paths) + - A node already visited + - End of graph + + Args: + node_idx: Index of the starting node + max_nodes: Maximum number of nodes to add to the region (-1 for no limit) + + Returns: + None + """ + logger.debug(f"Building sequence from node {node_idx} ({self.graph.nodes[node_idx].op})") + + queue: deque[int] = deque([node_idx]) + nodes_added = 0 + + while queue: + current_idx = queue.popleft() + node = self.graph.nodes[current_idx] + + self._append_node_to_region(current_idx) + self.visited_nodes.add(current_idx) + nodes_added += 1 + + if self._is_node_divergent(current_idx): + logger.debug(f" Stopped at divergent node {current_idx} ({node.op})") + else: + # Queue successors for non-divergent nodes + for output in node.outputs: + if output.name in self.tensor_users_map: + queue.extend(self.tensor_users_map[output.name]) + + if 0 < max_nodes <= nodes_added: + logger.debug(" Max nodes reached") + break + + logger.debug(f"Sequence complete: {nodes_added} nodes") + + def _build_small_converged_region( + self, start_node_idx: int, converge_node_idx: int, visited_nodes: set[int] + ): + r"""Create a region encompassing divergence, branches, and convergence. + + Builds a single region containing: + - The divergent node (where branches split) + - All nodes in the branches + - The convergence node (where branches rejoin) + + This creates a "diamond" or "funnel" shaped region that captures + parallel computation paths and their merge point. + + **Structure:** + ``` + start (divergent) + / \ + path1 path2 (visited_nodes) + \\ / + convergence + ``` + """ + visited_nodes.remove(start_node_idx) + for node_idx in sorted(visited_nodes): + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + if not self._is_node_divergent(converge_node_idx): + self._append_node_to_region(converge_node_idx) + self.visited_nodes.add(converge_node_idx) + self._build_sequence_from_node(converge_node_idx, max_nodes=MAX_PROBE_STEPS_AFTER_CONVERGE) + + def _build_region_from_node(self, node_idx: int): + """Process a single node and create appropriate region(s) based on its pattern. + + This is the core dispatch method that determines how to handle each node based on whether + it's divergent (branches) or sequential. + + - Pattern 1: Divergent with Convergence (Ideal Case) + - Pattern 2: Divergent without Convergence (Boundary Case) + - Pattern 3: Sequential Chain (Common Case) + + Args: + node_idx: Index of node to process + + Side Effects: + - Marks processed nodes as visited + - Creates and commits region(s) via helper methods + - May recursively process successor nodes (in sequence building) + """ + node = self.graph.nodes[node_idx] + + # Skip nodes already assigned to regions + if node_idx in self.visited_nodes: + logger.debug(f"Skipping node {node_idx} ({node.op}): already visited") + return + + logger.debug(f"Processing node {node_idx} ({node.op})") + + # Pattern 1 & 2: Handle divergent nodes + if self._is_node_divergent(node_idx): + logger.debug(" Divergent node, searching for convergence") + # Attempt to find where branches rejoin + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + # Check if convergence creates a reasonable-sized region + max_distance = self._max_distance_to_nodes(node_idx, visited_nodes) + # Pattern 1: Convergence found and region size is acceptable + if converge_node_idx is not None and max_distance < DEFAULT_MAX_STEPS: + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Creating converged region: {len(visited_nodes)} nodes, " + f"convergence at {converge_node_idx} ({converge_node.op}), distance {max_distance}" + ) + # Create region containing: divergence + all branches + convergence + self._build_small_converged_region(node_idx, converge_node_idx, visited_nodes) + self._commit_region() + # Pattern 2: No convergence or region would be too large + else: + logger.debug(" Creating orphan region for divergent node") + # Create single-node region for this divergent node + # Its successors will be processed separately + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + self._commit_region() + else: + # Pattern 3: Handle non-divergent (sequential) nodes + logger.debug(" Non-divergent node, building sequence") + # Build region by following the linear chain forward + self._build_sequence_from_node(node_idx) + self._commit_region() + + def partition_graph(self): + """Partition the entire graph into non-overlapping LEAF regions. + + This is the main entry point for bottom-up graph partitioning. Performs + a single pass over all nodes in graph order, creating regions based on + divergence/convergence patterns and sequential chains. + + Returns: + List of non-overlapping LEAF regions created from the graph. + + """ + logger.info(f"Partitioning graph ({len(self.graph.nodes)} nodes)") + logger.debug( + f"Initial state: {len(self.visited_nodes)} visited, {len(self.regions)} regions" + ) + + for node_idx in range(len(self.graph.nodes)): + self._build_region_from_node(node_idx) + + coverage_pct = ( + 100 * len(self.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 + ) + logger.info( + f"Partitioning complete: {len(self.regions)} regions, " + f"{len(self.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + if self.regions: + region_sizes = [len(r.nodes) for r in self.regions] + avg_size = sum(region_sizes) / len(region_sizes) + min_size = min(region_sizes) + max_size = max(region_sizes) + logger.debug(f"Region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + + return self.regions + + +class TopDownRegionBuilder(RegionSearchBase): + """Top-down region refiner that creates hierarchical structure from initial regions. + + This class implements Phase 2 of the combined region search strategy. It takes + a region created by RegionPartitioner and refines it by: + 1. Identifying and merging converged sub-patterns + 2. Splitting long sequences into optimal sub-regions + 3. Creating a hierarchical COMPOSITE region structure + """ + + def __init__( + self, + graph: gs.Graph, + root: Region, + next_region_id: int = 0, + maximum_sequence_region_size: int = 10, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the refiner with a region to refine. + + Args: + graph: The ONNX graph (onnx_graphsurgeon.Graph) + root: The region to refine (typically from RegionPartitioner) + next_region_id: Starting ID for new regions created during refinement + maximum_sequence_region_size: Maximum nodes per sequence region during merging (default: 10) + """ + super().__init__( + graph, + root=root, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.next_region_id = next_region_id + self.maximum_sequence_region_size = maximum_sequence_region_size + self.boundary_op_types = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + } + + def _create_leaf_region(self, node_indices: set[int]) -> Region: + """Create a new LEAF region containing specified nodes. + + Args: + node_indices: Set of node indices to add to the region + + Returns: + New LEAF region containing the specified nodes + """ + region = Region( + region_id=self.next_region_id, level=self.root.level + 1, region_type=RegionType.LEAF + ) + self.next_region_id += 1 + for node_idx in node_indices: + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + return region + + def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Region]]: + """Build mapping from tensor names to regions that consume them. + + Similar to tensor_users_map but at the region level instead of node level. + This enables efficient traversal of region dependencies for merging decisions. + + Args: + regions: List of regions to build the usage map for + + Returns: + Mapping from tensor names to regions that consume them + """ + region_usage_map: dict[str, list[Region]] = defaultdict(list) + for region in regions: + for input_tensor in region.inputs: + region_usage_map[input_tensor].append(region) + return region_usage_map + + def _split_sequence_regions(self, root: Region) -> list[Region]: + """Split a region into smaller sub-regions by merging producer-consumer chains. + + Takes a region and creates optimal sub-regions by: + 1. Initially splitting into individual single-node regions + 2. Traversing in data flow order (following tensor dependencies) + 3. Merging adjacent regions that form simple producer-consumer chains + 4. Respecting boundary operations and size limits + + Args: + root: The region to split + + Returns: + List of smaller sub-regions + """ + result_regions: list[Region] = [] + removed_regions: set[int] = set() + + # PHASE 1: Split into Single-Node Regions + for node_idx in root.get_nodes(): + region = Region( + region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF + ) + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + result_regions.append(region) + self.next_region_id += 1 + + region_usage_map = self._build_region_usage_map(result_regions) + + # PHASE 2: Merge Regions in Data Flow Order + queue = deque(root.inputs) + + while len(queue) > 0: + tensor_name = queue.popleft() + # Skip tensors not produced by any region in our scope + if tensor_name not in region_usage_map: + continue + # Process each region consuming this tensor (potential merge targets) + consumers = region_usage_map[tensor_name] + for consumer in consumers: + # Skip regions already merged into others + if consumer.id in removed_regions: + continue + # Merging criteria: ALL outputs go to same single region + common_use_region = None + can_merge = True + # Check all outputs of the consumer region + for output_tensor in consumer.outputs: + queue.append(output_tensor) + if output_tensor not in region_usage_map: + can_merge = False + break + use_regions = region_usage_map[output_tensor] + if len(use_regions) != 1: + can_merge = False + break + if common_use_region is None: + common_use_region = use_regions[0] + elif common_use_region != use_regions[0]: + can_merge = False + break + # No valid downstream region to merge with + if common_use_region is None or common_use_region.id in removed_regions: + can_merge = False + continue + # Constraint 1: Limit the number of boundary operations after merge + nodes_after_merge = set() + nodes_after_merge.update(consumer.get_nodes()) + nodes_after_merge.update(common_use_region.get_nodes()) + node_ops = [self.graph.nodes[idx].op for idx in nodes_after_merge] + boundary_op_count = sum( + [1 if op in self.boundary_op_types else 0 for op in node_ops] + ) + if boundary_op_count > 3: + can_merge = False + continue + # Constraint 2: Size limits to avoid overly large regions + # Keep regions manageable for optimization passes + if ( + len(consumer.nodes) >= self.maximum_sequence_region_size + or len(common_use_region.nodes) >= self.maximum_sequence_region_size + ): + # One or both regions too large - don't merge + can_merge = False + continue + # All criteria met: merge consumer into its downstream region + if can_merge: + common_use_region.merge(consumer) + removed_regions.add(consumer.id) + # Remove regions that were merged into others + result_regions = [region for region in result_regions if region.id not in removed_regions] + # Recompute boundaries for all remaining regions + for region in result_regions: + self.compute_region_boundaries(region) + + return result_regions + + def _merge_converged_regions(self, root: Region): + """Identify and merge convergence patterns within a region. + + Traverses the region to find divergent nodes and their convergence points, + creating sub-regions that capture divergence→branches→convergence patterns. + Nodes not part of any convergence pattern are left for sequence splitting. + + Args: + root: The region to merge + + Returns: + List of merged regions + """ + result_regions: list[Region] = [] + removed_nodes: set[int] = set() + queue = deque(root.inputs) + while len(queue) > 0: + tensor_name = queue.popleft() + if tensor_name not in self.tensor_users_map: + continue + consumer_nodes = self.tensor_users_map[tensor_name] + for node_idx in consumer_nodes: + # stop at boundary nodes + if node_idx not in root.get_nodes(): + continue + consumer = self.graph.nodes[node_idx] + for output_tensor in consumer.outputs: + if output_tensor.name not in self.tensor_users_map: + continue + queue.append(output_tensor.name) + # if the node is already in a region, skip + if node_idx in removed_nodes: + continue + if not self._is_node_divergent(node_idx): + continue + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + visited_nodes = visited_nodes.intersection(root.get_region_nodes_and_descendants()) + # if no convergence found, skip + if converge_node_idx is None: + continue + # group converged nodes into a region + if converge_node_idx in root.get_nodes(): + converged_region = self._create_leaf_region(visited_nodes) + result_regions.append(converged_region) + removed_nodes.update(visited_nodes) + continue + # create a leaf region for the remaining nodes + remaining_nodes = set(root.get_nodes()) - removed_nodes + if len(remaining_nodes) > 0: + result_regions.append(self._create_leaf_region(remaining_nodes)) + # compute region boundaries for all regions + for region in result_regions: + self.compute_region_boundaries(region) + return result_regions + + def build_composite_region(self) -> Region: + """Refine a flat region into a hierarchical COMPOSITE region.""" + # merge converged regions into composite regions + regions = self._merge_converged_regions(self.root) + # split sequence regions into smaller regions + result_regions: list[Region] = [] + for region in regions: + result_regions.extend(self._split_sequence_regions(region)) + for region in result_regions: + self.compute_region_boundaries(region, include_constant=True) + regions = result_regions + # merge all regions into a single composite region + if len(regions) > 1: + composite = Region( + region_id=self.next_region_id, + level=self.root.level, + region_type=RegionType.COMPOSITE, + ) + self.next_region_id += 1 + regions = sorted( + regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature + ) + for region in regions: + composite.add_child(region) + self.compute_region_boundaries(composite) + regions = [composite] + self.regions = regions + return self.regions[0] + + +class CombinedRegionSearch(RegionSearchBase): + """Two-phase region search combining bottom-up partitioning with top-down refinement. + + This class implements a sophisticated region discovery algorithm that combines two + complementary strategies to create well-formed, hierarchical regions from an ONNX + computation graph. + + """ + + def __init__( + self, + graph: gs.Graph, + maximum_sequence_region_size: int = 10, + minimum_topdown_search_size: int = 10, + ): + """Initialize CombinedRegionSearch for a given ONNX graph.""" + super().__init__(graph) + self.regions: list[Region] = [] + self.minimum_topdown_search_size = minimum_topdown_search_size + self.maximum_sequence_region_size = maximum_sequence_region_size + + def search_regions(self) -> list[Region]: + """Execute two-phase region search to partition the graph into hierarchical regions. + + 1. Bottom-up partitioning + 2. Top-down refinement + + Args: + None + + Returns: + List of hierarchical regions created from the graph + """ + logger.info("Phase 1: Bottom-up partitioning") + logger.debug("Initializing RegionPartitioner") + region_partitioner = RegionPartitioner(self.graph) + + # Execute the bottom-up partitioning algorithm. + self.regions = region_partitioner.partition_graph() + + coverage_pct = ( + 100 * len(region_partitioner.visited_nodes) / len(self.graph.nodes) + if self.graph.nodes + else 0 + ) + logger.info( + f"Phase 1 complete: {len(self.regions)} regions, " + f"{len(region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + logger.debug("Proceeding to Phase 2: Top-down refinement") + + logger.info("Phase 2: Top-down refinement") + next_region_id = region_partitioner.current_region_id + + refined_count = 0 + for idx, region in enumerate(self.regions): + node_count = len(region.get_region_nodes_and_descendants()) + if node_count < self.minimum_topdown_search_size: + logger.debug(f"Skipping region {idx}: {node_count} nodes (below minimum)") + continue + + logger.debug(f"Refining region {idx}: {node_count} nodes") + region_builder = TopDownRegionBuilder( + self.graph, + region, + next_region_id=next_region_id, + maximum_sequence_region_size=self.maximum_sequence_region_size, + tensor_users_map=region_partitioner.tensor_users_map, + forward_reachable_nodes_map=region_partitioner.forward_reachable_nodes_map, + ) + + self.regions[idx] = region_builder.build_composite_region() + node_count_after = len(self.regions[idx].get_region_nodes_and_descendants()) + if node_count != node_count_after: + logger.warning( + f"Node count mismatch in region {idx}: {node_count} → {node_count_after}" + ) + + region_partitioner.compute_region_boundaries(self.regions[idx]) + next_region_id = region_builder.next_region_id + refined_count += 1 + + logger.info(f"Phase 2 complete: refined {refined_count}/{len(self.regions)} regions") + + return self.regions diff --git a/modelopt/onnx/quantization/calib_utils.py b/modelopt/onnx/quantization/calib_utils.py index a962f48b3d..56e0d4cc01 100644 --- a/modelopt/onnx/quantization/calib_utils.py +++ b/modelopt/onnx/quantization/calib_utils.py @@ -38,7 +38,7 @@ class CalibrationDataProvider(CalibrationDataReader): def __init__( self, - onnx_path: str, + onnx_path: str | onnx.ModelProto, calibration_data: CalibrationDataType, calibration_shapes: str | None = None, ): @@ -58,7 +58,7 @@ def __init__( logger.info("Setting up CalibrationDataProvider for calibration") # Tensor data is not required to generate the calibration data # So even if the model has external data, we don't need to load them here - onnx_model = onnx.load(onnx_path) + onnx_model = onnx.load(onnx_path) if isinstance(onnx_path, str) else onnx_path input_names = get_input_names(onnx_model) input_shapes = {} if calibration_shapes is None else parse_shapes_spec(calibration_shapes) inferred_input_shapes = get_input_shapes(onnx_model) @@ -89,7 +89,7 @@ def __init__( # Create list of model inputs with appropriate batch size n_itr = int(calibration_data[input_names[0]].shape[0] / input_shapes[input_names[0]][0]) logger.debug(f"Creating {n_itr} calibration iterations") - self.calibration_data_list = [{}] * n_itr + self.calibration_data_list = [{} for _ in range(n_itr)] for input_name in input_names: for idx, calib_data in enumerate( np.array_split(calibration_data[input_name], n_itr, axis=0) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index bca898b0cf..76a3e81674 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -102,7 +102,7 @@ def _convert(node: onnx.NodeProto): ) zero_point = initializers[zero_point_idx] dtype = onnx.helper.tensor_dtype_to_np_dtype(zero_point.data_type) - vals = np.array(zero_point.int32_data, dtype=dtype).tobytes() + vals = np.array(zero_point.int32_data, dtype=dtype).tobytes() or zero_point.raw_data np_zero_point = onnx.helper.make_tensor( zero_point_name, onnx.TensorProto.FLOAT8E4M3FN, zero_point.dims, vals, raw=True @@ -182,6 +182,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + opset: int | None = None, **kwargs, ) -> onnx.ModelProto: """Applies FP8 GEMM only quantization to an ONNX file. @@ -328,6 +329,7 @@ def quantize( tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, + opset=opset, ) current_opsets = {opset.domain: opset.version for opset in onnx_model.opset_import} diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df1..efa77dd7bb 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,28 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_name = tensor + if isinstance(tensor, str): + tensor_name = tensor + elif hasattr(tensor, "name") and isinstance(tensor.name, str): + tensor_name = tensor.name + tensor_consumer_map[tensor_name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index a6e98a5792..b17431fb9b 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -220,7 +220,21 @@ def quantize_rtn( Always selects the first dimension (0) to block over. This is because we must batch over the Cin dimension, and in ONNX, weights are always plugged into the RHS (i.e. y = x @ W). + + Args: + use_column_major: If True, apply column-major storage optimization for execution + providers that need it. Passed via kwargs. """ + use_column_major = kwargs.get("use_column_major", False) + + # Column-major only makes sense for DQ-only mode + if use_column_major and not dq_only: + logger.warning( + "use_column_major=True has no effect in QDQ mode. " + "Column-major optimization only applies to DQ-only quantization." + ) + use_column_major = False + logger.info("Starting RTN quantization") t_start = time.time() @@ -295,8 +309,15 @@ def quantize_rtn( qw = np.asnumpy(qw) scales[name] = np.asnumpy(scales[name]) gemm_weights_quantized[name] = numpy.asarray(qw) + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) - dq_node_attributes = {"axis": 0, "block_size": block_size} qdq.insert_dq_nodes( graph, scales, @@ -305,6 +326,10 @@ def quantize_rtn( layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph) + if gather_w_map is not None: gather_dq_node_attributes = { "axis": gather_quantize_axis, @@ -605,7 +630,14 @@ def _quantize_awq_clip( ) t = time.time() - dq_node_attributes = {"axis": 0, "block_size": block_size} + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, @@ -614,6 +646,9 @@ def _quantize_awq_clip( attributes=dq_node_attributes, layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} @@ -1308,7 +1343,14 @@ def _quantize_awq_lite( ) t = time.time() - dq_node_attributes = {"axis": 0, "block_size": block_size} + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, @@ -1318,6 +1360,9 @@ def _quantize_awq_lite( zero_points=zero_points if use_zero_point else None, layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" assert not use_zero_point or gather_zp_map, ( @@ -1420,10 +1465,19 @@ def quantize( Default: False. - **layers_8bit** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4. Default: []. + - **use_column_major** (bool): If True, apply column-major storage optimization for + execution providers that need it. This transposes + weights and adds Transpose nodes around MatMul operations. + Only applies to DQ-only quantization mode. + Default: False. **Returns**: A quantized ONNX model in ONNX ModelProto format. """ configure_logging(level=log_level.upper()) logger.info(f"Starting INT4 quantization with method: {calibration_method}") + + # Log if column-major optimization is enabled (works for all methods) + if kwargs.get("use_column_major", False): + logger.info("Column-major storage optimization enabled via use_column_major flag") t_start = time.time() if cupy_warning_msg: diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 01929667cd..6e350a16f4 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -132,6 +132,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + opset: int | None = None, **kwargs, ) -> onnx.ModelProto: """Applies INT8 quantization to an ONNX file using the compiler friendly heuristics. @@ -289,6 +290,7 @@ def quantize( tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, + opset=opset, ) if nodes_to_quantize: diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 026b8d062d..a7e4208d00 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1022,6 +1022,142 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn return onnx_model +# ============================================================================= +# Column-major weight storage transformation for execution providers that need it +# ============================================================================= + + +def _apply_transpose_perm_to_shape(shape, perm): + """Apply transpose permutation to a shape to get the output shape. + + Args: + shape: Input shape as a list/tuple + perm: Permutation indices + + Returns: + Transposed shape or None if inputs are None + """ + if shape is None or perm is None: + return None + return [shape[i] for i in perm] + + +def insert_transpose_nodes_for_column_major(graph: gs.Graph): + """Add a single Transpose node after each DequantizeLinear for column-major weights. + + This implements the simple transformation: A @ B = A @ ((B^T)^T) + where B^T is stored in the DequantizeLinear node, and we add a Transpose + node after DQ to recover B before the MatMul. + + Graph transformation: + Before: DQ(W) -> MatMul/Gemm + After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm + + Args: + graph: ONNX GraphSurgeon graph to modify in-place + """ + nodes_to_add = [] + dq_nodes_processed = set() + + for node in graph.nodes: + if node.op in ["MatMul", "Gemm"]: + # Check if second input (weight) is from DequantizeLinear + weight_input = node.inputs[1] + if not isinstance(weight_input, gs.Variable): + continue + + # Find the producer of the weight input + producer_nodes = [n for n in graph.nodes if weight_input in n.outputs] + if not producer_nodes: + continue + + producer_node = producer_nodes[0] + if producer_node.op != DEQUANTIZE_NODE_NAME: + continue + + # Skip if we already processed this DQ node + if producer_node.name in dq_nodes_processed: + continue + dq_nodes_processed.add(producer_node.name) + + # For Gemm nodes with transB=1, flip to transB=0 since weights are already transposed + # Original: Gemm expects W and internally computes A @ W^T + # After column-major: weight is W^T, so set transB=0 to use W^T directly -> A @ W^T + if node.op == "Gemm": + if hasattr(node, "attrs") and "transB" in node.attrs and node.attrs["transB"] > 0: + logger.debug( + f"Gemm node {node.name} has transB=1, flipping to transB=0 for column-major" + ) + node.attrs["transB"] = 0 + continue + + # Get weight shape and dtype from DQ output + # DQ outputs W^T (transposed), shape is [N, K] instead of [K, N] + weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None + weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None + + # Permutation for 2D weights: [1, 0] to transpose back + # The stored weight is B^T (transposed), we need to get B back + # For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original) + perm = [1, 0] + + # Compute the transposed shape (original weight shape) + transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm) + + # Create output variable for the transpose node + transpose_out = gs.Variable( + f"{producer_node.name}_transposed_back", + dtype=weight_dtype, + shape=transposed_weight_shape, + ) + + # Create transpose node: (B^T)^T = B + transpose_node = gs.Node( + op="Transpose", + name=f"{producer_node.name}_transpose_back", + inputs=[weight_input], + outputs=[transpose_out], + attrs={"perm": perm}, + ) + + # Update MatMul/Gemm to use the transposed weight + node.inputs[1] = transpose_out + + # Add transpose node to list + nodes_to_add.append(transpose_node) + + # Add all new nodes to graph + if nodes_to_add: + graph.nodes.extend(nodes_to_add) + logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization") + + # Clean up and reorder graph + graph.cleanup().toposort() + + +def apply_column_major_transformation( + gemm_weights_quantized: dict, + scales: dict, +) -> None: + """Transpose quantized weights and scales in-place for column-major storage. + + Note: After calling this function and inserting DQ nodes with axis=1, + you should call insert_transpose_nodes_for_column_major() on the graph. + + Args: + gemm_weights_quantized: Dictionary mapping weight names to quantized weight arrays + scales: Dictionary mapping weight names to scale arrays + """ + logger.info("Applying column-major storage optimization") + + # Transpose weights and scales in-place + for name in list(gemm_weights_quantized.keys()): + gemm_weights_quantized[name] = gemm_weights_quantized[name].T + + for name in list(scales.keys()): + scales[name] = scales[name].T + + def cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): @@ -1035,3 +1171,30 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) + + +def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: + """Get the names of all quantized tensors from an ONNX model. + + This function identifies all DequantizeLinear nodes in the ONNX model + and extracts the names of tensors being dequantized (the first input of + each DequantizeLinear node, excluding scale and zero-point inputs). + + Args: + onnx_model: ONNX model protobuf to analyze + + Returns: + Set of tensor names that are inputs to DequantizeLinear nodes + (i.e., the tensors being dequantized) + """ + quantized_tensors = set() + + for node in onnx_model.graph.node: + if node.op_type == "DequantizeLinear": + # First input is the tensor being dequantized + # (inputs[1] is scale, inputs[2] is zero-point) + if node.input and len(node.input) > 0: + quantized_tensors.add(node.input[0]) + + logger.debug(f"Found {len(quantized_tensors)} dequantized tensors in ONNX model") + return quantized_tensors diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 96ee406c70..da7ff126dd 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -69,6 +69,8 @@ ) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model from modelopt.onnx.utils import ( + BASE_MIN_OPSET, + QDQ_PRECISION_MIN_OPSET, duplicate_shared_constants, get_opset_version, name_onnx_nodes, @@ -78,6 +80,17 @@ __all__ = ["quantize"] +def _normalize_quantize_mode_for_opset(quantize_mode: str) -> str: + """Map variants like "int4_awq", "int4_rtn", "nvfp4" to their base precision types for lookup purposes.""" + mode_lower = quantize_mode.lower() + if "int4" in mode_lower: + return "int4" + if "nvfp4" in mode_lower or "float4" in mode_lower: + return "float4_e2m1fn" + # For "int8", "fp8", etc., return as-is (fp8 falls back to BASE_MIN_OPSET which is correct) + return quantize_mode + + def _preprocess_onnx( onnx_path: str, use_external_data_format: bool, @@ -88,6 +101,7 @@ def _preprocess_onnx( override_shapes: str, simplify: bool = False, quantize_mode: str = "int8", + opset: int | None = None, ) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]: logger.info(f"Preprocessing the model {onnx_path}") intermediate_generated_files = [] @@ -118,16 +132,45 @@ def _preprocess_onnx( " '--trt_plugins' flag (requires TRT 10+)." ) - # Per-Channel support with QDQ format requires onnx opset version 13 or above - opset_version = get_opset_version(onnx_model) + # Opset 19 is the minimum required for fp16 scales in Q/DQ nodes + # Higher opsets required for specific quantization modes (int4: 21, nvfp4: 23) + original_opset_version = get_opset_version(onnx_model) + + # Determine minimum required opset based on quantization mode + # Normalize quantize_mode to handle variants like "int4_awq", "nvfp4", etc. + normalized_mode = _normalize_quantize_mode_for_opset(quantize_mode) + mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET) + + # Determine target opset version + if opset is not None: + target_opset = opset + # Warn if user-specified opset is below mode minimum (but still respect it) + if opset < mode_min_opset: + logger.warning( + f"Opset {opset} is below the minimum opset {mode_min_opset} required for " + f"{quantize_mode} quantization. Upgrading to opset {mode_min_opset}." + ) + target_opset = mode_min_opset + # Warn if user-specified opset is lower than original + if opset < original_opset_version: + logger.warning( + f"Specified opset {opset} is lower than the original model's opset {original_opset_version}. " + f"Using original model's opset {original_opset_version}." + ) + target_opset = max(target_opset, original_opset_version) + else: + # Use model's opset if it's >= mode_min_opset, otherwise upgrade to mode_min_opset + target_opset = ( + max(original_opset_version, mode_min_opset) + if original_opset_version != 1 + else mode_min_opset + ) - required_opset_version = 13 - if opset_version < required_opset_version and opset_version != 1: - opset_version = required_opset_version - onnx_model = onnx.version_converter.convert_version(onnx_model, opset_version) - onnx_path = os.path.join(output_dir, f"{model_name}_opset{opset_version}.onnx") + if original_opset_version < target_opset and original_opset_version != 1: + onnx_model = onnx.version_converter.convert_version(onnx_model, target_opset) + onnx_path = os.path.join(output_dir, f"{model_name}_opset{target_opset}.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) - logger.info(f"Model is cloned to {onnx_path} with opset_version {opset_version}") + logger.info(f"Model is cloned to {onnx_path} with opset_version {target_opset}") intermediate_generated_files.append(onnx_path) # Simplify model if requested @@ -223,7 +266,7 @@ def quantize( high_precision_dtype: str = "fp16", mha_accumulation_dtype: str = "fp16", disable_mha_qdq: bool = False, - dq_only: bool = True, + dq_only: bool = False, block_size: int | None = None, use_zero_point: bool = False, passes: list[str] = ["concat_elimination"], @@ -231,6 +274,7 @@ def quantize( calibrate_per_node: bool = False, input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, + opset: int | None = None, **kwargs: Any, ) -> None: """Quantizes the provided ONNX model. @@ -302,7 +346,7 @@ def quantize( disable_mha_qdq: Don't add Q/DQ layers to MatMuls in MHA pattern. dq_only: - If True (default), only add DQ nodes to the model. If False, add Q/DQ nodes to the model. + If True, only add DQ nodes to the model. If False (default), add Q/DQ nodes to the model. block_size: Block size parameter for int4 quantization. use_zero_point: @@ -350,6 +394,10 @@ def quantize( direct_io_types: If True, modify the I/O types in the quantized ONNX model to be lower precision whenever possible. If False, keep the I/O types in the quantized ONNX model the same as in the given ONNX model. + opset: + Target ONNX opset version for the quantized model. If None, uses required minimum opset + (19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum, + a warning will be issued and the opset will be upgraded to the required minimum. kwargs: Additional keyword arguments for int4 quantization, including: - awqlite_alpha_step (float): Alpha step for lite, range [0, 1]. @@ -420,6 +468,7 @@ def quantize( override_shapes, # type: ignore[arg-type] simplify, quantize_mode, + opset, ) trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type] @@ -481,6 +530,7 @@ def quantize( calibrate_per_node=calibrate_per_node, custom_ops_to_quantize=list(custom_ops_to_quantize.keys()), direct_io_types=direct_io_types, + opset=opset, **kwargs, ) elif "int4" in quantize_mode: diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index a6b37758ef..4025ea065a 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -15,6 +15,7 @@ """Utility functions related to onnx.""" +import copy import io import os import tempfile @@ -30,6 +31,9 @@ from modelopt.onnx.logging_config import logger +# Base minimum opset for quantization (opset 19 is the first to support fp16 scales) +BASE_MIN_OPSET = 19 + def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]: """This function returns the inputs names of the given onnx model in bytes. @@ -552,7 +556,7 @@ def _get_unique_name(old_name): return onnx_model, is_modified -def check_model(model: onnx.ModelProto) -> onnx.ModelProto: +def check_model(model: onnx.ModelProto) -> None: """Checks if the given model is valid.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit with tempfile.TemporaryDirectory() as temp_dir: @@ -561,10 +565,8 @@ def check_model(model: onnx.ModelProto) -> onnx.ModelProto: onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx") save_onnx(model, onnx_tmp_path, save_as_external_data=True) onnx.checker.check_model(onnx_tmp_path) - return onnx.load(onnx_tmp_path) else: onnx.checker.check_model(model) - return model def find_lowest_common_ancestor(node1: Node, node2: Node) -> tuple[str | None, int, int]: @@ -658,15 +660,16 @@ def save_onnx(model: onnx.ModelProto, onnx_path: str, save_as_external_data: boo # Set ir_version to 10, remove it once ORT supports ir_version 11 model.ir_version = 10 - if save_as_external_data: external_data_path = os.path.basename(onnx_path) + "_data" if os.path.exists(external_data_path): logger.warning(f"Removing existing external data file: {external_data_path}") os.remove(external_data_path) + # Copy so the onnx.ModelProto object will not be modified + model_copy = copy.deepcopy(model) onnx.save_model( - model, + model_copy, onnx_path, save_as_external_data=True, all_tensors_to_one_file=True, @@ -696,6 +699,75 @@ def get_opset_version(model: onnx.ModelProto) -> int: return ai_onnx_domain[0].version +def check_model_uses_external_data(model: onnx.ModelProto) -> bool: + """Checks if the model uses external data. True if any initializer tensor has data_location set to EXTERNAL.""" + return any( + init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL + for init in model.graph.initializer + ) + + +def get_qdq_precisions(model: onnx.ModelProto) -> set: + """Gets the Q/DQ precision types present in the model. + + Args: + model: Loaded in-memory onnx ModelProto. + + Returns: + set: Set of Q/DQ precision types present in the model (e.g., 'float8_e4m3fn', 'int8', + 'int4', 'float4_e2m1fn'). + """ + graph = gs.import_onnx(model) + precisions = set() + + # Check for custom 'NVFP4' nodes + custom_fp4_q_nodes = [node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"] + if custom_fp4_q_nodes: + precisions.add("float4_e2m1fn") + + # Check for precision in DQ nodes + dq_nodes = [node for node in graph.nodes if node.op == "DequantizeLinear"] + for dq_node in dq_nodes: + if len(dq_node.inputs) >= 3 and dq_node.inputs[2] is not None: + # If zero-point is set, return that as the quantization mode + if isinstance(dq_node.inputs[2], Constant) and dq_node.inputs[2].values is not None: + precisions.add(dq_node.inputs[2].values.dtype.name) + elif isinstance(dq_node.inputs[0], Constant) and dq_node.inputs[0].values is not None: + # Else, return the node's input precision (ex: 'NVFP4' weight quantization) + precisions.add(dq_node.inputs[0].values.dtype.name) + + return precisions + + +# Minimum opset requirements by quantization mode/precision +# Base minimum is 19 (first opset that allows fp16 scales in Q/DQ nodes) +# Supports both quantize modes (e.g., "fp8") and dtype prefixes (e.g., "float8" for "float8_e4m3fn") +QDQ_PRECISION_MIN_OPSET = { + "int8": BASE_MIN_OPSET, + "float8_e4m3fn": BASE_MIN_OPSET, + "int4": 21, + "uint4": 21, + "float4_e2m1fn": 23, +} + + +def get_min_opset_for_precisions(precisions: set) -> int: + """Gets the minimum required opset version for a set of Q/DQ precision types. + + Args: + precisions: Set of precision type strings (e.g., 'float8_e4m3fn', 'int4'). + + Returns: + int: Minimum required opset version for the given precisions. + """ + min_opset = BASE_MIN_OPSET # Base minimum for fp16 scales support + for precision in precisions: + # Direct lookup first + if precision in QDQ_PRECISION_MIN_OPSET: + min_opset = max(min_opset, QDQ_PRECISION_MIN_OPSET[precision]) + return min_opset + + def bfloat16_to_float32(bf16_array): """Converts a bfloat16 array (as raw data) to a float32 array.""" uint32_array = bf16_array.astype(np.uint32) << 16 @@ -728,6 +800,366 @@ def get_attribute(node: onnx.NodeProto, attr_name: str) -> Any: raise ValueError(f"Attribute {attr_name} not found in node {node.name}") +def _infer_types_only(model: onnx.ModelProto) -> onnx.ModelProto: + """Infers types (but not shapes) of the onnx graph using local implementation. + + This is an internal function. Use infer_types() as the public API. + + This is a workaround for cases when ONNX's shape inference fails. + ONNX's infer_shapes performs both shape and type inference together, but for AutoCast, we only + need type inference. + + Args: + model: ONNX model to infer types for. + + Returns: + onnx.ModelProto: Model with inferred types updated in value_info and outputs. + """ + from modelopt.onnx.autocast import utils as autocast_utils + + # Get opset version + opset = get_opset_version(model) + + # Process each graph (main graph and all subgraphs) recursively + def infer_types_for_graph( + graph: onnx.GraphProto, parent_node: onnx.NodeProto = None, is_subgraph: bool = False + ) -> None: + """Infer types for a single graph (main or subgraph). + + Args: + graph: The graph to infer types for. + parent_node: The parent node containing this subgraph (None for main graph). + is_subgraph: Whether this is a subgraph (True) or the main graph (False). + """ + # Use graphsurgeon to topologically sort nodes for efficient single-pass traversal + # Create a temporary model with just this graph for graphsurgeon + temp_model = onnx.ModelProto() + temp_model.graph.CopyFrom(graph) + temp_model.opset_import.add().version = opset + temp_model.ir_version = model.ir_version + + try: + gs_graph = gs.import_onnx(temp_model) + gs_graph.toposort() + # Convert back to ONNX to get topologically sorted nodes + sorted_model = gs.export_onnx(gs_graph) + sorted_graph = sorted_model.graph + except Exception as e: + logger.debug( + f"Graphsurgeon toposort failed for {'subgraph' if is_subgraph else 'main graph'}," + f"using original order: {e!s}" + ) + # Fallback: process nodes in original order + sorted_graph = graph + + # Create mappings for quick lookup for this graph + initializer_map = {init.name: init for init in graph.initializer} + value_info_map = {vi.name: vi for vi in graph.value_info} + output_names = {out.name for out in graph.output} + + # Map tensor names to their inferred types (scoped to this graph) + tensor_types = {} + + # Initialize types from inputs and initializers + for inp in graph.input: + if inp.type.HasField("tensor_type"): + tensor_types[inp.name] = inp.type.tensor_type.elem_type + + for init_name, init in initializer_map.items(): + tensor_types[init_name] = init.data_type + + # Helper function to get tensor type + def get_tensor_type_from_name(tensor_name: str) -> int | None: + if tensor_name in tensor_types: + return tensor_types[tensor_name] + if tensor_name in value_info_map: + vi = value_info_map[tensor_name] + return _get_tensor_type(vi) + return None + + # Process nodes in topological order (single pass) + for node in sorted_graph.node: + # Get input types for this node + input_types = [] + for inp_name in node.input: + # an empty tensor name is typically a sign of an optional input, skip it + if not inp_name: + continue + inp_type = get_tensor_type_from_name(inp_name) + if inp_type is None: + raise ValueError(f"Input {inp_name} of node {node.name} has unknown type") + input_types.append(inp_type) + + # Infer output types for this node + output_types = [] + + if node.op_type == "Cast": + # Cast node: output type is the 'to' attribute + cast_to_type = None + for attr in node.attribute: + if attr.name == "to": + cast_to_type = attr.i + break + if cast_to_type is None: + raise ValueError(f"Cast node {node.name} has unknown target type") + output_types = [cast_to_type] + elif node.op_type == "DequantizeLinear": + # DequantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the scale type (input[1]) + # inputs: [data, scale, zero_point (optional)] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] + elif len(node.input) >= 2 and node.input[1]: + scale_type = get_tensor_type_from_name(node.input[1]) + if scale_type is not None: + output_types = [scale_type] + else: + # Fallback: use first input type or FLOAT + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] + else: + # Fallback: use first input type or FLOAT + output_types = [input_types[0] if input_types else onnx.TensorProto.FLOAT] + elif node.op_type == "QuantizeLinear": + # QuantizeLinear: output type is determined by output_dtype attribute if present, + # otherwise use the zero_point type (input[2]) + # inputs: [data, scale, zero_point] + output_dtype = None + for attr in node.attribute: + if attr.name == "output_dtype": + output_dtype = attr.i + break + + if output_dtype is not None: + output_types = [output_dtype] * len(node.output) + elif len(node.input) >= 3 and node.input[2]: + zero_point_type = get_tensor_type_from_name(node.input[2]) + if zero_point_type is not None: + output_types = [zero_point_type] + else: + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] + else: + # Fallback: use INT8 as fallback, since TRT doesn't support UINT8 + output_types = [onnx.TensorProto.INT8] + elif node.op_type == "Constant": + # Constant: output type is from the value attribute's tensor data_type + const_type = None + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + assert const_type is not None + output_types = [const_type] + elif node.op_type == "ConstantOfShape": + # ConstantOfShape: output type is from the value attribute's tensor data_type + # If no value attribute, defaults to FLOAT + # Note: Schema allows multiple types, so we need to check the value attribute + const_type = None + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + if attr.t.HasField("data_type"): + const_type = attr.t.data_type + break + assert const_type is not None + output_types = [const_type] + elif node.op_type == "Split": + # Split schema allows multiple outputs, but the schema only specifies one output type + output_types = [input_types[0]] * len(node.output) + else: + # Check if this node has subgraphs (GRAPH or GRAPHS attributes) + # Common nodes with subgraphs: If, Loop, Scan + subgraphs = [] + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + subgraphs.append(attr.g) + elif attr.type == onnx.AttributeProto.GRAPHS: + subgraphs.extend(attr.graphs) + + # If node has subgraphs, infer types for them first + if subgraphs: + for subgraph in subgraphs: + infer_types_for_graph(subgraph, parent_node=node, is_subgraph=True) + + # For nodes with subgraphs, try to infer output types from subgraph outputs + # This avoids incorrectly matching to control inputs (e.g., condition for If, trip_count for Loop) + output_types = [] + if len(node.output) > 0: + # Use the first subgraph as reference (works for If, Loop, Scan) + first_subgraph = subgraphs[0] + for out_idx, out_name in enumerate(node.output): + if out_idx < len(first_subgraph.output): + subgraph_out = first_subgraph.output[out_idx] + # Typically we only have one subgraph, but If nodes have two subgraphs + # (then_branch and else_branch). In any case, the output types of the + # subgraphs must be identical, so we check just the first one + if ( + subgraph_out.type.HasField("tensor_type") + and subgraph_out.type.tensor_type.elem_type + != onnx.TensorProto.UNDEFINED + ): + output_types.append(subgraph_out.type.tensor_type.elem_type) + else: + output_types.append(onnx.TensorProto.FLOAT) + else: + # Fallback if we can't infer from subgraphs + output_types = None + + # If we couldn't infer from subgraphs, fall through to schema-based inference + if output_types is None or len(output_types) != len(node.output): + output_types = None + else: + # No subgraphs, proceed with normal inference + output_types = None + + # If output_types not set yet, use schema-based inference + if output_types is None: + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + # Use ONNX operator schema to determine output types + try: + schema = onnx.defs.get_schema(node.op_type, opset, domain=node.domain or "") + assert schema.outputs and len(schema.outputs) >= len(node.output) + except Exception as e: + # Fallback: if schema lookup fails, propagate first input type + logger.debug( + f"Node {node.name}: Failed to get schema for {node.op_type}: {e}, " + "propagate first input type" + ) + default_type = input_types[0] if input_types else onnx.TensorProto.FLOAT + output_types = [default_type] * len(node.output) + else: + # Try to infer from schema + input_schemas = [ + schema.inputs[i].type_str for i in range(len(schema.inputs)) + ] + output_schemas = [ + schema.outputs[i].type_str for i in range(len(schema.outputs)) + ] + output_types = [None] * len(node.output) + + for output_idx in range(len(node.output)): + # explicit type is set in schema, use it + if "tensor" in output_schemas[output_idx]: + found_type = onnx_type_str_to_enum(output_schemas[output_idx]) + output_types[output_idx] = found_type + continue + # sometimes output type is set with a placeholder name despite supporting a single type + # e.g. Shape operator is constrained to int64, but the type_str is "T1" + for constraint in schema.type_constraints: + # If output type constraint has only one allowed type, use it directly + if constraint.type_param_str == output_schemas[output_idx]: + if len(constraint.allowed_type_strs) == 1: + found_type = onnx_type_str_to_enum( + constraint.allowed_type_strs[0] + ) + output_types[output_idx] = found_type + break + else: + # We have a placeholder name "T", "T1", "T2", etc that should + # match one of the input types + try: + input_match_idx = input_schemas.index( + output_schemas[output_idx] + ) + except ValueError: + input_match_idx = None + if input_match_idx is not None: + found_type = input_types[input_match_idx] + else: + found_type = default_type + logger.debug( + f"Node {node.name}: Failed to infer type for output " + f"#{output_idx}, propagate first input type" + ) + output_types[output_idx] = found_type + + # Update output tensor types + for out_idx, out_name in enumerate(node.output): + if not out_name or out_idx >= len(output_types): + continue + + output_type = output_types[out_idx] + tensor_types[out_name] = output_type + + # Update value_info if it exists + if out_name in value_info_map: + value_info_map[out_name].type.tensor_type.elem_type = output_type + elif out_name not in output_names: + # Create new value_info for intermediate tensor + new_vi = graph.value_info.add() + new_vi.name = out_name + new_vi.type.tensor_type.elem_type = output_type + value_info_map[out_name] = new_vi + + # Update output types for this graph + for out in graph.output: + if out.name in tensor_types: + out.type.tensor_type.elem_type = tensor_types[out.name] + + # Process main graph and all subgraphs recursively + autocast_utils.walk_subgraphs_recursive(model.graph, infer_types_for_graph, is_subgraph=False) + infer_types_verification(model) + return model + + +def infer_types_verification(model: onnx.ModelProto) -> onnx.ModelProto: + """Verify that all reachable tensors have a defined type. + + This is necessary because some nodes may be removed during the inference process, + leaving unreachable value_info entries. + """ + reachable_tensors = set() + + # Add graph inputs as reachable + for inp in model.graph.input: + reachable_tensors.add(inp.name) + + # Add initializers as reachable + for init in model.graph.initializer: + reachable_tensors.add(init.name) + + # Traverse nodes to find all reachable tensor outputs + for node in model.graph.node: + # A node is reachable if any of its inputs are reachable + # (or if it has no inputs - rare but possible) + node_is_reachable = not node.input or any( + inp in reachable_tensors for inp in node.input if inp + ) + + if node_is_reachable: + # All outputs of a reachable node are reachable + for out in node.output: + if out: # Skip empty output names + reachable_tensors.add(out) + + is_undefined = False + # Check value_info for reachable tensors + for vi in model.graph.value_info: + if vi.name in reachable_tensors: + if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error( + f"Infer types verification failed. Value info {vi.name} has undefined type" + ) + is_undefined = True + + # Graph outputs should always be reachable + for out in model.graph.output: + if out.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: + logger.error(f"Infer types verification failed. Output {out.name} has undefined type") + is_undefined = True + if is_undefined: + raise ValueError( + "Infer types verification failed. Undefined types found in the model - see logs for details." + ) + return model + + def infer_shapes(model: onnx.ModelProto, **kwargs): """Infers shapes of the onnx graph, handles large models.""" if model.ByteSize() > (2 * (1024**3)): # 2GB limit @@ -744,6 +1176,29 @@ def infer_shapes(model: onnx.ModelProto, **kwargs): return onnx.shape_inference.infer_shapes(model, **kwargs) +def infer_types( + model: onnx.ModelProto, use_standalone_type_inference: bool = False, **kwargs +) -> onnx.ModelProto: + """Infers types (and optionally shapes) based on the use_standalone_type_inference flag. + + When use_standalone_type_inference is True, uses a standalone type inference implementation + that only infers types. Otherwise, uses ONNX's infer_shapes which infers both types and shapes. + + Args: + model: ONNX model to infer types/shapes for. + use_standalone_type_inference: If True, use standalone type inference (_infer_types_only). + If False, use ONNX's shape inference (infer_shapes). + **kwargs: Additional arguments passed to infer_shapes when not using standalone type inference. + + Returns: + onnx.ModelProto: Model with inferred types (and shapes if not using standalone type inference). + """ + if use_standalone_type_inference: + return _infer_types_only(model) + else: + return infer_shapes(model, **kwargs) + + def onnx_type_str_to_enum(dtype: str) -> int: """Converts ONNX type in string format to onnx.TensorProto format. diff --git a/modelopt/torch/_deploy/utils/onnx_utils.py b/modelopt/torch/_deploy/utils/onnx_utils.py index a377afcb6b..9120eb73a1 100644 --- a/modelopt/torch/_deploy/utils/onnx_utils.py +++ b/modelopt/torch/_deploy/utils/onnx_utils.py @@ -45,14 +45,3 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> list[str]: if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL ] return model_tensors_ext - - -def check_model_uses_external_data(model: onnx.ModelProto) -> bool: - """ - Checks if the model uses external data. - """ - model_tensors = _get_initializer_tensors(model) - return any( - tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL - for tensor in model_tensors - ) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 26a5781ed6..304fb8ec7a 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -42,6 +42,7 @@ ) from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( + check_model_uses_external_data, get_input_names, get_input_shapes, get_node_names, @@ -55,7 +56,6 @@ from modelopt.torch.utils._pytree import TreeSpec from ..utils.onnx_optimizer import Optimizer -from .onnx_utils import check_model_uses_external_data ModelMetadata = dict[str, Any] ModelType = Any diff --git a/modelopt/torch/distill/__init__.py b/modelopt/torch/distill/__init__.py index a09aa6b8ef..dad15dcc64 100644 --- a/modelopt/torch/distill/__init__.py +++ b/modelopt/torch/distill/__init__.py @@ -19,6 +19,7 @@ from .config import * from .distillation import * from .distillation_model import * +from .layerwise_distillation_model import * from .loss_balancers import * from .losses import * from .registry import * diff --git a/modelopt/torch/distill/config.py b/modelopt/torch/distill/config.py index cfdb3ccb61..74ef153005 100644 --- a/modelopt/torch/distill/config.py +++ b/modelopt/torch/distill/config.py @@ -26,7 +26,7 @@ from .loss_balancers import DistillationLossBalancer -__all__ = ["KDLossConfig"] +__all__ = ["ExportStudentConfig", "KDLossConfig", "LayerwiseKDConfig"] Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007 @@ -120,6 +120,25 @@ def _strict_validate(self) -> None: ) +class LayerwiseKDConfig(KDLossConfig): + """Configuration for the Layerwise Knowledge-Distillation mode. + + This mode is used to distill knowledge from a teacher model to a student model using layerwise distillation. + """ + + @pydantic.field_validator("criterion") + @classmethod + def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]: + """Ensure criterion is a mapping from layer names to loss (potentially entire module).""" + if not isinstance(criterion, dict): + raise ValueError("Layerwise Distillation mode requires explicit criterion pairs.") + if any(key == ("", "") for key in criterion): + raise ValueError( + "Layerwise Distillation mode does not support output-only distillation." + ) + return criterion + + class ExportStudentConfig(ModeloptBaseConfig): """Configuration for the export_student mode. diff --git a/modelopt/torch/distill/distillation_model.py b/modelopt/torch/distill/distillation_model.py index 930b68560d..fa344385a3 100644 --- a/modelopt/torch/distill/distillation_model.py +++ b/modelopt/torch/distill/distillation_model.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - """Meta-model wrapper to support knowledge-distillation learning.""" import inspect @@ -45,6 +43,7 @@ def _setup(self): self._register_temp_attribute("_loss_modules", nn.ModuleList()) self._register_temp_attribute("_only_teacher_fwd", False) self._register_temp_attribute("_only_student_fwd", False) + self._register_temp_attribute("_hook_handles", set()) # HACK: set model's forward signature to match student class' original. # Needed for HF `transformers.utils.find_labels` which relies on inspecting class signature. @@ -57,13 +56,13 @@ def _setup(self): def modify( self, - teacher_model: nn.Module, # To be frozen. + teacher_model: nn.Module, criterion: dict[ tuple[ - str, # Student model layer whose output to capture. - str, # Teacher model layer whose output to capture. + str, # Student model layer whose output to capture + str, # Teacher model layer whose output to capture ], - Loss, # Loss fn. + Loss, # Loss function ], loss_balancer: DistillationLossBalancer | None = None, expose_minimal_state_dict: bool = True, @@ -71,9 +70,8 @@ def modify( """Constructor. Args: - teacher_model: A teacher model which this class would encapsulate. - criterion: A dictionary mapping the tuple of student and teacher - model layer names to the loss function to apply to that layer pair. + teacher_model: The teacher model (will be frozen). + criterion: Dictionary mapping (student_layer_name, teacher_layer_name) to loss functions. loss_balancer: Instance of :class:`DistillationLossBalancer ` which reduces distillation and non-distillation losses into a single value using some weighing scheme. @@ -106,22 +104,30 @@ def modify( {m for m in self._layers_to_loss.values() if len(list(m.parameters())) > 0} ) - # Disable grad for teacher + # Disable grad for teacher. self._teacher_model.requires_grad_(False) - # Register hooks for intermediate outputs from teacher models and the student model. - # HACK: For inexplicable reasons, sometimes a model will have hooks remain after - # `ato.restore()` so we check if they are present accidentally first. + # Use hooks to caputure relevant activation tensors for loss computation. + self._register_hooks() + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" for student_layer, teacher_layer in self._layers_to_loss: setattr(student_layer, "_intermediate_output", None) - if student_output_capture_fwd_hook not in student_layer._forward_hooks.values(): - student_layer.register_forward_hook(student_output_capture_fwd_hook) + handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook) setattr(teacher_layer, "_intermediate_output", None) - if teacher_output_capture_fwd_hook not in teacher_layer._forward_hooks.values(): - teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook) + self._hook_handles.update([handle_s, handle_t]) + + def export(self): + """Export the distillation model.""" + for handle in self._hook_handles: + handle.remove() + self._hook_handles.clear() + return super().export() @property - def teacher_model(self) -> nn.ModuleList: + def teacher_model(self) -> nn.Module: """Fetch the teacher model.""" return self._teacher_model @@ -148,7 +154,7 @@ def hide_teacher_model(self, enable=True): @contextmanager def hide_loss_modules(self, enable=True): - """Context manager to temporarily hide teacher model from the model.""" + """Context manager to temporarily hide loss modules from the model.""" loss_modules = self._loss_modules if enable: self._loss_modules = nn.ModuleList() @@ -169,7 +175,7 @@ def only_teacher_forward(self, enable=True): @contextmanager def only_student_forward(self, enable=True): - """Context manager to temporarily disable forward passes on the student model.""" + """Context manager to temporarily run forward passes only on the student model.""" if enable: self._only_student_fwd = True try: @@ -245,15 +251,13 @@ def compute_kd_loss( Args: student_loss: Original loss computed from the student's output. - loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for - loss-masking situations where the callable changes arguments each iteration. + loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. + Useful for loss-masking situations where the callable changes arguments each iteration. skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar. **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed. - This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``. Returns: - If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses. - If reduce is False, a dict of student model output loss and layer-wise distillation losses. + A dict of losses if skip_balancer is True, else the scalar total loss. """ if self._loss_balancer is None: assert student_loss is None, "Cannot pass in student loss without using Loss Balancer." @@ -288,9 +292,9 @@ def compute_kd_loss( return loss_total -def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if getattr(module, "_only_teacher_fwd", False): return # Might be hooked on entire model fwd @@ -303,9 +307,9 @@ def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): module._intermediate_output = output -def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin +def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): """A hook to capture layer output.""" - # NOTE: Defined externally to allow pickling. + # NOTE: Defined externally to allow pickling during DDP initialization. if module._intermediate_output is not None: # NOTE: cannot tell if train or eval since teacher is always eval diff --git a/modelopt/torch/distill/layerwise_distillation_model.py b/modelopt/torch/distill/layerwise_distillation_model.py new file mode 100644 index 0000000000..e8cbef99fe --- /dev/null +++ b/modelopt/torch/distill/layerwise_distillation_model.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning.""" + +import warnings +from typing import Any + +import torch.nn as nn + +from .distillation_model import DistillationModel, student_output_capture_fwd_hook + +__all__ = ["LayerwiseDistillationModel"] + + +class LayerwiseDistillationModel(DistillationModel): + """Meta-model wrapper to support layerwise-enabled knowledge-distillation learning. + + The LayerwiseDistillationModel is a subclass of the DistillationModel that injects teacher inputs + into the corresponding student layers. This accomodates the case where the student model is the + teacher with specific submodules replaced, which now need to be trained to mimic the original + submodule in the teacher. + """ + + def modify(self, *args, **kwargs): + """Modify the distillation model.""" + super().modify(*args, **kwargs) + + # Freeze student layers except those in criterion. + self.requires_grad_(False) + for student_layer, _ in self._layers_to_loss: + student_layer.requires_grad_(True) + + # Make lm heads (if we have them) no-ops to save compute. + if hasattr(self, "lm_head"): + self._lm_head = self.lm_head + self.lm_head = nn.Identity() + if hasattr(self._teacher_model, "lm_head"): + self._teacher_model._lm_head = self._teacher_model.lm_head + self._teacher_model.lm_head = nn.Identity() + + return self + + def _register_hooks(self): + """Register hooks for intermediate tensors from teacher models and the student model.""" + for student_layer, teacher_layer in self._layers_to_loss: + setattr(student_layer, "_teacher_layer", [teacher_layer]) + handle_s1 = student_layer.register_forward_pre_hook(student_input_bypass_fwd_hook) + setattr(student_layer, "_intermediate_output", None) + handle_s2 = student_layer.register_forward_hook(student_output_capture_fwd_hook) + setattr(teacher_layer, "_intermediate_input", None) + setattr(teacher_layer, "_intermediate_output", None) + handle_t = teacher_layer.register_forward_hook(teacher_input_output_capture_fwd_hook) + self._hook_handles.update([handle_s1, handle_s2, handle_t]) + + def export(self): + """Export the distillation model.""" + for student_layer, _ in self._layers_to_loss: + delattr(student_layer, "_teacher_layer") + + if hasattr(self, "_lm_head"): + self.lm_head = self._lm_head + if hasattr(self._teacher_model, "_lm_head"): + self._teacher_model.lm_head = self._teacher_model._lm_head + + return super().export() + + +def student_input_bypass_fwd_hook(module: nn.Module, input: Any): + """A hook to inject teacher input into corresponding student layer.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if getattr(module, "_only_teacher_fwd", False): + return input # Might be hooked on entire model fwd + + teacher_layer = module._teacher_layer[0] + teacher_input = teacher_layer._intermediate_input + if teacher_input is None: + warnings.warn( + f"Teacher's Module `{type(teacher_layer).__name__}` has no intermediate input stored." + " This is expected when the `only_student_forward` context manager is in use." + ) + return input + + teacher_layer._intermediate_input = None # reset + return teacher_input + + +def teacher_input_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): + """A hook to capture layer input and output.""" + # NOTE: Defined externally to allow pickling during DDP initialization. + + if module._intermediate_output is not None: + # NOTE: cannot tell if train or eval since teacher is always eval + warnings.warn( + f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored." + " This is expected when `DistillationModel.compute_kd_loss` is not called in eval mode." + ) + + module._intermediate_input = input + module._intermediate_output = output diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 75ea751f40..18ccfd9bb0 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -21,24 +21,23 @@ import warnings import torch.nn as nn -from torch.nn.modules.loss import _Loss as Loss from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModeloptStateManager +from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ( ConvertEntrypoint, ConvertReturnType, - MetadataDict, ModeDescriptor, RestoreEntrypoint, - UpdateEntrypoint, _ModeRegistryCls, ) from modelopt.torch.utils import init_model_from_model_like, unwrap_model -from .config import ExportStudentConfig, KDLossConfig +from .config import ExportStudentConfig, KDLossConfig, LayerwiseKDConfig from .distillation_model import DistillationModel -from .registry import DistillationDMRegistry +from .layerwise_distillation_model import LayerwiseDistillationModel +from .registry import DistillationDMRegistry, LayerwiseDistillationDMRegistry DistillModeRegistry = _ModeRegistryCls("distill") @@ -75,17 +74,35 @@ def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" raise NotImplementedError(f"{self.name} mode does not support restore.") - @property - def update_for_new_mode(self) -> UpdateEntrypoint: - """The mode's entrypoint for updating the models state for adding new mode.""" - return _reset_kd_state_config - @property def save_mode_in_state(self) -> bool: """Whether the mode should be saved into the modelopt state.""" return False +@DistillModeRegistry.register_mode +class LayerwiseKDModeDescriptor(KnowledgeDistillationModeDescriptor): + """Class to describe the Layerwise Knowledge-Distillation mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "layerwise_kd" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return LayerwiseKDConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return _convert_for_layerwise + + @DistillModeRegistry.register_mode class ExportStudentModeDescriptor(ModeDescriptor): """Class to describe the specific Export mode to be used with Knowledge Distillation. @@ -124,7 +141,12 @@ def save_mode_in_state(self) -> bool: return False -def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType: +def _convert_for_kd( + model: nn.Module, + config: KDLossConfig, + model_cls: type[nn.Module] = DistillationModel, + model_registry: _DMRegistryCls = DistillationDMRegistry, +) -> ConvertReturnType: """Function for converting a model to a distillation meta-model. This is the only utility needed to use the ``modelopt.torch.distill`` API directly. @@ -158,12 +180,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType # initialize distillation model original_cls = type(student) - if original_cls not in DistillationDMRegistry: - DistillationDMRegistry.register({original_cls: "student_class"})(DistillationModel) + if original_cls not in model_registry: + model_registry.register({original_cls: "student_class"})(model_cls) # TODO (lucasl): look into ways to avoid registering every class manually # (e.g. by just registering nn.Module and disable the "forward" check for the inherited class check - distillation_model = DistillationDMRegistry.convert(student) + distillation_model = model_registry.convert(student) distillation_model.modify( **{**config, "teacher_model": teacher} # overwrite with instantiated teacher ) @@ -174,11 +196,14 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType return distillation_model, metadata -def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): - """Function for resetting the state's config.""" - config.teacher_model = nn.Module - config.criterion = Loss() - config.loss_balancer = None +def _convert_for_layerwise(model: nn.Module, config: LayerwiseKDConfig) -> ConvertReturnType: + """Function for converting a model to a layerwise distillation meta-model.""" + return _convert_for_kd( + model, + config, + model_cls=LayerwiseDistillationModel, + model_registry=LayerwiseDistillationDMRegistry, + ) def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType: diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 500921ce38..dbfad6fb6b 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING import torch +import torch.distributed.nn as dist_nn import torch.nn as nn import torch.nn.functional as F import yaml @@ -57,6 +58,7 @@ class DistillationConfig: skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``). kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``. logit_kl_temperature: Temperature for the logit KL-divergence loss. + logit_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value. """ intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) @@ -64,6 +66,7 @@ class DistillationConfig: skip_lm_loss: bool = True kd_loss_scale: float = 1.0 logit_kl_temperature: float = 1.0 + logit_kl_topk: int | None = None criterion: Criterion | None = None loss_balancer: mtd.DistillationLossBalancer | None = None @@ -123,9 +126,15 @@ def setup_distillation_config( if cfg.criterion is None: criterion = {} if parallel_state.is_pipeline_last_stage(): - criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( - student_cfg, temperature=cfg.logit_kl_temperature - ) + # Use TopKLogitsKLLoss if logit_kl_topk is specified, otherwise use LogitsKLLoss + if cfg.logit_kl_topk is not None: + criterion[tuple(cfg.logit_layers)] = TopKLogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logit_kl_topk + ) + else: + criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature + ) # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) @@ -310,81 +319,143 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions, targets = self.pre_forward(predictions, targets) # Division by temp should happen prior to finding max for both student and teacher. - # Currently we don't use temperature in any of ours runs (temp=1.0) output_teacher = targets.float() / self._temperature output_student = predictions.float() / self._temperature # Compute local softmax, and the reweight to compute global softmax. if self._config.tensor_model_parallel_size > 1: - # Maximum value along vocab dimension across all GPUs. - teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Subtract maximum value along vocab dimension across all GPUs (for stability) + teacher_logits_max, _ = torch.max(output_teacher, dim=-1, keepdim=True) torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + output_teacher -= teacher_logits_max - denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) - # We can't use standard reduction function here since the computation - # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd( - denom_teacher, group=parallel_state.get_tensor_model_parallel_group() - ) - - # Maximum value along vocab dimension across all GPUs. - student_logits_max, _ = torch.max(output_student, dim=-1) + student_logits_max, _ = torch.max(output_student, dim=-1, keepdim=True) torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + output_student -= student_logits_max.detach() - denom_student = torch.sum(torch.exp(output_student), dim=-1) - denom_student = all_reduce_autograd( - denom_student, group=parallel_state.get_tensor_model_parallel_group() - ) + # Compute global softmax denominators + # We can't use standard all_reduce function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True) + denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group) + + denom_student = torch.sum(torch.exp(output_student), dim=-1, keepdim=True) + denom_student = dist_nn.functional.all_reduce(denom_student, group=tp_group) + + # Compute log probabilities (log softmax) + teacher_log_prob = output_teacher - torch.log(denom_teacher) + student_log_prob = output_student - torch.log(denom_student) + + # KL divergence + p, q = student_log_prob, teacher_log_prob + else: + # Compute log probabilities + p, q = F.log_softmax(output_student, dim=-1), F.log_softmax(output_teacher, dim=-1) + + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) + + return self.post_forward(loss, tp_reduce=True) + + +class TopKLogitsKLLoss(LogitsKLLoss): + """Calculates KL-Divergence loss restricted to the Teacher's Top-K vocabulary entries. + + Calculates using the global Top-K entries without gathering full logits. + NOTE: Will gather Top-K logits per rank, so mind the value of K for memory and communication. + """ + + def __init__( + self, + model_config: "TransformerConfig", + temperature: float = 1.0, + reverse: bool = False, + top_k: int = 1024, + ): + """Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + top_k: The number of top vocabulary entries to keep from the teacher's distribution. + """ + super().__init__(model_config, temperature, reverse) + self.top_k = top_k + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. - slen, bsz, sharded_vocab_size = output_student.shape - student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + Top-K KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + tp_size = self._config.tensor_model_parallel_size + assert self.top_k <= targets.size(-1) * tp_size, ( + f"top_k ({self.top_k}) is larger than total vocab size ({targets.size(-1) * tp_size})" + ) + + # Divide by temperature first + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # Extract local Top-K + # We take K from each rank and then find the global Top-K of all those. + local_top_k = min(self.top_k, targets.size(-1)) + top_teacher_vals, top_idx = torch.topk(output_teacher, local_top_k, dim=-1) + top_student_vals = torch.gather(output_student, dim=-1, index=top_idx) + + if tp_size > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Gather all candidates into shape [N_rows, local_k * tp_size] + # Use all_gather from torch.distributed.nn.functional to preserve gradients + all_teacher_vals = dist_nn.functional.all_gather( + top_teacher_vals.contiguous(), group=tp_group ) - teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size + all_student_vals = dist_nn.functional.all_gather( + top_student_vals.contiguous(), group=tp_group ) + all_teacher_vals = torch.cat(all_teacher_vals, dim=-1) + all_student_vals = torch.cat(all_student_vals, dim=-1) - if self._reverse: - loss = torch.sum( - F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), - dim=-1, - ) - else: - loss = torch.sum( - F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), - dim=-1, - ) + # Pick the true Top-K based on Teacher values + global_top_vals, global_top_idx = torch.topk(all_teacher_vals, self.top_k, dim=-1) - elif self._reverse: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_teacher, dim=-1), - F.softmax(output_student, dim=-1), - reduction="none", - ), - dim=-1, - ) + final_teacher_logits = global_top_vals + final_student_logits = torch.gather(all_student_vals, dim=-1, index=global_top_idx) else: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_student, dim=-1), - F.softmax(output_teacher, dim=-1), - reduction="none", - ), - dim=-1, - ) + final_teacher_logits = top_teacher_vals + final_student_logits = top_student_vals - return self.post_forward(loss, tp_reduce=True) + # Standard (dense) Softmax + KL + p = F.log_softmax(final_student_logits, dim=-1) + q = F.log_softmax(final_teacher_logits, dim=-1) + + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) + + # No need to reduce since all ranks compute same global Top-K + return self.post_forward(loss, tp_reduce=False) class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer): @@ -417,7 +488,7 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: """ original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY) for _key in loss_dict: - if _key.startswith(LogitsKLLoss.__name__): + if "Logits" in _key: # class name logits_key = _key # should only be one logits_loss = loss_dict.pop(logits_key) intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1) @@ -481,32 +552,6 @@ def _init_weights(self, module): module.bias.data.zero_() -class _AllReduce(torch.autograd.Function): - """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" - - @staticmethod - def forward(ctx, op, group, tensor): - ctx.group, ctx.op = group, op - tensor = tensor.clone() - torch.distributed.all_reduce(tensor, op=op, group=group) - return tensor - - @staticmethod - def backward(ctx, grad_output): - return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) - - -def all_reduce_autograd( - tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD -): - """Custom all-reduce function. - - Needed instead of other all-reduce functions available when the computation following - the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. - """ - return _AllReduce.apply(op, group, tensor) - - ######################################################## diff --git a/modelopt/torch/distill/registry.py b/modelopt/torch/distill/registry.py index 905378cc7c..69ea47308e 100644 --- a/modelopt/torch/distill/registry.py +++ b/modelopt/torch/distill/registry.py @@ -17,7 +17,10 @@ from modelopt.torch.opt.dynamic import _DMRegistryCls -__all__ = ["DistillationDMRegistry"] +__all__ = ["DistillationDMRegistry", "LayerwiseDistillationDMRegistry"] DistillationDMRegistry = _DMRegistryCls(prefix="Distill") # global instance for the registry + +# Need separate one due to registration override issues when using single registry for both. +LayerwiseDistillationDMRegistry = _DMRegistryCls(prefix="LayerwiseDistill") diff --git a/modelopt/torch/export/__init__.py b/modelopt/torch/export/__init__.py index 8b2ba56f4d..5c0905ba3b 100644 --- a/modelopt/torch/export/__init__.py +++ b/modelopt/torch/export/__init__.py @@ -19,6 +19,7 @@ from .model_config import * from .model_config_export import * from .model_utils import * +from .moe_utils import * from .plugins import * from .transformer_engine import * from .unified_export_hf import * diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py new file mode 100644 index 0000000000..a9bf138767 --- /dev/null +++ b/modelopt/torch/export/diffusers_utils.py @@ -0,0 +1,658 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code that export quantized Hugging Face models for deployment.""" + +import warnings +from collections.abc import Callable +from contextlib import contextmanager +from importlib import import_module +from typing import Any + +import torch +import torch.nn as nn + +from .layer_utils import is_quantlinear + +DiffusionPipeline: type[Any] | None +ModelMixin: type[Any] | None +try: # diffusers is optional for LTX-2 export paths + from diffusers import DiffusionPipeline as _DiffusionPipeline + from diffusers import ModelMixin as _ModelMixin + + DiffusionPipeline = _DiffusionPipeline + ModelMixin = _ModelMixin + _HAS_DIFFUSERS = True +except Exception: # pragma: no cover + DiffusionPipeline = None + ModelMixin = None + _HAS_DIFFUSERS = False + +TI2VidTwoStagesPipeline: type[Any] | None +try: # optional for LTX-2 export paths + from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline as _TI2VidTwoStagesPipeline + + TI2VidTwoStagesPipeline = _TI2VidTwoStagesPipeline +except Exception: # pragma: no cover + TI2VidTwoStagesPipeline = None + + +def is_diffusers_object(model: Any) -> bool: + """Return True if model is a diffusers pipeline/component or LTX-2 pipeline.""" + if not _HAS_DIFFUSERS: + return False + + diffusers_types: tuple[type, ...] = () + if DiffusionPipeline is not None: + diffusers_types = (*diffusers_types, DiffusionPipeline) + if ModelMixin is not None: + diffusers_types = (*diffusers_types, ModelMixin) + if TI2VidTwoStagesPipeline is not None: + diffusers_types = (*diffusers_types, TI2VidTwoStagesPipeline) + + if not diffusers_types: + return False + + return isinstance(model, diffusers_types) + + +def generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + - WanTransformer3DModel: 5D hidden_states + encoder_hidden_states + timestep + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: + try: + module = import_module(module_path) + return isinstance(model, getattr(module, class_name)) + except (ImportError, AttributeError): + return fallback + + is_flux = _is_model_type( + "diffusers.models.transformers", + "FluxTransformer2DModel", + "flux" in model_class_name.lower(), + ) + is_sd3 = _is_model_type( + "diffusers.models.transformers", + "SD3Transformer2DModel", + "sd3" in model_class_name.lower(), + ) + is_dit = _is_model_type( + "diffusers.models.transformers", + "DiTTransformer2DModel", + model_class_name == "DiTTransformer2DModel", + ) + is_wan = _is_model_type( + "diffusers.models.transformers", + "WanTransformer3DModel", + "wan" in model_class_name.lower(), + ) + is_unet = _is_model_type( + "diffusers.models.unets", + "UNet2DConditionModel", + "unet" in model_class_name.lower(), + ) + + cfg = getattr(model, "config", None) + + def _flux_inputs() -> dict[str, torch.Tensor]: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + def _sd3_inputs() -> dict[str, torch.Tensor]: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + def _dit_inputs() -> dict[str, torch.Tensor]: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + def _unet_inputs() -> dict[str, torch.Tensor]: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + def _wan_inputs() -> dict[str, torch.Tensor]: + # WanTransformer3DModel: 5D hidden_states (batch, channels, frames, height, width) + # Requires: hidden_states, encoder_hidden_states, timestep + in_channels = getattr(cfg, "in_channels", 16) + text_dim = getattr(cfg, "text_dim", 4096) + max_seq_len = getattr(cfg, "rope_max_seq_len", 512) + + patch_dtype = getattr(getattr(model, "patch_embedding", None), "weight", None) + patch_dtype = patch_dtype.dtype if patch_dtype is not None else dtype + text_embedder = getattr(getattr(model, "condition_embedder", None), "text_embedder", None) + text_dtype = ( + text_embedder.linear_1.weight.dtype + if text_embedder is not None and hasattr(text_embedder, "linear_1") + else dtype + ) + + # Wan expects num_frames = 4 * n + 1; keep n small for dummy forward + num_frames = 5 + text_seq_len = min(max_seq_len, 512) + + # Keep spatial dims small and divisible by patch size (default 2x2) + height = 8 + width = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, num_frames, height, width, device=device, dtype=patch_dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, text_dim, device=device, dtype=text_dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + if cfg is None: + return None + if not (hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size")): + return None + + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + model_input_builders = [ + ("flux", is_flux, _flux_inputs), + ("sd3", is_sd3, _sd3_inputs), + ("dit", is_dit, _dit_inputs), + ("wan", is_wan, _wan_inputs), + ("unet", is_unet, _unet_inputs), + ] + + for _, matches, build_inputs in model_input_builders: + if matches: + return build_inputs() + + generic_inputs = _generic_transformer_inputs() + if generic_inputs is not None: + return generic_inputs + + return None + + +def generate_diffusion_dummy_forward_fn(model: nn.Module) -> Callable[[], None]: + """Create a dummy forward function for diffusion(-like) models. + + - For diffusers components, this uses `generate_diffusion_dummy_inputs()` and calls `model(**kwargs)`. + - For LTX-2 stage-1 transformer (X0Model), the forward signature is + `model(video: Modality|None, audio: Modality|None, perturbations: BatchedPerturbationConfig)`, + so we build tiny `ltx_core` dataclasses and call the model directly. + """ + # Duck-typed LTX-2 stage-1 transformer wrapper + velocity_model = getattr(model, "velocity_model", None) + if velocity_model is not None: + + def _ltx2_dummy_forward() -> None: + try: + from ltx_core.guidance.perturbations import BatchedPerturbationConfig + from ltx_core.model.transformer.modality import Modality + except Exception as e: # pragma: no cover + raise RuntimeError( + "LTX-2 export requires `ltx_core` to be installed (Modality, BatchedPerturbationConfig)." + ) from e + + # Small shapes for speed/memory + batch_size = 1 + v_seq_len = 8 + a_seq_len = 8 + ctx_len = 4 + + device = next(model.parameters()).device + default_dtype = next(model.parameters()).dtype + + def _param_dtype(module: Any, fallback: torch.dtype) -> torch.dtype: + w = getattr(getattr(module, "weight", None), "dtype", None) + return w if isinstance(w, torch.dtype) else fallback + + def _positions(bounds_dims: int, seq_len: int) -> torch.Tensor: + # [B, dims, seq_len, 2] bounds (start/end) + pos = torch.zeros( + (batch_size, bounds_dims, seq_len, 2), device=device, dtype=torch.float32 + ) + pos[..., 1] = 1.0 + return pos + + has_video = hasattr(velocity_model, "patchify_proj") and hasattr( + velocity_model, "caption_projection" + ) + has_audio = hasattr(velocity_model, "audio_patchify_proj") and hasattr( + velocity_model, "audio_caption_projection" + ) + if not has_video and not has_audio: + raise ValueError( + "Unsupported LTX-2 velocity model: missing both video and audio preprocessors." + ) + + video = None + if has_video: + v_in = int(velocity_model.patchify_proj.in_features) + v_caption_in = int(velocity_model.caption_projection.linear_1.in_features) + v_latent_dtype = _param_dtype(velocity_model.patchify_proj, default_dtype) + v_ctx_dtype = _param_dtype( + velocity_model.caption_projection.linear_1, default_dtype + ) + video = Modality( + enabled=True, + latent=torch.randn( + batch_size, v_seq_len, v_in, device=device, dtype=v_latent_dtype + ), + # LTX `X0Model` uses `timesteps` as the sigma tensor in `to_denoised(sample, velocity, sigma)`. + # It must be broadcastable to `[B, T, D]`, so we use `[B, T, 1]`. + timesteps=torch.full( + (batch_size, v_seq_len, 1), 0.5, device=device, dtype=torch.float32 + ), + positions=_positions(bounds_dims=3, seq_len=v_seq_len), + context=torch.randn( + batch_size, ctx_len, v_caption_in, device=device, dtype=v_ctx_dtype + ), + context_mask=None, + ) + + audio = None + if has_audio: + a_in = int(velocity_model.audio_patchify_proj.in_features) + a_caption_in = int(velocity_model.audio_caption_projection.linear_1.in_features) + a_latent_dtype = _param_dtype(velocity_model.audio_patchify_proj, default_dtype) + a_ctx_dtype = _param_dtype( + velocity_model.audio_caption_projection.linear_1, default_dtype + ) + audio = Modality( + enabled=True, + latent=torch.randn( + batch_size, a_seq_len, a_in, device=device, dtype=a_latent_dtype + ), + timesteps=torch.full( + (batch_size, a_seq_len, 1), 0.5, device=device, dtype=torch.float32 + ), + positions=_positions(bounds_dims=1, seq_len=a_seq_len), + context=torch.randn( + batch_size, ctx_len, a_caption_in, device=device, dtype=a_ctx_dtype + ), + context_mask=None, + ) + + perturbations = BatchedPerturbationConfig.empty(batch_size) + model(video, audio, perturbations) + + return _ltx2_dummy_forward + + # Default: diffusers-style `model(**kwargs)` + def _diffusers_dummy_forward() -> None: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) + if dummy_inputs is None: + raise ValueError( + f"Unknown model type '{type(model).__name__}', cannot generate dummy inputs." + ) + model(**dummy_inputs) + + return _diffusers_dummy_forward + + +def is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check last or second-to-last for cases like "attn.to_q.weight" + return last_part in qkv_patterns or second_last in qkv_patterns + + +def get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def get_diffusion_components( + model: Any, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusion(-like) pipeline. + + Supports: + - diffusers `DiffusionPipeline`: returns `pipeline.components` + - diffusers component `nn.Module` (e.g., UNet / transformer) + - LTX-2 pipeline (duck-typed): returns stage-1 transformer only as `stage_1_transformer` + + Args: + model: The pipeline or component. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + # LTX-2 pipeline: duck-typed stage-1 transformer export + stage_1 = getattr(model, "stage_1_model_ledger", None) + transformer_fn = getattr(stage_1, "transformer", None) + if stage_1 is not None and callable(transformer_fn): + all_components: dict[str, Any] = {"stage_1_transformer": stage_1.transformer()} + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + return all_components + + # diffusers pipeline + if _HAS_DIFFUSERS and DiffusionPipeline is not None and isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + + if isinstance(model, nn.Module): + # Single component model (e.g., UNet2DConditionModel, DiTTransformer2DModel, FluxTransformer2DModel) + component_name = type(model).__name__ + all_components = {component_name: model} + + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + + raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}") + + +# Backward-compatible alias +get_diffusers_components = get_diffusion_components + + +@contextmanager +def hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) + + +def infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 diff --git a/modelopt/torch/export/distribute.py b/modelopt/torch/export/distribute.py index f9d902fd23..4fe7be43ef 100644 --- a/modelopt/torch/export/distribute.py +++ b/modelopt/torch/export/distribute.py @@ -91,6 +91,7 @@ def read_configs_and_weights_from_rank( raise ValueError("NFSWorkspace is not initialized!") state_path = self._get_state_path(target_rank) if state_path.exists(): + # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state = torch.load(state_path, map_location="cpu", weights_only=False) return state["config"], state["weight"] else: diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e35ee070fc..9c68899d92 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -345,7 +345,15 @@ def is_moe(module: nn.Module) -> bool: def is_quantlinear(module: nn.Module) -> bool: """Returns whether the module is a quantized linear layer.""" - return "QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower() + name = type(module).__name__ + return ( + any( + keyword in name + for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"] + ) + and "lora" not in name.lower() + and "ds_kernel" not in name.lower() + ) def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor: diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 306348f2c1..5d282add79 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -33,8 +33,10 @@ QUANTIZATION_INT4_AWQ = "int4_awq" QUANTIZATION_W4A8_AWQ = "w4a8_awq" QUANTIZATION_NVFP4 = "nvfp4" +QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant" QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8" QUANTIZATION_MXFP4 = "mxfp4" +QUANTIZATION_MXFP8 = "mxfp8" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" QUANTIZATION_NVFP4_AWQ = "nvfp4_awq" QUANTIZATION_FP8_PB_REAL = "fp8_pb_real" @@ -507,12 +509,20 @@ def hidden_size(self): """Returns the hidden size of the transformer model.""" if isinstance(self.mlp, MOEConfig): # fc.weight for MOE is stacked - if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if self.mlp.fc.quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: return self.mlp.fc.weight.shape[-1] * 2 return self.mlp.fc.weight.shape[-1] else: k = self.mlp.fc.weight.shape[1] - if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if self.mlp.fc.quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: return k * 2 return k diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad7..637860357d 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -55,6 +55,7 @@ "Deepseek": "deepseek", "Whisper": "whisper", "gptoss": "gptoss", + "MiniMax": "minimax", } __doc__ = f"""Utility functions for model type detection and classification. @@ -85,6 +86,7 @@ def is_multimodal_model(model): - Vision LoRA configurations - Audio processing capabilities - Image embedding layers + - Nemotron-Parse conditional generation models Args: model: The HuggingFace model instance to check @@ -103,6 +105,10 @@ def is_multimodal_model(model): """ config = model.config + # Check for Nemotron-Parse encoder-decoder architecture + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + return ( hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) @@ -112,6 +118,7 @@ def is_multimodal_model(model): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or is_nemotron_parse # Nemotron-Parse conditional generation model ) @@ -141,5 +148,11 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None: if hasattr(model, "language_model"): return [model, model.language_model] - # Pattern 3: No language_model found + # Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model. + # Only match if the model is detected as multimodal to avoid matching non-VLM encoder-decoder + # models like T5, Bart, Whisper which also have .decoder. + if hasattr(model, "decoder") and is_multimodal_model(model): + return [model, model.decoder] + + # Pattern 4: No language_model found return None diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py new file mode 100644 index 0000000000..a5ba465b11 --- /dev/null +++ b/modelopt/torch/export/moe_utils.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Mixture-of-Experts (MoE) model export.""" + +from pathlib import Path + +import torch.nn as nn + + +def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): + """Collect expert_token_count from all quantized MoE layers and save as an HTML table. + + The table has rows for each MoE layer and columns for each expert, with cell values + showing the number of tokens routed to that expert during calibration. + + Args: + model: The model containing quantized MoE layers with ``expert_token_count`` attributes. + output_dir: Directory to save the HTML file. Defaults to current directory. + """ + rows = [] + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + rows.append((name, module.expert_token_count)) + + if not rows: + return + + num_experts = rows[0][1].shape[0] + assert all(r[1].shape[0] == num_experts for r in rows), ( + "All MoE layers must have the same number of experts" + ) + html_parts = [ + "", + "

Expert Token Counts (per MoE layer)

", + "", + ] + html_parts.extend(f"" for i in range(num_experts)) + html_parts.append("") + + for name, counts in rows: + avg = counts.float().mean().item() + html_parts.append(f"") + for c in counts.tolist(): + if avg > 0 and c < avg * 0.05: + style = ' style="background: #ff6666;"' + elif avg > 0 and c < avg * 0.1: + style = ' style="background: #ffcccc;"' + else: + style = "" + html_parts.append(f"{c}") + html_parts.append("") + + html_parts.append("
Layer/Expert{i}
{name}
") + html_content = "\n".join(html_parts) + + if output_dir is None: + output_dir = Path(".") + output_path = Path(output_dir) / ".moe.html" + output_path.write_text(html_content, encoding="utf-8") + print(f"\033[1mExpert token count table saved to {output_path}\033[0m") diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py new file mode 100644 index 0000000000..e89900cbba --- /dev/null +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hugging Face checkpoint utility.""" + +import json +import os +import shutil +from pathlib import Path + +import torch +from safetensors.torch import safe_open +from tqdm import tqdm + + +def copy_remote_code( + pretrained_model_path: str | os.PathLike, + save_directory: str | os.PathLike, +): + """Copy remote code from pretrained model to save directory. + + For models that keep configuration and modeling files as part of the checkpoint, + we need to copy them to the export directory for seamless integration with inference + frameworks. + + Args: + pretrained_model_path: Path to the pretrained model. + save_directory: Path to the save directory. + + Raises: + ValueError: If the pretrained model path is not a directory. + """ + hf_checkpoint_path = Path(pretrained_model_path) + save_dir = Path(save_directory) + + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + for py_file in hf_checkpoint_path.glob("*.py"): + if py_file.is_file(): + shutil.copy(py_file, save_dir / py_file.name) + + +def load_multimodal_components( + pretrained_model_path: str | os.PathLike, +) -> dict[str, torch.Tensor]: + """Load multimodal components from safetensors file. + + Args: + pretrained_model_path: Path to the pretrained model. + + Returns: + A dictionary of multimodal components. + """ + hf_checkpoint_path = Path(pretrained_model_path) + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" + safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" + + multimodal_state_dict = {} + + if safetensors_file.is_file(): + print(f"Loading multimodal components from single file: {safetensors_file}") + with safe_open(safetensors_file, framework="pt") as f: + multimodal_keys = [ + key + for key in f.keys() # noqa: SIM118 + if key.startswith(("multi_modal_projector", "vision_model")) + ] + for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + + elif safetensors_index_file.is_file(): + print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") + with open(safetensors_index_file) as f: + safetensors_index = json.load(f) + + # For multimodal models, vision_model and multi_modal_projector are in the first shard + all_shard_files = sorted(set(safetensors_index["weight_map"].values())) + first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" + + # Load multimodal components from the first shard file + safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file + print(f"Loading multimodal components from {first_shard_file}") + + with safe_open(safetensors_filepath, framework="pt") as f: + shard_keys = list(f.keys()) + multimodal_keys_in_shard = [ + k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model")) + ] + + if multimodal_keys_in_shard: + print( + f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" + ) + for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + else: + print(f"No multimodal components found in {first_shard_file}") + + else: + print(f"Warning: No safetensors files found in {hf_checkpoint_path}") + + print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") + return multimodal_state_dict diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 23804b322d..90c523d849 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -103,6 +103,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class GroupedMLPMerging(CustomModuleMapping): + """A custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that merges up_proj and down_proj for Grouped MLP.""" + super().__init__( + func_name="grouped_mlp_merging", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class GatedMLPMerging(CustomModuleMapping): """A custom module mapping that merges gate_proj and up_proj.""" @@ -127,6 +139,18 @@ def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] ) +class SelfAttentionScaling(CustomModuleMapping): + """A custom module mapping that scales self attention.""" + + def __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}): + """Create a custom module mapping that scales self attention.""" + super().__init__( + func_name="self_attention_scaling", + target_name_or_prefix=target_name_or_prefix, + func_kwargs=func_kwargs, + ) + + class GatedMLPSlicing(CustomModuleMapping): """A custom module mapping that slices gate_proj and up_proj.""" @@ -250,6 +274,59 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): json.dump(safetensor_index, f, indent=4) +def save_safetensors_by_layer_index( + layer_state_dicts: dict[int, dict[str, torch.Tensor]], + total_layers: int, + save_directory: str | os.PathLike, + name_template: str = "model-{:05d}-of-{:05d}", +): + """Save safetensors by layer index. + + Args: + layer_state_dicts: A dictionary of layer state dictionaries. + total_layers: Total number of layers. + save_directory: Path to the save directory. + name_template: Template for the filename. + """ + for layer_index, layer_state_dict in layer_state_dicts.items(): + filename = name_template.format(layer_index, total_layers) + meta_filename = filename + ".json" + ckpt_filename = filename + ".safetensors" + + weight_map = {} + layer_total_size = 0 + for key, val in layer_state_dict.items(): + tensor_size = val.numel() * val.element_size() + layer_total_size += tensor_size + weight_map[key] = ckpt_filename + + with open(save_directory + "/" + meta_filename, "w") as f: + json.dump( + {"metadata": {"total_size": layer_total_size}, "weight_map": weight_map}, + f, + indent=4, + ) + save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + + # [TODO]: this global barrier needs to be replaced with something safer + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + safetensor_index = { + "metadata": {"total_size": 0}, + "weight_map": {}, + } + for layer_index in range(total_layers): + meta_filename = name_template.format(layer_index + 1, total_layers) + ".json" + with open(save_directory + "/" + meta_filename) as f: + shard = json.load(f) + safetensor_index["metadata"]["total_size"] += shard["metadata"]["total_size"] + safetensor_index["weight_map"].update(shard["weight_map"]) + + with open(save_directory + "/model.safetensors.index.json", "w") as f: + json.dump(safetensor_index, f, indent=4) + + def _get_safetensors_file(pretrained_model_path: str | Path, key: str) -> Path | None: """Given a tensor key return the safetensors file that contains this tensor if exists. diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index 03a2c5fe76..7fb8ec76ac 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -30,6 +30,7 @@ PackNameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, UnpackNameRemapping, ) @@ -38,6 +39,8 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + # KV cache quant export + "core_attention": SelfAttentionScaling("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b8..6883c51c94 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -23,9 +23,11 @@ ROW_ETP, ROW_TP, CustomModuleMapping, + GroupedMLPMerging, NameRemapping, QKVMerging, QKVSlicing, + SelfAttentionScaling, ) # Example on adding a new CausalLM. @@ -35,6 +37,7 @@ "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), # NemotronForCausalLM is using square-relu where no gated handle is needed. "linear_fc1": NameRemapping("model.layers.{}.mlp.up_proj."), @@ -55,7 +58,11 @@ "D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE), "dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE), "conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE), - "in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP), + # mapping layer_norm_weight to None tells _name_remapping to skip it; + # the fused layer_norm_weight is loaded separately via the "fused_norm" rule. + "in_proj": NameRemapping( + "backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP), # Attention "input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), @@ -63,8 +70,13 @@ "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), - "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP), + "linear_fc1": NameRemapping( + "backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP), + # Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear + # modules (in_proj, linear_qkv, linear_fc1) when using TE spec. + "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), # MoE "router": NameRemapping( "backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}} @@ -81,9 +93,25 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + # Repeated MTP module + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm.", {"is_mtp": True}), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), + # Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers). + # The prefix uses "backbone" for regular decoder layers; when called from MTP + # context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp". + "experts.linear_fc1": GroupedMLPMerging( + "backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP + ), + "experts.linear_fc2": GroupedMLPMerging( + "backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP + ), } - nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), @@ -101,6 +129,7 @@ "input_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_qkv": QKVSlicing("backbone.layers.{}.mixer."), "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj."), + "core_attention": SelfAttentionScaling("backbone.layers.{}.mixer."), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), @@ -115,4 +144,12 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), + # MTP + "mtp.enorm": NameRemapping("mtp.layers.{}.enorm."), + "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm."), + "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj."), + "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm."), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index f663e19216..a156f2cd8c 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -19,7 +19,7 @@ from pathlib import Path import torch -import torch.distributed +import torch.distributed as dist from huggingface_hub import snapshot_download from tqdm import tqdm @@ -94,12 +94,12 @@ def __init__( if workspace_dir is None: workspace_dir = tempfile.gettempdir() pretrained_model_path = workspace_dir + "/" + pretrained_model_name_or_path - if torch.distributed.get_rank() == 0: + if dist.get_rank() == 0: snapshot_download( repo_id=pretrained_model_name_or_path, local_dir=pretrained_model_path, ) - torch.distributed.barrier() + dist.barrier() self.arch = self._hf_config.architectures[0] self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] @@ -108,7 +108,7 @@ def __init__( self.dtype = dtype self.dequantize = dequantize self.verbose = verbose - self.disable_tqdm = torch.distributed.get_rank() > 0 or verbose + self.disable_tqdm = dist.get_rank() > 0 or verbose def _populate_rule_book(self): """The rule book maps each state_dict key to a Callable.""" @@ -119,6 +119,7 @@ def _custom_mapping_to_lambda(mapping): "name_remapping": self._name_remapping, "qkv_merging": self._qkv_merging, "gated_mlp_merging": self._gated_mlp_merging, + "grouped_mlp_merging": self._grouped_mlp_merging, "unpack_name_remapping": self._unpack_name_remapping, "unpack_name_remapping_gpt_oss": self._unpack_name_remapping_gpt_oss, } @@ -150,7 +151,13 @@ def _name_remapping( mapping={}, parallel_config: ParallelConfig | None = None, dtype: torch.dtype | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") if dtype is None: dtype = self.dtype if isinstance(module, torch.Tensor): @@ -183,7 +190,7 @@ def _name_remapping( tensor = expanded_tensor state_dict["weight"] = tensor.view(dtype=weight.dtype).to(device=weight.device) else: - state_dict["weight"] = tensor.to(dtype=self.dtype).to(device=weight.device) + state_dict["weight"] = tensor.to(dtype=dtype).to(device=weight.device) # Handle the rest of the state_dict. for key, val in module.state_dict().items(): @@ -193,6 +200,12 @@ def _name_remapping( state_dict[key] = val else: source_key = mapping.get(key, key) + # A mapping value of None means "skip this key" (keep existing value). + # This is used for fused TE modules where layer_norm_weight is loaded + # separately from a different HF path. + if source_key is None: + state_dict[key] = val + continue # For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding # since bias should always be replicated, not sharded if ( @@ -216,7 +229,14 @@ def _gated_mlp_merging( gate_proj_name="gate_proj", up_proj_name="up_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + weight = module.state_dict().get("weight", None) weight_scale = module.state_dict().get("weight_quantizer._scale", None) @@ -254,6 +274,33 @@ def _gated_mlp_merging( module.load_state_dict(state_dict) + def _grouped_mlp_merging( + self, + module, + prefix, + parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, + init_expert_id: int = 0, + num_local_experts: int = 1, + ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") + + state_dict = module.state_dict() + + assert module.num_gemms == num_local_experts, ( + "num_gemms must be equal to num_local_experts in TEGroupedMLP" + ) + for expert_id in range(init_expert_id, init_expert_id + num_local_experts): + tensor = self._get_safetensor(prefix.format(expert_id) + ".weight") + state_dict[f"weight{expert_id}"] = tensor + # TODO handle weight_scale + + module.load_state_dict(state_dict) + def _qkv_merging( self, module, @@ -262,7 +309,13 @@ def _qkv_merging( k_proj_name="k_proj", v_proj_name="v_proj", parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, ): + if is_mtp: + if "backbone" in prefix: + prefix = prefix.replace("backbone", "mtp") + else: + prefix = prefix.replace("model", "mtp") config = module.config hidden_size = config.hidden_size num_query_groups = config.num_query_groups @@ -289,8 +342,9 @@ def _qkv_merging( state_dict = {} - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) if weight is None: raise ValueError(f"{module!s} does not contain weight!") @@ -344,7 +398,7 @@ def _qkv_merging( state_dict["weight"] = tensor.reshape(-1, hidden_size) # Handle bias merging - bias = module.state_dict().get("bias", None) + bias = module_state_dict.get("bias", None) if bias is not None: q_bias = self._get_safetensor( prefix + q_proj_name + ".bias", parallel_config=parallel_config @@ -371,6 +425,11 @@ def _qkv_merging( state_dict["bias"] = bias_tensor.reshape(-1) + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + state_dict["_extra_state"] = None # for TE modules require _extra_state key + module.load_state_dict(state_dict) def _unpack_name_remapping( @@ -379,6 +438,7 @@ def _unpack_name_remapping( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor = self._get_safetensor(prefix, parallel_config=parallel_config) @@ -409,6 +469,7 @@ def _unpack_name_remapping_gpt_oss( prefix, layer_type: str, parallel_config: ParallelConfig | None = None, + is_mtp: bool = False, # no-op: necessary for _import_transformer_layer ): tensor_blocks = self._get_safetensor(prefix + "_blocks", parallel_config=parallel_config) tensor_bias = self._get_safetensor(prefix + "_bias", parallel_config=parallel_config) @@ -469,9 +530,188 @@ def _unpack_name_remapping_gpt_oss( linear_module.load_state_dict(state_dict) + def _import_mamba_layer(self, layer, layer_id, layer_pbar): + layer_pbar.set_description("Importing Mamba layer") + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) + + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) + + # TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.norm, IdentityOp) + and hasattr(layer.mixer.in_proj, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id) + + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) + + attention = layer.self_attention + if not isinstance(attention, IdentityOp): + if "MLASelfAttention" in str(type(attention)): + if hasattr(attention, "linear_q_proj"): + layer_pbar.set_description("Importing MLA (without q LoRA)") + self.rules["linear_q_proj"](attention.linear_q_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing MLA (with q LoRA)") + self.rules["linear_q_down_proj"]( + attention.linear_q_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_q_up_proj"]( + attention.linear_q_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_down_proj"]( + attention.linear_kv_down_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_kv_up_proj"]( + attention.linear_kv_up_proj, layer_id, is_mtp=is_mtp + ) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + else: + layer_pbar.set_description("Importing GQA/MHA") + if attention.q_layernorm is not None and not isinstance( + attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](attention.q_layernorm, layer_id, is_mtp=is_mtp) + self.rules["k_layernorm"](attention.k_layernorm, layer_id, is_mtp=is_mtp) + self.rules["linear_qkv"](attention.linear_qkv, layer_id, is_mtp=is_mtp) + self.rules["linear_proj"](attention.linear_proj, layer_id, is_mtp=is_mtp) + if getattr(attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp + ) + + # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.input_layernorm, IdentityOp) + and hasattr(attention, "linear_qkv") + and hasattr(attention.linear_qkv, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) + + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + layer_pbar.set_description( + f"Importing MoE with moe_router_dtype: {self.moe_router_dtype}" + ) + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype, is_mtp=is_mtp + ) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"]( + layer.mlp.fc1_latent_proj, layer_id, is_mtp=is_mtp + ) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"]( + layer.mlp.fc2_latent_proj, layer_id, is_mtp=is_mtp + ) + + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + layer_pbar.set_description("Importing MoE shared experts") + fc1 = layer.mlp.shared_experts.linear_fc1 + fc2 = layer.mlp.shared_experts.linear_fc2 + self.rules["shared_experts.linear_fc1"](fc1, layer_id, is_mtp=is_mtp) + self.rules["shared_experts.linear_fc2"](fc2, layer_id, is_mtp=is_mtp) + if not self.rules.get("use_packed_local_experts", False): # Import local experts + experts = layer.mlp.experts + if hasattr(experts, "local_experts"): + for local_expert_id, expert in tqdm( + enumerate(layer.mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + expert_id = layer.mlp.local_expert_indices[local_expert_id] + fc1 = expert.linear_fc1 + fc2 = expert.linear_fc2 + self.rules["local_experts.linear_fc1"]( + fc1, layer_id, expert_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2"]( + fc2, layer_id, expert_id, is_mtp=is_mtp + ) + else: # Slice TEGroupedMLP + layer_pbar.set_description("Importing MoE grouped local experts") + num_local_experts = experts.num_local_experts + num_global_experts = experts.config.num_moe_experts + assert num_local_experts == num_global_experts, ( + "num_local_experts must be equal to num_global_experts during MoE import" + ) + init_index = 0 + + self.rules["experts.linear_fc1"]( + experts.linear_fc1, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + is_mtp=is_mtp, + ) + self.rules["experts.linear_fc2"]( + experts.linear_fc2, + layer_id, + init_expert_id=init_index, + num_local_experts=num_local_experts, + is_mtp=is_mtp, + ) + + # We only support either EP or ETP for now + elif get_expert_tensor_parallel_world_size() > 1: + # ETP supports for packed MoE + # ETP is not supported for gpt-oss model + if self.arch in ["GptOssForCausalLM"]: + raise ValueError("ETP is not supported for gpt-oss model") + self.rules["local_experts.linear_fc1_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_etp"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + # EP supports for packed MoE + self.rules["local_experts.linear_fc1_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + self.rules["local_experts.linear_fc2_ep"]( + layer.mlp.experts.local_experts, layer_id, is_mtp=is_mtp + ) + else: + layer_pbar.set_description("Importing MLP") + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + + # TE spec: pre_mlp_layernorm is fused into linear_fc1 + # (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.pre_mlp_layernorm, IdentityOp) + and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + def _import_state_dict(self): model = self.model - layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) # Embedding @@ -481,113 +721,18 @@ def _import_state_dict(self): # Decoder layers for layer in layer_pbar: + layer_pbar.set_description(f"Importing Decoder layer {layer.layer_number}") layer_id = layer.layer_number - 1 if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) - + self._import_mamba_layer(layer, layer_id, layer_pbar) elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - attention = layer.self_attention - if not isinstance(attention, IdentityOp): - if "MLASelfAttention" in str(type(attention)): - if hasattr(attention, "linear_q_proj"): - layer_pbar.set_description("Importing MLA (without q LoRA)") - self.rules["linear_q_proj"](attention.linear_q_proj, layer_id) - else: - layer_pbar.set_description("Importing MLA (with q LoRA)") - self.rules["linear_q_down_proj"](attention.linear_q_down_proj, layer_id) - self.rules["linear_q_layernorm"](attention.q_layernorm, layer_id) - self.rules["linear_q_up_proj"](attention.linear_q_up_proj, layer_id) - self.rules["linear_kv_down_proj"](attention.linear_kv_down_proj, layer_id) - self.rules["linear_kv_layernorm"](attention.kv_layernorm, layer_id) - self.rules["linear_kv_up_proj"](attention.linear_kv_up_proj, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - else: - layer_pbar.set_description("Importing GQA/MHA") - if attention.q_layernorm is not None and not isinstance( - attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](attention.q_layernorm, layer_id) - self.rules["k_layernorm"](attention.k_layernorm, layer_id) - self.rules["linear_qkv"](attention.linear_qkv, layer_id) - self.rules["linear_proj"](attention.linear_proj, layer_id) - if getattr(attention.core_attention, "softmax_offset", None) is not None: - self.rules["softmax_offset"]( - attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - layer_pbar.set_description("Importing MoE") - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - layer_pbar.set_description("Importing MoE shared experts") - fc1 = layer.mlp.shared_experts.linear_fc1 - fc2 = layer.mlp.shared_experts.linear_fc2 - self.rules["shared_experts.linear_fc1"](fc1, layer_id) - self.rules["shared_experts.linear_fc2"](fc2, layer_id) - if not self.rules.get("use_packed_local_experts", False): - for local_expert_id, expert in tqdm( - enumerate(layer.mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - expert_id = layer.mlp.local_expert_indices[local_expert_id] - fc1 = expert.linear_fc1 - fc2 = expert.linear_fc2 - self.rules["local_experts.linear_fc1"](fc1, layer_id, expert_id) - self.rules["local_experts.linear_fc2"](fc2, layer_id, expert_id) - # We only support either EP or ETP for now - elif get_expert_tensor_parallel_world_size() > 1: - # ETP supports for packed MoE - # ETP is not supported for gpt-oss model - if self.arch in ["GptOssForCausalLM"]: - raise ValueError("ETP is not supported for gpt-oss model") - self.rules["local_experts.linear_fc1_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_etp"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - # EP supports for packed MoE - self.rules["local_experts.linear_fc1_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2_ep"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - layer_pbar.set_description("Importing MLP") - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + self._import_transformer_layer(layer, layer_id, layer_pbar) if self.verbose: print( "{:3}/{:3} completes importing layer {:3}.".format( - torch.distributed.get_rank(), torch.distributed.get_world_size(), layer_id + dist.get_rank(), dist.get_world_size(), layer_id ), flush=True, ) @@ -595,67 +740,92 @@ def _import_state_dict(self): # Final layernorm if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: self.rules["final_layernorm"](model.decoder.final_layernorm) - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: self.rules["final_norm"](model.decoder.final_norm) # Output layer if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) + # MTP if hasattr(model, "mtp"): - # MTP is the last layer in DeepSeek V3/R1 - layer_id += 1 - for mtp in model.mtp: - self.rules["mtp.fc"](mtp.fc, layer_id) + layer_pbar.set_description("Importing MTP") + if len(model.mtp.layers) == 1: # Repeated MTP + layer_id = 0 # reset layer_id for repeated MTP + mtp = model.mtp.layers[0] + + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) self.rules["mtp.enorm"](mtp.enorm, layer_id) self.rules["mtp.hnorm"](mtp.hnorm, layer_id) - self.rules["mtp.input_layernorm"](mtp.decoder.layers[0].input_layernorm, layer_id) - if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): - self.rules["mtp.linear_q_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + + mtp_model_layers = mtp.mtp_model_layer.layers + for mtp_model_layer in mtp_model_layers: + if isinstance(mtp_model_layer, TransformerLayer): + self._import_transformer_layer( + mtp_model_layer, layer_id, layer_pbar, is_mtp=True + ) + else: + raise ValueError( + f"Unsupported layer type during MTP import: {type(mtp_model_layer)}.\n" + "Only TransformerLayer is supported." + ) + + layer_id += 1 + else: # non-repeated MTP + # MTP is the last layer in DeepSeek V3/R1 + layer_id += 1 + for mtp in model.mtp.layers: + self.rules["mtp.eh_proj"](mtp.eh_proj, layer_id) + self.rules["mtp.enorm"](mtp.enorm, layer_id) + self.rules["mtp.hnorm"](mtp.hnorm, layer_id) + self.rules["mtp.input_layernorm"]( + mtp.decoder.layers[0].input_layernorm, layer_id ) - else: - self.rules["mtp.linear_q_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + if hasattr(mtp.decoder.layers[0].self_attention, "linear_q_proj"): + self.rules["mtp.linear_q_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_proj, layer_id + ) + else: + self.rules["mtp.linear_q_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_down_proj, layer_id + ) + self.rules["mtp.linear_q_layernorm"]( + mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + ) + self.rules["mtp.linear_q_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + ) + self.rules["mtp.linear_kv_down_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id ) - self.rules["mtp.linear_q_layernorm"]( - mtp.decoder.layers[0].self_attention.q_layernorm, layer_id + self.rules["mtp.linear_kv_layernorm"]( + mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id ) - self.rules["mtp.linear_q_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_q_up_proj, layer_id + self.rules["mtp.linear_kv_up_proj"]( + mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id ) - self.rules["mtp.linear_kv_down_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_down_proj, layer_id - ) - self.rules["mtp.linear_kv_layernorm"]( - mtp.decoder.layers[0].self_attention.kv_layernorm, layer_id - ) - self.rules["mtp.linear_kv_up_proj"]( - mtp.decoder.layers[0].self_attention.linear_kv_up_proj, layer_id - ) - self.rules["mtp.linear_proj"]( - mtp.decoder.layers[0].self_attention.linear_proj, layer_id - ) - self.rules["mtp.pre_mlp_layernorm"]( - mtp.decoder.layers[0].pre_mlp_layernorm, layer_id - ) - self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) - self.rules["mtp.shared_experts.linear_fc1"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["mtp.shared_experts.linear_fc2"]( - mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in tqdm( - enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), - desc="Importing MoE local experts", - leave=False, - disable=self.disable_tqdm, - ): - self.rules["mtp.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + self.rules["mtp.linear_proj"]( + mtp.decoder.layers[0].self_attention.linear_proj, layer_id + ) + self.rules["mtp.pre_mlp_layernorm"]( + mtp.decoder.layers[0].pre_mlp_layernorm, layer_id ) - self.rules["mtp.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + self.rules["mtp.router"](mtp.decoder.layers[0].mlp.router, layer_id) + self.rules["mtp.shared_experts.linear_fc1"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc1, layer_id ) + self.rules["mtp.shared_experts.linear_fc2"]( + mtp.decoder.layers[0].mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in tqdm( + enumerate(mtp.decoder.layers[0].mlp.experts.local_experts), + desc="Importing MoE local experts", + leave=False, + disable=self.disable_tqdm, + ): + self.rules["mtp.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["mtp.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 95b194c3ff..3f69271b06 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -72,7 +72,7 @@ class VllmFqGPTModelExporter(GPTModelExporter): def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) @@ -91,7 +91,7 @@ def _get_quantization_format(self, module: torch.nn.Module): def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), diff --git a/modelopt/torch/export/postprocess.py b/modelopt/torch/export/postprocess.py index 5c3d0fcf35..376a52a413 100644 --- a/modelopt/torch/export/postprocess.py +++ b/modelopt/torch/export/postprocess.py @@ -35,6 +35,7 @@ LINEAR_ROW, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ConvConfig, EmbeddingConfig, ExpertConfig, @@ -398,7 +399,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None): group_size=config.awq_block_size, quantization=config.quantization, ) - if config.quantization == QUANTIZATION_NVFP4_AWQ: + if config.quantization in [ + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: # We have to update weight_scaling_factor and weight_scaling_factor_2 config.weights_scaling_factor, config.weights_scaling_factor_2 = ( NVFP4QTensor.get_weights_scaling_factor( @@ -430,6 +434,7 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None): if config.quantization in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: ( config.weights_scaling_factor, diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc518..0d99d44f04 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -25,11 +25,16 @@ import torch.nn as nn from modelopt import __version__ -from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection +from modelopt.torch.quantization.model_calib import ( + enable_stats_collection, + finish_stats_collection, + svd, +) from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import ( FP8QTensor, MXFP4QTensor, + MXFP8QTensor, NVFP4QTensor, QTensorWrapper, ) @@ -54,9 +59,11 @@ QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO, QUANTIZATION_MXFP4, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_W4A8_NVFP4_FP8, @@ -165,7 +172,7 @@ def resmooth_and_get_scale( ) new_weights.append(weight) # If NVFP4_AWQ then we view the scales as uint8 to allow for cat later - if quantization == QUANTIZATION_NVFP4_AWQ: + if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]: scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8) else: scale = get_scaling_factor_from_weight(weight, group_size) @@ -176,7 +183,7 @@ def resmooth_and_get_scale( return ( torch.cat(new_weights, dim=0), resmoothed_scales.view(torch.float8_e4m3fn) - if quantization == QUANTIZATION_NVFP4_AWQ + if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT] else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat new_pre_quant_scale, ) @@ -231,6 +238,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: return scaling_factor +def _ensure_weight_quantizer_calibrated( + weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = "" +) -> None: + """Calibrate weight quantizer if amax is not set. + + This is a lazy calibration pattern used during export when weight quantizers + may not have been calibrated during the main calibration phase. + + Args: + weight_quantizer: The weight quantizer to calibrate + weight: The weight tensor to use for calibration + module_name: Optional module name for better warning messages + """ + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + warn( + f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. " + f"Computing amax from weights. This may occur if: " + f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size" + ) + weight_quantizer.reset_amax() + enable_stats_collection(weight_quantizer) + weight_quantizer(weight) + finish_stats_collection(weight_quantizer) + + def get_activation_scaling_factor( module: nn.Module, input_quantizer_name: str = "input_quantizer" ) -> torch.Tensor: @@ -243,6 +275,7 @@ def get_activation_scaling_factor( if get_quantization_format(module) in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: return NVFP4QTensor.get_activation_scaling_factor(input_quantizer) return get_scaling_factor(input_quantizer) @@ -270,8 +303,13 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: + # Calibrate weight quantizer if amax is not set + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. @@ -290,6 +328,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[ 1 ].reshape(*weight.shape[:-1], -1) + + if quantization_format == QUANTIZATION_MXFP8: + return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer) return get_scaling_factor(weight_quantizer) @@ -300,12 +341,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") if weight_quantizer is None: return None - if get_quantization_format(module) in [ + quantization_format = get_quantization_format(module) + + # Calibrate weight quantizer if amax is not set for all NVFP4 variants + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A8_NVFP4_FP8, + ]: + weight = getattr(module, weight_name) + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + + if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: + elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. return weight_quantizer._amax.float() / 448.0 @@ -345,18 +400,27 @@ def get_kv_cache_bias(kv_module: nn.Module) -> list[torch.Tensor]: return kv_bias -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - if not hasattr(kv_module, "k_bmm_quantizer") or not hasattr(kv_module, "v_bmm_quantizer"): +def get_kv_cache_scaling_factor(self_attention_module: nn.Module) -> list[torch.Tensor | None]: + """Get the K and V BMM scaling factors for the self attention module. + + Args: + self_attention_module: The self attention module to get the K and V BMM scaling factors from. + + Returns: + The K and V BMM scaling factors. + """ + if not hasattr(self_attention_module, "k_bmm_quantizer") or not hasattr( + self_attention_module, "v_bmm_quantizer" + ): return [None, None] scaling_factors = [ - get_scaling_factor(getattr(kv_module, quantizer)) + get_scaling_factor(getattr(self_attention_module, quantizer)) for quantizer in ("k_bmm_quantizer", "v_bmm_quantizer") ] # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: + if get_kv_cache_dtype(self_attention_module) == KV_CACHE_FP8: for i, factor in enumerate(scaling_factors): if factor.item() > 0.5: warn( @@ -366,7 +430,6 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> list[torch.Tensor]: scaling_factors[i] = torch.max( factor, torch.tensor([1.0], dtype=torch.float, device=factor.device) ) - return scaling_factors @@ -397,6 +460,23 @@ def get_kv_cache_dtype(modules: list[nn.Module] | nn.Module) -> str | None: num_bits_list.append(quantizer_attr.num_bits) is_affine &= hasattr(quantizer_attr, "_bias_value") + return _compute_kv_cache_dtype(num_bits_list) + + +def _compute_kv_cache_dtype(num_bits_list: list[int | tuple[int, int]]) -> str | None: + """Returns the kv_cache dtype. + + If num_bits of output_quantizer is (4, 3) then returns FP8; if it is 8, returns int8, + otherwise returns None. + + Args: + modules: The module or list of modules to inspect. + + Returns: + The kv_cache dtype. + """ + is_affine = True + if (4, 3) in num_bits_list: return KV_CACHE_FP8 elif 8 in num_bits_list: @@ -474,6 +554,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if weight_quantizer.num_bits == (4, 3): if weight_quantizer.block_sizes: assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer" + # Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0)) + block_sizes = getattr(weight_quantizer, "block_sizes") + if ( + isinstance(block_sizes, dict) + and block_sizes.get("type", "static") == "dynamic" + and block_sizes.get("scale_bits") == (8, 0) + ): + return QUANTIZATION_MXFP8 if weight_quantizer.fake_quant: return QUANTIZATION_FP8_PB_WO else: @@ -487,6 +575,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames block_sizes = getattr(weight_quantizer, "block_sizes") scale_bits = block_sizes.get("scale_bits") + if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"): + return QUANTIZATION_NVFP4_SVDQUANT if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ if getattr(layer, "fused_with_prequant", False): @@ -660,15 +750,23 @@ def process_layer_quant_config(layer_config_dict): elif v == "w4a8_nvfp4_fp8": layer_config = { "quant_algo": "W4A8_NVFP4_FP8", - "group_size": layer_config_dict[prefix + ".awq_block_size"], - "has_zero_point": False, - "pre_quant_scale": True, + "group_size": block_size_value, } elif v == "w4a8_mxfp4_fp8": layer_config = { "quant_algo": "W4A8_MXFP4_FP8", "group_size": block_size_value, } + elif v == "nvfp4_svdquant": + layer_config = { + "quant_algo": "NVFP4_SVD", + "group_size": block_size_value, + } + elif v == "mxfp8": + layer_config = { + "quant_algo": "MXFP8", + "group_size": block_size_value, + } else: layer_config = {"quant_algo": v} @@ -773,6 +871,9 @@ def to_quantized_weight( if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) + if quantization == QUANTIZATION_MXFP8: + return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor) + if quantization == QUANTIZATION_FP8_PB_WO: return FP8QTensor.quantize( weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size} @@ -813,7 +914,12 @@ def to_quantized_weight( if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]: return pack_int4_in_uint8(weight, weights_scaling_factor) - if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]: + if quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_NVFP4_SVDQUANT, + ]: assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format." assert weights_scaling_factor2 is not None, ( "Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format" @@ -881,7 +987,13 @@ def postprocess_state_dict( "v_bmm_quantizer._bias_value": "v_proj.v_bias", "input_quantizer._pre_quant_scale": "pre_quant_scale", } - skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + skip_keys = [ + "output_quantizer", + "_amax", + "_bias_value", + "input_quantizer._pre_quant_scale", + "weight_shape", + ] # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment if is_modelopt_qlora: @@ -931,7 +1043,6 @@ def postprocess_state_dict( # We export real value for KV_CACHE_NVFP4 if quantization == KV_CACHE_FP8: value.clamp_(min=1.0) - post_state_dict[prefix + new_suffix] = value break @@ -1008,6 +1119,40 @@ def _update_pre_quant_scale(module, new_pre_quant_scale): finish_stats_collection(module.weight_quantizer) +def _update_svdquant(modules, new_pre_quant_scale): + """Updates the pre_quant_scale, svdquant_lora_a and svdquant_lora_b matrices when pre_quant_scale is changed.""" + new_pre_quant_scale = new_pre_quant_scale.to(torch.float32) + lora_a = [m.weight_quantizer.svdquant_lora_a.to(torch.float32) for m in modules] + lora_b = [m.weight_quantizer.svdquant_lora_b.to(torch.float32) for m in modules] + weight = [m.weight.to(torch.float32) for m in modules] + old_pre_quant_scale = [m.input_quantizer._pre_quant_scale.to(torch.float32) for m in modules] + weight = [ + (w + (lb @ la)) * (s / new_pre_quant_scale) + for w, la, lb, s in zip(weight, lora_a, lora_b, old_pre_quant_scale) + ] + weight_concatenated = torch.cat(weight, dim=0) + lb, la = svd(weight_concatenated, rank=lora_a[0].shape[0]) + weight_concatenated -= lb @ la + weight_concatenated = weight_concatenated.to(modules[0].weight.dtype) + la = la.to(modules[0].weight_quantizer.svdquant_lora_a.dtype) + lb = lb.to(modules[0].weight_quantizer.svdquant_lora_b.dtype) + new_pre_quant_scale = new_pre_quant_scale.to(modules[0].input_quantizer.pre_quant_scale.dtype) + + index = 0 + for i, module in enumerate(modules): + module.input_quantizer.pre_quant_scale = new_pre_quant_scale + module.weight_quantizer.svdquant_lora_a = la + assert lora_b[i].shape[0] == module.weight.shape[0] + module.weight_quantizer.svdquant_lora_b = lb[index : index + lora_b[i].shape[0], :] + module.weight = nn.Parameter(weight_concatenated[index : index + lora_b[i].shape[0], :]) + index += lora_b[i].shape[0] + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) + + # Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) PQS_FUSE_MODULE_MAPPING = [ # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension @@ -1101,6 +1246,16 @@ def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False): setattr(linear_pqs_from, "fused_with_prequant", True) +def _layernorm_uses_weight_plus_one(module: torch.nn.Module) -> bool: + if any( + name in type(module).__name__ + for name in ["LayerNorm1P", "GemmaRMSNorm", "Gemma2RMSNorm", "Gemma3RMSNorm"] + ): + return True + + return bool(hasattr(module, "zero_centered_gamma") and module.zero_centered_gamma) + + def fuse_prequant_layernorm( layernorm_module: torch.nn.Module, modules: list[torch.Tensor], @@ -1116,13 +1271,17 @@ def fuse_prequant_layernorm( fused_bias = bias * avg_pre_quant_scale layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias """ - layernorm_module.weight = torch.nn.Parameter( - layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") + pre_quant_scale = getattr(modules[0].input_quantizer, "_pre_quant_scale").to( + layernorm_module.weight.device ) + if _layernorm_uses_weight_plus_one(layernorm_module): + # For norms that use (1 + weight) in forward, fold pre_quant_scale into the effective weight. + fused_weight = (layernorm_module.weight + 1.0) * pre_quant_scale - 1.0 + else: + fused_weight = layernorm_module.weight * pre_quant_scale + layernorm_module.weight = torch.nn.Parameter(fused_weight.to(layernorm_module.weight.dtype)) if hasattr(layernorm_module, "bias") and layernorm_module.bias is not None: - layernorm_module.bias = torch.nn.Parameter( - layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") - ) + layernorm_module.bias = torch.nn.Parameter(layernorm_module.bias * pre_quant_scale) # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm for module in modules: delattr(module.input_quantizer, "_pre_quant_scale") @@ -1146,9 +1305,14 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False dim=0, ) - for module in modules: - if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale): - _update_pre_quant_scale(module, avg_prequant_scale) + if all( + getattr(m.weight_quantizer, "svdquant_lora_a", None) is not None for m in modules + ): + _update_svdquant(modules, avg_prequant_scale) + else: + for module in modules: + if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale): + _update_pre_quant_scale(module, avg_prequant_scale) if resmooth_only: return @@ -1286,3 +1450,18 @@ def get_quant_config( quant_config["quantization"]["kv_cache_quant_algo"] = kv_cache_format return quant_config + + +def has_quantized_modules(model: nn.Module) -> bool: + """Check if a model has any quantized modules. + + Args: + model: The model to check. + + Returns: + True if the model contains quantized modules, False otherwise. + """ + return any( + get_quantization_format(sub_module) != QUANTIZATION_NONE + for _, sub_module in model.named_modules() + ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ccfc012001..ca80cb450d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -22,19 +22,43 @@ import warnings from builtins import ValueError from collections import defaultdict +from collections.abc import Callable from pathlib import Path from typing import Any import torch import torch.nn as nn from safetensors.torch import save_file + +try: + import diffusers + + from .diffusers_utils import ( + generate_diffusion_dummy_forward_fn, + get_diffusion_components, + get_qkv_group_key, + hide_quantizers_from_state_dict, + infer_dtype_from_model, + is_diffusers_object, + is_qkv_projection, + ) + + HAS_DIFFUSERS = True +except ImportError: + HAS_DIFFUSERS = False + from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names +try: + from modelopt.torch.sparsity.attention_sparsity.conversion import export_sparse_attention_config +except ImportError: + export_sparse_attention_config = None + from .convert_hf_config import convert_hf_quant_config_format from .layer_utils import ( get_expert_linear_names, @@ -45,15 +69,14 @@ set_expert_quantizer_amax, ) from .model_config import ( - KV_CACHE_FP8, - KV_CACHE_NVFP4, - KV_CACHE_NVFP4_AFFINE, QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PC_PT, + QUANTIZATION_MXFP8, QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ) @@ -68,6 +91,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + has_quantized_modules, maybe_transpose_expert_weight_dimensions, postprocess_state_dict, preprocess_linear_fusion, @@ -87,57 +111,201 @@ def _is_enabled_quantizer(quantizer): return False -def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): - """Group modules that take the same input and register shared parameters in module.""" - # TODO: Handle DBRX MoE - input_to_linear = defaultdict(list) - output_to_layernorm = defaultdict(None) - quantization_format = get_quantization_format(model) +def _save_component_state_dict_safetensors( + component: nn.Module, component_export_dir: Path +) -> None: + cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} + save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) + with open(component_export_dir / "config.json", "w") as f: + json.dump( + { + "_class_name": type(component).__name__, + "_export_format": "safetensors_state_dict", + }, + f, + indent=4, + ) + + +def _collect_shared_input_modules( + model: nn.Module, + dummy_forward_fn: Callable[[], None], + collect_layernorms: bool = False, +) -> tuple[dict, dict | None]: + """Collect modules that share the same input using forward hooks. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model to analyze. + dummy_forward_fn: A callable that runs a dummy forward pass on the model. + Should be a function that takes no arguments. + collect_layernorms: If True, also collect layernorm output mappings (for AWQ). + + Returns: + A tuple of (input_to_linear, output_to_layernorm). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). + """ + input_to_linear: dict = defaultdict(list) + output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): """Update dictionary with list of all modules that share the same input.""" - # TODO: Handle DBRX MoE case - input_to_linear[input[0]].append(module) + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) def _output_hook(module, input, output): """Update dictionary with mapping of layernorms and their outputs.""" - output_to_layernorm[output] = module + if output_to_layernorm is not None and isinstance(output, torch.Tensor): + output_to_layernorm[output] = module handles = [] - model_type = type(model).__name__.lower() + # Register hooks on all quantized linear modules (and optionally layernorms) + for name, module in model.named_modules(): + if collect_layernorms and is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + return input_to_linear, output_to_layernorm + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + dummy_forward_fn() + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + return input_to_linear, output_to_layernorm + + +def _fuse_shared_input_modules( + model: nn.Module, + input_to_linear: dict, + output_to_layernorm: dict | None = None, + qkv_only: bool = False, + fuse_layernorms: bool = False, + quantization_format: str | None = None, +) -> dict[str, list[str]]: + """Fuse modules that share the same input. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model being processed (for FSDP-aware updates). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (optional). + qkv_only: If True, only fuse QKV projection layers (for diffusion models). + fuse_layernorms: If True, also fuse layernorms with pre_quant_scale (for AWQ). + quantization_format: The quantization format of the model. + + Returns: + Dict mapping first module name to list of all fused module names. + """ fused_linears = {} + fused_count = 0 + + for tensor, modules in input_to_linear.items(): + # Get quantization format for this group of modules + # (must be re-evaluated per group as different modules may have different formats) + group_quant_format = get_quantization_format(modules[0]) if modules else quantization_format + + if len(modules) > 1 and group_quant_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + if qkv_only: + # Filter to only include QKV projection layers (diffusion models) + qkv_modules = [m for m in modules if is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) >= 2: + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") + else: + # Fuse all modules that have the same input (LLM models) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + fused_count += 1 + + # Fuse layernorms (for AWQ) + if ( + fuse_layernorms + and output_to_layernorm is not None + and group_quant_format is not None + and group_quant_format != QUANTIZATION_NONE + and "awq" in group_quant_format + and tensor in output_to_layernorm + ): + with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + if qkv_only: + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + return fused_linears + + +def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + quantization_format = get_quantization_format(model) + model_type = type(model).__name__.lower() module_names = set() + # NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because + # 1) its kernel handles pre-quant scale. + # 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in + # the later gate up fusion. # Fuse pre_quant_scale to the linear weights if possible if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): fuse_prequant_to_linear(model) + # Pre-process MoE experts for name, module in model.named_modules(): module_names.add(name) # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts - if is_moe(module) and ("awq" in quantization_format): + if is_moe(module) and ( + quantization_format is not QUANTIZATION_NONE + and ("awq" in quantization_format or quantization_format == QUANTIZATION_NVFP4_SVDQUANT) + ): # update_experts_avg_prequant_scale(module) grouped_experts = get_experts_list(module, model_type) for modules in grouped_experts: with fsdp2_aware_weight_update(model, modules): preprocess_linear_fusion(modules, resmooth_only=True) - # Attach hook to layernorm modules that need to be fused - if is_layernorm(module): - module.name = name - handle = module.register_forward_hook(_output_hook) - handles.append(handle) - elif is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) - - with torch.no_grad(): + # Define the dummy forward function for LLM + def llm_dummy_forward(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input @@ -153,57 +321,42 @@ def _output_hook(module, input, output): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - # Run forward pass so that all modules sharing the same input are collected using forward hook. - - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part - language_model_lineage = get_language_model_from_vl(model) - - if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use - language_model = language_model_lineage[-1] - print( - f"Running optimization on language model with fake_input shape: {fake_input.shape}" - ) - language_model(fake_input) - else: - raise ValueError( - f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " - "This is required for requantization/resmoothing optimization. " - "Please ensure the model architecture is supported or file an issue." - ) - else: - model(fake_input) + if is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, run optimization on just the language model/decoder. + # This avoids needing pixel_values for the vision encoder. + language_model_lineage = get_language_model_from_vl(model) - for handle in handles: - handle.remove() + if language_model_lineage is not None: + language_model = language_model_lineage[-1] + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + # Pass use_cache=False to avoid KV cache issues in encoder-decoder models + language_model(fake_input, use_cache=False) + else: + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) + elif getattr(model.config, "is_encoder_decoder", False): + # For other encoder-decoder models (non-VL), pass both encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) + else: + model(fake_input) - for tensor, modules in input_to_linear.items(): - quantization_format = get_quantization_format(modules[0]) - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Fuse modules that have the same input - with fsdp2_aware_weight_update(model, modules): - preprocess_linear_fusion(modules) - fused_linears[modules[0].name] = [module.name for module in modules] + input_to_linear, output_to_layernorm = _collect_shared_input_modules( + model, llm_dummy_forward, collect_layernorms=True + ) - # Fuse layernorms - if ( - quantization_format is not QUANTIZATION_NONE - and "awq" in quantization_format - and tensor in output_to_layernorm - ): - # Pre quant scale of modules is already updated to avg_pre_quant_scale - with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): - fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + fused_linears = _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm, + qkv_only=False, + fuse_layernorms=True, + quantization_format=quantization_format, + ) # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. @@ -295,6 +448,15 @@ def _export_quantized_weight( weight_quantizer._scale.to(torch.float32), ) del weight_quantizer._scale + elif quantization_format == QUANTIZATION_MXFP8: + # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) + weight = getattr(sub_module, weight_name) + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, weight_quantizer + ) + sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale) + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + del weight_quantizer._scale else: sub_module.register_buffer( quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name) @@ -314,6 +476,7 @@ def _export_quantized_weight( if quantization_format in [ QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_NVFP4, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, @@ -334,7 +497,11 @@ def _export_quantized_weight( for expert_type in ["Llama4TextExperts", "GptOssExperts"] ) - if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization weight, _ = maybe_transpose_expert_weight_dimensions( @@ -390,8 +557,73 @@ def _export_quantized_weight( if weight_scale is not None: sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + torch.cuda.empty_cache() + + +def _process_quantized_modules( + model: nn.Module, + dtype: torch.dtype, + is_modelopt_qlora: bool = False, +) -> None: + """Process all quantized modules in model, export weights in-place. + + This function iterates through all modules in the model and exports quantized weights + for modules that have quantization enabled. It handles both standard linear layers + and specialized expert modules (Llama4TextExperts, GptOssExperts). + + Args: + model: The model containing quantized modules. + dtype: The data type for weight conversion. + is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. + If True, modules with base_layer attribute are skipped. + """ + fsdp_module_to_reshard = None + + for _, sub_module in model.named_modules(): + # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead + if isinstance(sub_module, FSDPModule): + # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. + # We need to reshard the previous FSDPModule to prevent potential OOM. + # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + fsdp_module_to_reshard = sub_module + + # We skip QuantLoraLinear module for modelopt QLoRA + if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): + continue -def _export_hf_checkpoint( + if hasattr(sub_module, "weight_packed") or ( + "QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1 + ): + sub_module.unpack_weight() + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + elif ( + "Llama4TextExperts" in type(sub_module).__name__ + or "GptOssExperts" in type(sub_module).__name__ + ): + # TODO: consolidate uncalibrated experts handling logic + # Handle weight quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], + ) + # Handle input quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + ) + # Export the quantized weights + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: + _export_quantized_weight(sub_module, dtype, weight_name) + + +def _export_transformers_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -485,93 +717,305 @@ def _export_hf_checkpoint( quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora) - kv_cache_max_bound = 0 + # Add MTP layer prefixes to exclude_modules if they were excluded from quantization + # This ensures they appear in quantization_config["ignore"] in config.json + mtp_layer_prefixes = getattr(model, "_mtp_layer_prefixes", None) + if mtp_layer_prefixes: + exclude_modules = quant_config["quantization"].setdefault("exclude_modules", []) + for prefix in mtp_layer_prefixes: + # Add wildcard pattern to exclude all submodules under this MTP layer + pattern = f"{prefix}*" + if pattern not in exclude_modules: + exclude_modules.append(pattern) + print(f"Adding MTP layer to quantization_config ignore: {pattern}") + + # Process all quantized modules and export weights + _process_quantized_modules(model, dtype, is_modelopt_qlora) + + if accelerator is not None: + # Gather state_dict from all ranks + quantized_state_dict = accelerator.get_state_dict(model) + else: + quantized_state_dict = model.state_dict() + + # We define kv cache scale as amax / 448 for both FP8 and NVFP4 KV cache quantization. + kv_cache_max_bound = 448 kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] + quantized_state_dict = postprocess_state_dict( + quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora + ) - cache_bound_mapping = { - KV_CACHE_NVFP4: 6 * 448, - KV_CACHE_NVFP4_AFFINE: 6 * 448, - KV_CACHE_FP8: 448, + return quantized_state_dict, quant_config + + +def _fuse_qkv_linears_diffusion( + model: nn.Module, dummy_forward_fn: Callable[[], None] | None = None +) -> None: + """Fuse QKV linear layers that share the same input for diffusion models. + + This function uses forward hooks to dynamically identify linear modules that + share the same input tensor (e.g., q_proj, k_proj, v_proj in attention). + For these modules, it unifies their input and weight amax values. + + Note: This is a simplified version for diffusion models that: + - Handles QKV fusion (shared input detection) + - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) + - Skips pre_quant_scale handling (TODO for future) + - Skips FFN fusion with layernorm (TODO for future) + + Args: + model: The diffusion model component (e.g., transformer, unet). + dummy_forward_fn: Optional callable to run a dummy forward pass. Use this + for diffusion-like models whose forward signature is not compatible + with `generate_diffusion_dummy_inputs`. + """ + quantization_format = get_quantization_format(model) + + if quantization_format == QUANTIZATION_NONE: + return + + if dummy_forward_fn is None: + dummy_forward_fn = generate_diffusion_dummy_forward_fn(model) + + # Collect modules sharing the same input + try: + input_to_linear, _ = _collect_shared_input_modules( + model, dummy_forward_fn, collect_layernorms=False + ) + except Exception as e: + print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") + print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") + return + + if not input_to_linear: + print("No quantized linear modules found for QKV fusion.") + return + + # Fuse the collected modules (QKV only for diffusion) + _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm=None, + qkv_only=True, + fuse_layernorms=False, + quantization_format=quantization_format, + ) + + +def _export_diffusers_checkpoint( + pipe: Any, + dtype: torch.dtype | None, + export_dir: Path, + components: list[str] | None, + max_shard_size: int | str = "10GB", +) -> None: + """Internal: Export diffusion(-like) model/pipeline checkpoint. + + This function handles the export of: + - diffusers models: DiffusionPipeline and individual ModelMixin components. + - LTX-2 pipelines (duck-typed): exports stage-1 transformer only. + + Args: + pipe: The model or pipeline to export. + dtype: The data type for weight conversion. If None, will be inferred from model. + export_dir: The directory to save the exported checkpoint. + components: Optional list of component names to export. Only used for pipelines. + If None, all components are exported. + max_shard_size: Maximum size of each shard file. If the model exceeds this size, + it will be sharded into multiple files and a .safetensors.index.json will be + created. Use smaller values like "5GB" or "2GB" to force sharding. + """ + export_dir = Path(export_dir) + + # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + all_components = get_diffusion_components(pipe, components) + + if not all_components: + warnings.warn("No exportable components found in the model.") + return + + # Separate nn.Module components for quantization-aware export + module_components = { + name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) } - # Only update kv_cache_max_bound if a quantization is applied. - if kv_cache_format != QUANTIZATION_NONE: - kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) + # Best-effort diffusers pipeline check (kept for folder layout + model_index.json behavior) + is_diffusers_pipe = False + if HAS_DIFFUSERS: + try: + from diffusers import DiffusionPipeline as _DiffusionPipeline + + is_diffusers_pipe = isinstance(pipe, _DiffusionPipeline) + except Exception: + is_diffusers_pipe = False + + # Step 3: Export each nn.Module component with quantization handling + for component_name, component in module_components.items(): + is_quantized = has_quantized_modules(component) + status = "quantized" if is_quantized else "non-quantized" + print(f"Exporting component: {component_name} ({status})") + + # Determine component export directory + # For pipelines, each component goes in a subfolder + if is_diffusers_pipe: + component_export_dir = export_dir / component_name + else: + component_export_dir = export_dir - # Track if any layers are quantized to properly set exclude_modules - fsdp_module_to_reshard = None + component_export_dir.mkdir(parents=True, exist_ok=True) - for _, sub_module in model.named_modules(): - # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead - if isinstance(sub_module, FSDPModule): - # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. - # We need to reshard the previous FSDPModule to prevent potential OOM. - # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. - if fsdp_module_to_reshard is not None: - fsdp_module_to_reshard.reshard() + # Infer dtype if not provided + component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) - fsdp_module_to_reshard = sub_module + if is_quantized: + # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion + # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + print(f" Running QKV fusion for {component_name}...") + _fuse_qkv_linears_diffusion(component) - # We skip QuantLoraLinear module for modelopt QLoRA - if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): - continue + # Step 4: Process quantized modules (convert weights, register scales) + _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) - if get_quantization_format(sub_module) != QUANTIZATION_NONE: - if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) - elif ( - "Llama4TextExperts" in type(sub_module).__name__ - or "GptOssExperts" in type(sub_module).__name__ - ): - # TODO: consolidate uncalibrated experts handling logic - # Handle weight quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], - ) - # Handle input quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + # Step 5: Build quantization config + quant_config = get_quant_config(component, is_modelopt_qlora=False) + + # Step 6: Save the component + # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter + # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save + if hasattr(component, "save_pretrained"): + with hide_quantizers_from_state_dict(component): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + else: + with hide_quantizers_from_state_dict(component): + _save_component_state_dict_safetensors(component, component_export_dir) + + # Step 7: Update config.json with quantization info + if quant_config is not None: + hf_quant_config = convert_hf_quant_config_format(quant_config) + + config_path = component_export_dir / "config.json" + if config_path.exists(): + with open(config_path) as file: + config_data = json.load(file) + config_data["quantization_config"] = hf_quant_config + with open(config_path, "w") as file: + json.dump(config_data, file, indent=4) + # Non-quantized component: just save as-is + elif hasattr(component, "save_pretrained"): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + else: + _save_component_state_dict_safetensors(component, component_export_dir) + + print(f" Saved to: {component_export_dir}") + + # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + if is_diffusers_pipe: + for component_name, component in all_components.items(): + # Skip nn.Module components (already handled above) + if isinstance(component, nn.Module): + continue + + component_export_dir = export_dir / component_name + component_export_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting component: {component_name} ({type(component).__name__})") + + # Handle different component types + if hasattr(component, "save_pretrained"): + # Tokenizers, feature extractors, image processors + component.save_pretrained(component_export_dir) + elif hasattr(component, "save_config"): + # Schedulers + component.save_config(component_export_dir) + else: + warnings.warn( + f"Component '{component_name}' of type {type(component).__name__} " + "does not have save_pretrained or save_config method. Skipping." ) - # Export the quantized weights - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + continue - if accelerator is not None: - # Gather state_dict from all ranks - quantized_state_dict = accelerator.get_state_dict(model) - else: - quantized_state_dict = model.state_dict() + print(f" Saved to: {component_export_dir}") - quantized_state_dict = postprocess_state_dict( - quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora - ) + # Step 5: For pipelines, also save model_index.json + if is_diffusers_pipe: + model_index_path = export_dir / "model_index.json" + is_partial_export = components is not None - return quantized_state_dict, quant_config + # For full export, preserve original model_index.json when possible. + # For partial export, skip this to avoid listing non-exported components. + if not is_partial_export: + source_path = getattr(pipe, "name_or_path", None) or getattr( + getattr(pipe, "config", None), "_name_or_path", None + ) + if source_path: + candidate_model_index = Path(source_path) / "model_index.json" + if candidate_model_index.exists(): + with open(candidate_model_index) as file: + model_index = json.load(file) + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + # Full-export fallback to Diffusers-native config serialization. + # Partial export skips this for the same reason as above. + if not is_partial_export and not model_index_path.exists() and hasattr(pipe, "save_config"): + pipe.save_config(export_dir) + + # Last resort: synthesize a minimal model_index.json from exported components. + if not model_index_path.exists() and hasattr(pipe, "config") and pipe.config is not None: + model_index = { + "_class_name": type(pipe).__name__, + "_diffusers_version": diffusers.__version__, + } + for name, comp in all_components.items(): + module = type(comp).__module__ + library = module.split(".")[0] + model_index[name] = [library, type(comp).__name__] + + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + print(f"Export complete. Saved to: {export_dir}") def export_hf_checkpoint( - model: nn.Module, + model: Any, dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + components: list[str] | None = None, + extra_state_dict: dict[str, torch.Tensor] | None = None, ): - """Exports the torch model to unified checkpoint and saves to export_dir. + """Export quantized HuggingFace model checkpoint (transformers or diffusers). + + This function automatically detects whether the model is from transformers + or diffusers and applies the appropriate export logic. Args: - model: the full torch model to export. The actual quantized model may be a submodule. - dtype: the weights data type to export the unquantized layers or the default model data type if None. - export_dir: the target export path. - save_modelopt_state: whether to save the modelopt state_dict. + model: The full torch model to export. The actual quantized model may be a submodule. + Supports both transformers models (e.g., LlamaForCausalLM) and diffusers + models/pipelines (e.g., StableDiffusionPipeline, UNet2DConditionModel). + dtype: The weights data type to export the unquantized layers or the default + model data type if None. + export_dir: The target export path. + save_modelopt_state: Whether to save the modelopt state_dict. + components: Only used for diffusers pipelines. Optional list of component names + to export. If None, all quantized components are exported. + extra_state_dict: Extra state dictionary to add to the exported model. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + is_diffusers_obj = False + if HAS_DIFFUSERS: + is_diffusers_obj = is_diffusers_object(model) + if is_diffusers_obj: + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return + + # Transformers model export # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint + # This is a temp workaround to avoid error with offline spec ckpt during export if spec_opt_only(model): save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") with open(f"{export_dir}/config.json", "w") as file: @@ -579,18 +1023,24 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) if hf_quant_config is not None: - # Save hf_quant_config.json for\ backward compatibility + # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + # Remove hf_quantizer from model so post_state_dict can be exported. + if getattr(model, "hf_quantizer", None) is not None: + model.hf_quantizer = None + # Save model model.save_pretrained( - export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state + export_dir, + state_dict={**post_state_dict, **(extra_state_dict or {})}, + save_modelopt_state=save_modelopt_state, ) original_config = f"{export_dir}/config.json" @@ -602,6 +1052,12 @@ def export_hf_checkpoint( if hf_quant_config is not None: config_data["quantization_config"] = hf_quant_config + # Add sparse attention config if available + if export_sparse_attention_config is not None: + sparse_attn_config = export_sparse_attention_config(model) + if sparse_attn_config is not None: + config_data["sparse_attention_config"] = sparse_attn_config + with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f1bd673277..0567d0d1f2 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -20,23 +20,20 @@ import json import os -import shutil import tempfile from collections import OrderedDict from pathlib import Path from typing import Any -from warnings import warn import torch import torch.distributed -import torch.nn as nn -from huggingface_hub import hf_hub_download, snapshot_download -from safetensors.torch import safe_open, save_file -from tqdm import tqdm +from huggingface_hub import hf_hub_download +from safetensors.torch import save_file from modelopt import __version__ from modelopt.torch.utils import import_plugin +from .convert_hf_config import convert_hf_quant_config_format from .model_config import ( KV_CACHE_FP8, KV_CACHE_NVFP4, @@ -46,14 +43,19 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, ) +from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, save_safetensors +from .plugins.mcore_custom import ( + CustomModuleMapping, + get_safetensor, + save_safetensors_by_layer_index, +) from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, get_kv_cache_dtype, + get_kv_cache_scaling_factor, get_quantization_format, - get_scaling_factor, get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, @@ -86,33 +88,6 @@ ] -# This path uses output_quantizer for KV cache quantization. -# The function below is the old version of get_kv_cache_scaling_factor which is now refactored to handle bmm_quantizer. -def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: - """Returns the kv_cache scaling factor if output quantizer is set. Else returns None by default.""" - scaling_factor = ( - get_scaling_factor(kv_module.output_quantizer) - if hasattr(kv_module, "output_quantizer") - else None - ) - - if not scaling_factor: - return None - - # For FP8, we recommend default kv cache scaling factor to be 1. - if get_kv_cache_dtype(kv_module) == KV_CACHE_FP8: - if scaling_factor.item() > 0.5: - warn( - f"!!!!Large KV activations detected: {scaling_factor.item()}, " - "Quantized KV cache may lead to higher accuracy drop.\n!!!!" - ) - scaling_factor = torch.max( - scaling_factor, - torch.tensor([1.0], dtype=torch.float, device=scaling_factor.device), - ) - return scaling_factor - - class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -129,6 +104,8 @@ class GPTModelExporter: export_extra_modules: If True, export extra modules like medusa_heads or eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. + trust_remote_code: Whether to trust remote code in the HuggingFace pretrained model. + moe_router_dtype: The data type of the MoE router. Can be "fp32", "fp64", or None (default to the model dtype). """ def __init__( @@ -137,14 +114,15 @@ def __init__( pretrained_model_name_or_path: str | os.PathLike | None = None, export_extra_modules: bool = False, dtype=torch.bfloat16, - trust_remote_code: bool = True, - moe_router_dtype: torch.dtype | None = None, + trust_remote_code: bool = False, + moe_router_dtype: str | None = None, ): """Create a GPTModel exporter instance.""" if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!") self._state_dict = OrderedDict() + self._layer_state_dicts = OrderedDict() self._hf_pretrained_model_name = pretrained_model_name_or_path self._hf_config = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code @@ -154,6 +132,7 @@ def __init__( self.moe_router_dtype = torch.float32 elif moe_router_dtype == "fp64": self.moe_router_dtype = torch.float64 + print(f"Exporting model with moe_router_dtype: {self.moe_router_dtype}") # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -167,7 +146,7 @@ def __init__( self.is_multimodal = isinstance(model, LLaVAModel) if not self.is_multimodal: self._hf_text_config.intermediate_size = model.config.ffn_hidden_size - self._hf_quant_config = None + self._hf_quant_config: dict = {} self._hf_extra_config = None self.export_extra_modules = export_extra_modules self.is_multimodal = isinstance(model, LLaVAModel) @@ -181,6 +160,7 @@ def __init__( del self._hf_config.quantization_config self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] + self.exclude_modules = [] if not hasattr(model, "_modelopt_state"): return @@ -245,10 +225,29 @@ def __init__( self._hf_extra_config.update(eagle_config_update) + def save_pretrained_extra_modules( + self, + save_directory: str | os.PathLike, + ): + """Save a EAGLE or Medusa checkpoints which can be deployed by vLLM and TensorRT-LLM.""" + # We use the last PP rank to write the config because + # medusa_heads and eagle_module only exist in the last stage. + pp_rank = get_pipeline_model_parallel_rank() + pp_size = get_pipeline_model_parallel_world_size() + is_last_stage_main_rank = pp_rank == pp_size - 1 + + state_dict = self.extra_state_dict + + if is_last_stage_main_rank and self._hf_extra_config is not None: + self._hf_extra_config.save_pretrained(save_directory) + save_file(state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"}) + + torch.distributed.barrier() + def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): """Save a unified checkpoint which can be deployed by vLLM and TensorRT-LLM. @@ -266,11 +265,10 @@ def save_pretrained( is_last_stage_main_rank = pp_rank == pp_size - 1 # Main export process - state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict - quantization_format = self._get_quantization_format(self.model) + layer_state_dicts = self.layer_state_dicts + quantization_format = self._get_quantization_format(self.model) quantization = None - if quantization_format in ( QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, @@ -281,190 +279,372 @@ def save_pretrained( elif quantization_format == QUANTIZATION_NVFP4: quantization = "NVFP4" - kv_cache_quantization = None - kv_cache_dtype = get_kv_cache_dtype(self.model) - if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): - # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM - kv_cache_quantization = kv_cache_dtype # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: - if self.export_extra_modules and self._hf_extra_config is not None: - self._hf_extra_config.save_pretrained(save_directory) - else: - self._hf_config.save_pretrained(save_directory) - try: - generation_config = transformers.GenerationConfig.from_pretrained( - self._hf_pretrained_model_name - ) - generation_config.save_pretrained(save_directory) - except OSError: - pass - try: - tokenizer = transformers.AutoTokenizer.from_pretrained( - self._hf_pretrained_model_name - ) - tokenizer.save_pretrained(save_directory) - except OSError: - pass - except TypeError: - pass - try: - # Load and save preprocessor config from the original model - processor = AutoProcessor.from_pretrained( - self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code - ) - if hasattr(processor, "image_processor"): - processor.image_processor.save_pretrained(save_directory) - except (OSError, ValueError, ImportError): - pass + self._hf_config.save_pretrained(save_directory) + try: + generation_config = transformers.GenerationConfig.from_pretrained( + self._hf_pretrained_model_name + ) + generation_config.save_pretrained(save_directory) + except OSError: + pass + try: + tokenizer = transformers.AutoTokenizer.from_pretrained( + self._hf_pretrained_model_name + ) + tokenizer.save_pretrained(save_directory) + except OSError: + pass + except TypeError: + pass + try: + # Load and save preprocessor config from the original model + processor = AutoProcessor.from_pretrained( + self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code + ) + if hasattr(processor, "image_processor"): + processor.image_processor.save_pretrained(save_directory) + except (OSError, ValueError, ImportError): + pass + + mtp_state_dict = self._get_mtp_state_dict() + if len(mtp_state_dict) > 0: + layer_state_dicts[self.model.config.num_layers].update(mtp_state_dict) + print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") + + combined_exclude_modules = self._gather_exclude_modules() if is_last_stage_main_rank and quantization is not None: - hf_quant_config = { + self._hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, "quantization": { "quant_algo": quantization, - "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": combined_exclude_modules, }, } + if quantization == "NVFP4": # update block size + self._hf_quant_config["quantization"]["group_size"] = 16 + if hasattr(self, "kv_cache_dtype"): + self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype with open(save_directory + "/hf_quant_config.json", "w") as f: - json.dump(hf_quant_config, f, indent=4) + json.dump(self._hf_quant_config, f, indent=4) - if ( - is_first_stage_main_rank - and self.is_multimodal - and pretrained_model_name_or_path is not None - ): - hf_checkpoint_path = Path(pretrained_model_name_or_path) - if not hf_checkpoint_path.is_dir(): - hf_checkpoint_path = tempfile.gettempdir() + "/" + pretrained_model_name_or_path - if not Path(hf_checkpoint_path).exists(): - snapshot_download( - repo_id=pretrained_model_name_or_path, - local_dir=hf_checkpoint_path, + # Add multimodal components to state_dict. Since only support decoder model quantization, + # no changes will be made to the multimodal components. We copy the multimodal components + # from the pretrained model directly to the state_dict to avoid implementing the export logic. + if is_first_stage_main_rank and self.is_multimodal: + multimodal_state_dict = load_multimodal_components(pretrained_model_name_or_path) + layer_state_dicts[0].update(multimodal_state_dict) + + # Barrier to ensure the export_dir has been created. + torch.distributed.barrier() + + if is_last_stage_main_rank and self._hf_config is not None: + copy_remote_code(pretrained_model_name_or_path, save_directory) + + # Newer versions of VLLM expect config.json with hf_quant_config + config_json_file = save_directory + "/config.json" + if self._hf_quant_config and os.path.exists(config_json_file): + converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) + with open(config_json_file) as f: + config_dict = json.load(f) + config_dict["quantization_config"] = converted_quant_config + with open(config_json_file, "w") as f: + json.dump(config_dict, f, indent=4) + + # save_safetensors(state_dict, save_directory) + save_safetensors_by_layer_index( + layer_state_dicts=layer_state_dicts, + total_layers=self.model.config.num_layers, + save_directory=save_directory, + name_template="model-{:05d}-of-{:05d}", + ) + + @property + def state_dict(self): + """Return the real quantized state_dict of the base model.""" + if len(self._state_dict) == 0: + self._get_state_dict() + return self._state_dict + + @property + def layer_state_dicts(self): + if len(self._layer_state_dicts) == 0: + self._get_state_dict() + return self._layer_state_dicts + + @property + def extra_state_dict(self): + if len(self._state_dict) == 0: + self._get_medusa_heads_state_dict() + self._get_eagle_module_state_dict() + return self._state_dict + + def _get_state_dict(self): + model = self.model + + # Embedding + if hasattr(model, "embedding"): + self.rules["word_embeddings"](model.embedding.word_embeddings) + + # Decoder layers + for layer in model.decoder.layers: + layer_id = layer.layer_number - 1 + if isinstance(layer, MambaLayer): + self._get_mamba_layer_state_dict(layer, layer_id) + elif isinstance(layer, TransformerLayer): + self._get_transformer_layer_state_dict(layer, layer_id) + else: + raise ValueError("Only TransformerLayer or MambaLayer are supported.") + + self._layer_state_dicts[layer.layer_number] = self._state_dict + if layer.layer_number != self.model.config.num_layers: + self._state_dict = OrderedDict() + + # Final layernorm + if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: + self.rules["final_layernorm"](model.decoder.final_layernorm) + + if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: + self.rules["final_norm"](model.decoder.final_norm) + + # Output layer + if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: + self.rules["output_layer"](model.output_layer) + + def _get_transformer_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.input_layernorm, IdentityOp): + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if not isinstance(layer.self_attention, IdentityOp): + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["linear_q_proj"](layer.self_attention.linear_q_proj, layer_id) + else: + self.rules["linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id ) + self.rules["linear_q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["linear_q_up_proj"](layer.self_attention.linear_q_up_proj, layer_id) - safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" - safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" + self.rules["linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["linear_kv_layernorm"](layer.self_attention.kv_layernorm, layer_id) + self.rules["linear_kv_up_proj"](layer.self_attention.linear_kv_up_proj, layer_id) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + else: + if layer.self_attention.q_layernorm is not None and not isinstance( + layer.self_attention.q_layernorm, (IdentityOp, L2Norm) + ): + self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) + self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) + if ( + hasattr(layer.self_attention, "core_attention") + and "core_attention" in self.rules + ): # KV cache quant export + self.rules["core_attention"](layer.self_attention.core_attention, layer_id) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + if getattr(layer.self_attention.core_attention, "softmax_offset", None) is not None: + self.rules["softmax_offset"]( + layer.self_attention.core_attention.softmax_offset, layer_id + ) - multimodal_state_dict = {} + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - if safetensors_file.is_file(): - print(f"Loading multimodal components from single file: {safetensors_file}") - with safe_open(safetensors_file, framework="pt") as f: - multimodal_keys = [ - key - for key in f.keys() # noqa: SIM118 - if key.startswith(("multi_modal_projector", "vision_model")) - ] - for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): - multimodal_state_dict[key] = f.get_tensor(key) + if not isinstance(layer.mlp, IdentityOp): + if "MoE" in str(type(layer.mlp)): + self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype) + if hasattr(layer.mlp, "fc1_latent_proj") and layer.mlp.fc1_latent_proj is not None: + self.rules["fc1_latent_proj"](layer.mlp.fc1_latent_proj, layer_id) + if hasattr(layer.mlp, "fc2_latent_proj") and layer.mlp.fc2_latent_proj is not None: + self.rules["fc2_latent_proj"](layer.mlp.fc2_latent_proj, layer_id) + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + self.rules["shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + if not self.rules.get("use_packed_local_experts", False): + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + # For llama 4, in hf unified checkpoint, all local experts share one scale + self.rules["local_experts.linear_fc1"]( + layer.mlp.experts.local_experts, layer_id + ) + self.rules["local_experts.linear_fc2"]( + layer.mlp.experts.local_experts, layer_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) + + def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: + """Export the MTP module. + + Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers. + """ + # TODO Implement MTP export for quantized MTP + # Hacky version for now: copy MTP weights from pretrained model + mtp_state_dict = {} + if self._hf_pretrained_model_name: + if os.path.isdir(self._hf_pretrained_model_name): + safetensors_index_file = ( + Path(self._hf_pretrained_model_name) / "model.safetensors.index.json" + ) + else: + safetensors_index_file = hf_hub_download( + repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json" + ) - elif safetensors_index_file.is_file(): - print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") + print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}") + mtp_exists = False + if safetensors_index_file and os.path.exists(safetensors_index_file): with open(safetensors_index_file) as f: safetensors_index = json.load(f) + model_dir = Path(safetensors_index_file).parent + for key in safetensors_index["weight_map"]: + if key.startswith("mtp.") and key not in self._state_dict: + mtp_state_dict[key] = get_safetensor(model_dir, key) + mtp_exists = True - # For multimodal models, vision_model and multi_modal_projector are in the first shard - all_shard_files = sorted(set(safetensors_index["weight_map"].values())) - first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" - - # Load multimodal components from the first shard file - safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file - print(f"Loading multimodal components from {first_shard_file}") - - with safe_open(safetensors_filepath, framework="pt") as f: - shard_keys = list(f.keys()) - multimodal_keys_in_shard = [ - k - for k in shard_keys - if k.startswith(("multi_modal_projector", "vision_model")) - ] - - if multimodal_keys_in_shard: - print( - f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" - ) - for key in tqdm( - multimodal_keys_in_shard, desc="Loading multimodal tensors" - ): - multimodal_state_dict[key] = f.get_tensor(key) - else: - print(f"No multimodal components found in {first_shard_file}") + if mtp_exists: + self.exclude_modules.append("mtp*") + return mtp_state_dict - else: - print(f"Warning: No safetensors files found in {hf_checkpoint_path}") + def _get_mamba_layer_state_dict(self, layer, layer_id): + if not isinstance(layer.norm, IdentityOp): + self.rules["norm"](layer.norm, layer_id) - print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") - # Add multimodal components to state_dict - state_dict.update(multimodal_state_dict) + self.rules["mixer_norm"](layer.mixer.norm, layer_id) + self.rules["A_log"](layer.mixer.A_log, layer_id) + self.rules["D"](layer.mixer.D, layer_id) + self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - # Barrier to ensure the export_dir has been created. - torch.distributed.barrier() + self.rules["conv1d"](layer.mixer.conv1d, layer_id) + self.rules["in_proj"](layer.mixer.in_proj, layer_id) + self.rules["out_proj"](layer.mixer.out_proj, layer_id) - if self.export_extra_modules: - if is_last_stage_main_rank: - save_file( - state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"} - ) - torch.distributed.barrier() + def _get_medusa_heads_state_dict(self): + medusa_heads = getattr(self.model, "medusa_heads", None) + if medusa_heads is None: return - if ( - is_last_stage_main_rank - and self._hf_config is not None - and pretrained_model_name_or_path is not None - ): - # For models that keep configuration and modeling files as part of the checkpoint, - # we need to copy them to the export directory for seamless integration with inference - # frameworks. - hf_checkpoint_path = Path(pretrained_model_name_or_path) - model_type = getattr(self._hf_config, "model_type", None) - - if hf_checkpoint_path.is_dir(): - # Local directory - files should be there - config_file = hf_checkpoint_path / f"configuration_{model_type}.py" - modeling_file = hf_checkpoint_path / f"modeling_{model_type}.py" - else: - # Remote model ID - download from HuggingFace Hub (cached automatically) - try: - config_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, - filename=f"configuration_{model_type}.py", + for head_id, head in enumerate(medusa_heads): + self.rules["medusa_heads.lm_head"](head.lm_head, head_id) + for layer_id, layer in enumerate(head.medusa_layers): + self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) + + def _get_eagle_module_state_dict(self): + eagle_module = getattr(self.model, "eagle_module", None) + + if eagle_module is None: + return + + # if hasattr(self.model, "embedding"): + # self.rules["word_embeddings"](self.model.embedding.word_embeddings) + + self.rules["fc"](eagle_module.fc) + if self.model.eagle_config.use_aux_hidden_state: + self.rules["enorm"](eagle_module.enorm) + elif self.model.eagle_config.use_mtp_layernorm: + self.rules["enorm"](eagle_module.enorm) + self.rules["hnorm"](eagle_module.hnorm) + + if self.model.eagle_config.use_last_layernorm: + self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) + + if hasattr(self.model.eagle_module, "eagle_output_layer"): + self.rules["output_layer"](eagle_module.eagle_output_layer) + if hasattr(self.model.eagle_module, "dt2"): + self.rules["d2t"](eagle_module.d2t) + + for layer in eagle_module.decoder.layers: + layer_id = layer.layer_number - 1 + + # The first layernorm needs special handling here. We have a dedicated mapping + # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm + # instead of input_layernorm (due to the specialized transformer layer). + # The remaining EAGLE3 layers (if more than 1) are normal transformer layers + # where input_layernorm is mapped to input_layernorm. + if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: + self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) + elif layer_id > 0: + self.rules["input_layernorm"](layer.input_layernorm, layer_id) + + if "MLASelfAttention" in str(type(layer.self_attention)): + if hasattr(layer.self_attention, "linear_q_proj"): + self.rules["eagle_module.linear_q_proj"]( + layer.self_attention.linear_q_proj, layer_id + ) + else: + self.rules["eagle_module.linear_q_down_proj"]( + layer.self_attention.linear_q_down_proj, layer_id + ) + self.rules["eagle_module.linear_q_layernorm"]( + layer.self_attention.q_layernorm, layer_id ) - except Exception: - config_file = "" - try: - modeling_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=f"modeling_{model_type}.py" + self.rules["eagle_module.linear_q_up_proj"]( + layer.self_attention.linear_q_up_proj, layer_id ) - except Exception: - modeling_file = "" - if config_file and os.path.exists(config_file): - shutil.copy(config_file, f"{save_directory}/configuration_{model_type}.py") - if modeling_file and os.path.exists(modeling_file): - shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + self.rules["eagle_module.linear_kv_down_proj"]( + layer.self_attention.linear_kv_down_proj, layer_id + ) + self.rules["eagle_module.linear_kv_layernorm"]( + layer.self_attention.kv_layernorm, layer_id + ) + self.rules["eagle_module.linear_kv_up_proj"]( + layer.self_attention.linear_kv_up_proj, layer_id + ) + else: + self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - save_safetensors(state_dict, save_directory) + self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) + self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - @property - def state_dict(self): - """Return the real quantized state_dict of the base model.""" - if len(self._state_dict) == 0: - self._get_state_dict() - return self._state_dict + if "MoE" in str(type(layer.mlp)): + self.rules["eagle_module.router"](layer.mlp.router, layer_id) + if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: + self.rules["eagle_module.shared_experts.linear_fc1"]( + layer.mlp.shared_experts.linear_fc1, layer_id + ) + self.rules["eagle_module.shared_experts.linear_fc2"]( + layer.mlp.shared_experts.linear_fc2, layer_id + ) + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["eagle_module.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["eagle_module.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) + self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - @property - def extra_state_dict(self): - if len(self._state_dict) == 0: - self._get_medusa_heads_state_dict() - self._get_eagle_module_state_dict() - return self._state_dict + parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) + if parallel_draft_heads is not None: + for head_id, head in enumerate(parallel_draft_heads.medusa_heads): + for layer_id, layer in enumerate(head): + self.rules["parallel_draft_heads.medusa_layers"]( + layer.linear, head_id, layer_id + ) + self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) def _populate_rule_book(self): all_rules = {} @@ -473,6 +653,7 @@ def _custom_mapping_to_lambda(mapping): method_map = { "name_remapping": self._name_remapping, "qkv_slicing": self._qkv_slicing, + "self_attention_scaling": self._self_attention_scaling, "gated_mlp_slicing": self._gated_mlp_slicing, "pack_name_remapping": self._pack_name_remapping, "pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss, @@ -497,30 +678,40 @@ def _get_quantized_state( self, module: torch.nn.Module, dtype: torch.dtype = torch.float16, + prefix: str = "", ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. Args: module: The target module to perform real quantization. dtype: The default data type. + prefix: The prefix of the layer. Returns: Tuple: state_dict, quantization format, and block_size of the module. """ name_to_value = {} qformat: str = self._get_quantization_format(module) + if qformat is None and "norm" not in prefix: + # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty + # string then it usually ends with "." which needs to be removed. + self.exclude_modules.append(prefix.removesuffix(".")) block_size = get_weight_block_size(module) - if hasattr(module, "weight") and module.weight is not None: + if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: weight = module.weight.to(dtype).cpu() name_to_value["weight"] = weight else: return name_to_value, qformat, block_size - if hasattr(module, "bias") and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None and module.bias.numel() > 0: name_to_value["bias"] = module.bias.to(dtype).cpu() - if hasattr(module, "expert_bias") and module.expert_bias is not None: + if ( + hasattr(module, "expert_bias") + and module.expert_bias is not None + and module.expert_bias.numel() > 0 + ): name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() if qformat == QUANTIZATION_NONE: @@ -542,11 +733,6 @@ def _get_quantized_state( if hasattr(module.input_quantizer, "_pre_quant_scale"): raise ValueError("Detect pre_quant_scale! SmoothQuant/AWQ are not yet supported!") - if hasattr(module, "output_quantizer"): - output_scale = get_kv_cache_scaling_factor(module) - if output_scale is not None: - name_to_value["output_scale"] = output_scale - return name_to_value, qformat, block_size def _get_quantization_format(self, module: torch.nn.Module): @@ -580,7 +766,7 @@ def _name_remapping( self._state_dict[prefix] = module return - name_to_value, qformat, block_size = self._get_quantized_state(module, dtype) + name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -612,7 +798,9 @@ def _name_remapping( def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -674,10 +862,10 @@ def _qkv_slicing( q_proj_name="q_proj", k_proj_name="k_proj", v_proj_name="v_proj", - k_scale_name="k_scale", - v_scale_name="v_scale", ): - name_to_value, qformat, block_size = self._get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = self._get_quantized_state( + module, self.dtype, prefix=prefix + ) q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." @@ -774,10 +962,7 @@ def _qkv_slicing( q_proj_key = q_proj_prefix + key k_proj_key = k_proj_prefix + key v_proj_key = v_proj_prefix + key - if key == "output_scale": - self._state_dict[prefix + k_scale_name] = val.detach().clone() - self._state_dict[prefix + v_scale_name] = val.detach().clone() - elif key == "bias": + if key == "bias": # Slice bias similar to weight bias = val.detach().clone() bias = bias.reshape([qkv_total_dim, head_size]) @@ -790,6 +975,23 @@ def _qkv_slicing( self._state_dict[k_proj_key] = val.detach().clone() self._state_dict[v_proj_key] = val.detach().clone() + def _self_attention_scaling( + self, module, prefix, k_scale_name="k_scale", v_scale_name="v_scale" + ): + """KV cache scaling for CoreAttention module.""" + k_scale_key = prefix + k_scale_name + v_scale_key = prefix + v_scale_name + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_scales = get_kv_cache_scaling_factor(module) + if all(s is not None for s in kv_scales): + self._state_dict[k_scale_key] = kv_scales[0] + self._state_dict[v_scale_key] = kv_scales[1] + + kv_cache_dtype = get_kv_cache_dtype(module) + if kv_cache_dtype in (KV_CACHE_FP8, KV_CACHE_NVFP4): + # FP8 KV Cache is supported in VLLM; NVFP4 supported in TRTLLM + self.kv_cache_dtype = kv_cache_dtype + def _pack_name_remapping(self, module, prefix, layer_type=None): """Pack name remapping into one tensor.""" weight_list = [] @@ -800,7 +1002,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -866,7 +1068,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = self._get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, prefix=prefix ) weight = name_to_value.pop("weight") bias = name_to_value.pop("bias", None) @@ -969,242 +1171,27 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): # TODO: May need to modify the key name later. self._state_dict[prefix + "_bias"] = merged_bias - def _get_medusa_heads_state_dict(self): - medusa_heads = getattr(self.model, "medusa_heads", None) - if medusa_heads is None: - return - - for head_id, head in enumerate(medusa_heads): - self.rules["medusa_heads.lm_head"](head.lm_head, head_id) - for layer_id, layer in enumerate(head.medusa_layers): - self.rules["medusa_heads.medusa_layers.linear"](layer.linear, head_id, layer_id) - - def _get_eagle_module_state_dict(self): - eagle_module = getattr(self.model, "eagle_module", None) - - if eagle_module is None: - return - - # if hasattr(self.model, "embedding"): - # self.rules["word_embeddings"](self.model.embedding.word_embeddings) - - self.rules["fc"](eagle_module.fc) - if self.model.eagle_config.use_aux_hidden_state: - self.rules["enorm"](eagle_module.enorm) - elif self.model.eagle_config.use_mtp_layernorm: - self.rules["enorm"](eagle_module.enorm) - self.rules["hnorm"](eagle_module.hnorm) - - if self.model.eagle_config.use_last_layernorm: - self.rules["final_layernorm"](eagle_module.decoder.final_layernorm) - - if hasattr(self.model.eagle_module, "eagle_output_layer"): - self.rules["output_layer"](eagle_module.eagle_output_layer) - if hasattr(self.model.eagle_module, "dt2"): - self.rules["d2t"](eagle_module.d2t) - - for layer in eagle_module.decoder.layers: - layer_id = layer.layer_number - 1 - - # The first layernorm needs special handling here. We have a dedicated mapping - # for the first layernorm since in EAGLE3 it will be mapped to hidden_norm - # instead of input_layernorm (due to the specialized transformer layer). - # The remaining EAGLE3 layers (if more than 1) are normal transformer layers - # where input_layernorm is mapped to input_layernorm. - if layer_id == 0 and self.model.eagle_config.use_input_layernorm_in_first_layer: - self.rules["first_input_layernorm"](layer.input_layernorm, layer_id) - elif layer_id > 0: - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["eagle_module.linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["eagle_module.linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["eagle_module.linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["eagle_module.linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["eagle_module.linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["eagle_module.linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["eagle_module.linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - else: - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if "MoE" in str(type(layer.mlp)): - self.rules["eagle_module.router"](layer.mlp.router, layer_id) - if hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None: - self.rules["eagle_module.shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["eagle_module.shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["eagle_module.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["eagle_module.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - - parallel_draft_heads = getattr(eagle_module, "parallel_draft_heads", None) - if parallel_draft_heads is not None: - for head_id, head in enumerate(parallel_draft_heads.medusa_heads): - for layer_id, layer in enumerate(head): - self.rules["parallel_draft_heads.medusa_layers"]( - layer.linear, head_id, layer_id - ) - self.rules["parallel_draft_heads.lm_head"](parallel_draft_heads.lm_head) - - def _get_state_dict(self): - model = self.model - - # Embedding - if hasattr(model, "embedding"): - self.rules["word_embeddings"](model.embedding.word_embeddings) - - # Final layernorm - if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: - self.rules["final_layernorm"](model.decoder.final_layernorm) - - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: - self.rules["final_norm"](model.decoder.final_norm) - - # Output layer - if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: - self.rules["output_layer"](model.output_layer) - - # Decoder layers - for layer in model.decoder.layers: - layer_id = layer.layer_number - 1 - - if isinstance(layer, MambaLayer): - if not isinstance(layer.norm, IdentityOp): - self.rules["norm"](layer.norm, layer_id) - - self.rules["mixer_norm"](layer.mixer.norm, layer_id) - self.rules["A_log"](layer.mixer.A_log, layer_id) - self.rules["D"](layer.mixer.D, layer_id) - self.rules["dt_bias"](layer.mixer.dt_bias, layer_id) - - self.rules["conv1d"](layer.mixer.conv1d, layer_id) - self.rules["in_proj"](layer.mixer.in_proj, layer_id) - self.rules["out_proj"](layer.mixer.out_proj, layer_id) + def _gather_exclude_modules(self): + """Get exclude_modules from all ranks to ensure hf_quant_config is complete.""" + if not torch.distributed.is_initialized(): + return sorted(self.exclude_modules) - elif isinstance(layer, TransformerLayer): - if not isinstance(layer.input_layernorm, IdentityOp): - self.rules["input_layernorm"](layer.input_layernorm, layer_id) - - if not isinstance(layer.self_attention, IdentityOp): - if "MLASelfAttention" in str(type(layer.self_attention)): - if hasattr(layer.self_attention, "linear_q_proj"): - self.rules["linear_q_proj"]( - layer.self_attention.linear_q_proj, layer_id - ) - else: - self.rules["linear_q_down_proj"]( - layer.self_attention.linear_q_down_proj, layer_id - ) - self.rules["linear_q_layernorm"]( - layer.self_attention.q_layernorm, layer_id - ) - self.rules["linear_q_up_proj"]( - layer.self_attention.linear_q_up_proj, layer_id - ) - - self.rules["linear_kv_down_proj"]( - layer.self_attention.linear_kv_down_proj, layer_id - ) - self.rules["linear_kv_layernorm"]( - layer.self_attention.kv_layernorm, layer_id - ) - self.rules["linear_kv_up_proj"]( - layer.self_attention.linear_kv_up_proj, layer_id - ) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - else: - if layer.self_attention.q_layernorm is not None and not isinstance( - layer.self_attention.q_layernorm, (IdentityOp, L2Norm) - ): - self.rules["q_layernorm"](layer.self_attention.q_layernorm, layer_id) - self.rules["k_layernorm"](layer.self_attention.k_layernorm, layer_id) - self.rules["linear_qkv"](layer.self_attention.linear_qkv, layer_id) - self.rules["linear_proj"](layer.self_attention.linear_proj, layer_id) - if ( - getattr(layer.self_attention.core_attention, "softmax_offset", None) - is not None - ): - self.rules["softmax_offset"]( - layer.self_attention.core_attention.softmax_offset, layer_id - ) - - if not isinstance(layer.pre_mlp_layernorm, IdentityOp): - self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - - if not isinstance(layer.mlp, IdentityOp): - if "MoE" in str(type(layer.mlp)): - self.rules["router"]( - layer.mlp.router, layer_id, dtype=self.moe_router_dtype - ) - if ( - hasattr(layer.mlp, "shared_experts") - and layer.mlp.shared_experts is not None - ): - self.rules["shared_experts.linear_fc1"]( - layer.mlp.shared_experts.linear_fc1, layer_id - ) - self.rules["shared_experts.linear_fc2"]( - layer.mlp.shared_experts.linear_fc2, layer_id - ) - if not self.rules.get("use_packed_local_experts", False): - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) - else: - # For llama 4, in hf unified checkpoint, all local experts share one scale - self.rules["local_experts.linear_fc1"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2"]( - layer.mlp.experts.local_experts, layer_id - ) - else: - self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) - self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) - else: - raise ValueError("Only TransformerLayer or MambaLayer are supported.") + all_exclude_modules = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_exclude_modules, self.exclude_modules) + combined_exclude_modules = set() + for modules in all_exclude_modules: + if modules: + combined_exclude_modules.update(modules) + return sorted(combined_exclude_modules) def export_mcore_gpt_to_hf( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), + trust_remote_code: bool = False, moe_router_dtype: torch.dtype | None = None, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir. @@ -1225,9 +1212,13 @@ def export_mcore_gpt_to_hf( pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype, + trust_remote_code=trust_remote_code, moe_router_dtype=moe_router_dtype, ) - exporter.save_pretrained(export_dir, pretrained_model_name_or_path) + if exporter.export_extra_modules: + exporter.save_pretrained_extra_modules(export_dir) + else: + exporter.save_pretrained(export_dir, pretrained_model_name_or_path) def import_mcore_gpt_from_hf( @@ -1235,6 +1226,7 @@ def import_mcore_gpt_from_hf( pretrained_model_path: str, workspace_dir: str | None = None, dtype: torch.dtype = torch.bfloat16, + trust_remote_code: bool = False, moe_router_dtype: torch.dtype | None = None, ): """Import GPTModel state_dict from supported HuggingFace pretrained model path. @@ -1243,13 +1235,17 @@ def import_mcore_gpt_from_hf( model: The Megatron Core GPTModel instance. pretrained_model_path: A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + workspace_dir: The directory to save the workspace. dtype: The weights data type to import. + trust_remote_code: If True, this allows importing from a wider range of sources. + moe_router_dtype: The data type to import the moe router weights. """ importer = GPTModelImporter( model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype, + trust_remote_code=trust_remote_code, moe_router_dtype=moe_router_dtype, ) importer._import_state_dict() diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index a250127098..83fa3c1d85 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -18,26 +18,18 @@ import types from abc import ABC from collections.abc import Callable, Sequence -from typing import Any import torch import torch.nn as nn from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.gpt import GPTModel -from megatron.core.parallel_state import ( - get_data_parallel_group, - get_pipeline_model_parallel_group, - get_tensor_model_parallel_group, - is_pipeline_first_stage, - is_pipeline_last_stage, -) +from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage from megatron.core.tensor_parallel.layers import ( ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ) -from megatron.core.transformer import MegatronModule from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLP @@ -51,29 +43,14 @@ from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.opt.hparam import HPType -from modelopt.torch.opt.searcher import ConstraintsDict from modelopt.torch.trace import Symbol from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import ( - get_module_device, - make_divisible, - param_num_from_forward, - print_rank_0, - random, -) +from modelopt.torch.utils import make_divisible -from ..algorithms import ( - MODULE_TYPE_TO_CONSTRAINTS_FUNC, - ConstraintEvalFunc, - ConstraintInterpolator, - ConstraintsFunc, - ConstraintsRes, -) from ..hparams.concat import build_concat_hp from ..modules import _DynamicLayerNorm from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices from ..registry import DMRegistry -from ..search_space import SampleFunc from ..traced_hp import TracedHp SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"} @@ -634,7 +611,6 @@ def modify( def _export_reinit_token_dispatcher(self) -> None: """Reinitialize the token dispatcher after pruning.""" - print_rank_0("Reinitializing token dispatcher after pruning") if hasattr(moe_utils, "get_default_model_comm_pgs"): model_comm_pgs = moe_utils.get_default_model_comm_pgs() else: @@ -1045,27 +1021,30 @@ def modify( *, hidden_size_divisor: int = 1, ffn_hidden_size_divisor: int = 1, - mamba_num_heads_divisor: int = 1, mamba_head_dim_divisor: int = 1, num_moe_experts_divisor: int = 1, + num_layers_divisor: int = 1, ): """Modify the dynamic choices of the module according to provided keyword arguments. Args: hidden_size_divisor: The divisor of the hidden_size. ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. - mamba_num_heads_divisor: The divisor of the mamba num_heads. mamba_head_dim_divisor: The divisor of the mamba head_dim. num_moe_experts_divisor: The divisor of the number of MoE experts. + num_layers_divisor: The divisor of the number of layers. """ - hp = self.get_hparam("hidden_size") - choices = {int(make_divisible(c, hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] - hp.choices = list(set(hp.choices) & choices | {hp.original}) + for hp_name, divisor in [ + ("hidden_size", hidden_size_divisor), + ("num_layers", num_layers_divisor), + ]: + hp = self.get_hparam(hp_name) + choices = {int(make_divisible(c, divisor)) for c in hp.choices} # type: ignore[arg-type] + hp.choices = list(set(hp.choices) & choices | {hp.original}) for layer in self.decoder.layers: layer.modify( ffn_hidden_size_divisor=ffn_hidden_size_divisor, - mamba_num_heads_divisor=mamba_num_heads_divisor, mamba_head_dim_divisor=mamba_head_dim_divisor, num_moe_experts_divisor=num_moe_experts_divisor, ) @@ -1084,86 +1063,3 @@ def export(self) -> torch.nn.Module: ).export() self.output_layer.export() return super().export() - - -class MegatronConstraintsFunc(ConstraintsFunc): - """A Functor class to check if sub-net satisfied all provided constraints. - - We intentionally expose some attributes like `limits` s.t. we can modify it manually. - """ - - _sample_points_dict: dict[tuple[str, ...], dict[str, SampleFunc]] = { - ("params",): {"min": min, "centroid": random.centroid, "max": max}, - } - - def __init__( - self, - model: MegatronModule, - constraints: ConstraintsDict, - dummy_input: Any | tuple[Any, ...], - deployment: dict | None = None, - fast_eval: bool = True, - ): - """Initialize with additional data parallel group info from megatron.""" - for key in constraints: - if key != "params": - raise ValueError("Only params constraints is supported for MegatronModule!") - - self.model = model - self.dummy_input = dummy_input - self.deployment = deployment - self._fast_eval = fast_eval - - # Getting data parallel group for - self.dp_group = get_data_parallel_group() - - # initialize latency interpolator - keys_for_interpolation = ("params",) - if ConstraintsFunc.is_configurable(self.model, "depth"): - keys_for_interpolation += ("flops_min_depth",) - self._latency_interpolator = ConstraintInterpolator( - self.model, - points_funcs={k: self.constraint_eval_funcs[k] for k in keys_for_interpolation}, - value_func=self._get_true_latency, - ) - # set fast/regular mode for latency interpolator - self._latency_interpolator.collect_mode = not self.fast_eval - - # set limit at the end with setter to use sanity checks on constraints - self._limits = {} - self.limits = constraints - - @property - def constraint_eval_funcs(self) -> dict[str, ConstraintEvalFunc]: - """Get constraint eval fns.""" - return { - "params": self._get_params, - } - - def _get_params(self, _: ConstraintsRes | None = None) -> float: - """Get number of model parameters from forward pass.""" - params = param_num_from_forward(self.model, args=self.dummy_input, unit=1.0) - reduced_params = torch.Tensor([params]).to(device=get_module_device(self.model)) - torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) - torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) - return reduced_params.item() - - def _get_flops(self, _: ConstraintsRes | None = None) -> float: - """Get inference FLOPs.""" - raise NotImplementedError - - def _get_flops_min_depth(self, _: ConstraintsRes | None = None) -> float: - """Get inference FLOPs with depth set to minimum.""" - raise NotImplementedError - - def _get_true_latency(self, _: ConstraintsRes | None = None) -> float: - """Get true inference latency.""" - raise NotImplementedError - - def _get_latency(self, precomputed: ConstraintsRes | None = None) -> float: - """Get inference latency from interpolator.""" - raise NotImplementedError - - -# Clear the mapping and reinsert. -MODULE_TYPE_TO_CONSTRAINTS_FUNC[MegatronModule] = MegatronConstraintsFunc diff --git a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py index 0ed0003490..7cd7214443 100644 --- a/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py +++ b/modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py @@ -1042,7 +1042,7 @@ def __call__( args: Tuple with one tensor entry (B, T, I) output: Router logits of shape (B, T, E) """ - router_logits = output + router_logits = output[0] if isinstance(output, tuple) else output num_experts = router_logits.shape[-1] router_argsort = torch.argsort(router_logits, dim=-1, descending=True) router_argsort = router_argsort.view(-1, num_experts).to(torch.int16).cpu() diff --git a/modelopt/torch/nas/search_space.py b/modelopt/torch/nas/search_space.py index 6da4d425a8..4c7d2172a7 100644 --- a/modelopt/torch/nas/search_space.py +++ b/modelopt/torch/nas/search_space.py @@ -135,9 +135,7 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F hps_to_sort: A set of hparam names to sort. If not provided or empty, all hparams will be sorted. verbose: Whether to print the search space and hparam importances. """ - print_rank_0("Sorting parameters...") - if verbose: - self.print_summary() + print_rank_0("\nSorting parameters...") # get config and set to max config = self.config() diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 1de6143bd1..874c51b599 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -526,6 +526,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any] model = ... # Create the model-like object # Restore the previously saved modelopt state followed by model weights + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input mto.restore_from_modelopt_state( model, torch.load("modelopt_state.pt", weights_only=False) ) # Restore modelopt state diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index a2834329e5..7988f9f970 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -31,7 +31,6 @@ from torch.nn.parameter import Parameter from modelopt.torch.utils import get_unwrapped_name, is_channels_last, unwrap_model -from modelopt.torch.utils.distributed import ParallelState from modelopt.torch.utils.network import bind_forward_method from .config import ModeloptBaseRule, RulesDict @@ -359,14 +358,10 @@ class DynamicModule(nn.Module): should ensure only to expose ``hparams`` in the outermost class and handle other ``hparams`` internally including ``hparams`` of child modules that are exposed on their own usually (e.g. block module implementations containing DynamicLinear). - - In addition, the class also provides ``parallel_state`` attribute that can be used to access - the parallel state of the module. """ # this is needed to store the special attributes for dynamic modules _dm_attribute_manager: _DMAttributeManager - _parallel_state: ParallelState def __init__(self, *args, **kwargs): """Initializing a dynamic module is not allowed!""" @@ -584,6 +579,15 @@ def export(self) -> nn.Module: assert not is_dynamic, "Exported module must not be a DynamicModule anymore!" delattr(self, "_dm_attribute_manager") + # If this module had a monkey-patched forward before DynamicModule.convert(), we may have + # overridden it by binding the dynamic forward onto the instance (to follow the MRO). + # On final export, restore the original forward to avoid leaking a dynamic forward + # (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance. + # please see: https://github.com/NVIDIA/Model-Optimizer/pull/824 + if hasattr(self, "_forward_pre_dm"): + setattr(self, "forward", getattr(self, "_forward_pre_dm")) + delattr(self, "_forward_pre_dm") + return self @classmethod @@ -621,6 +625,10 @@ def bind_forward_method_if_needed(self): # accelerate patched module bind_forward_method(self, self.__class__.forward) else: + if not hasattr(self, "_forward_pre_dm"): + # Keep the patched forward for downstream modules that want to call it. + self._forward_pre_dm = self.forward + bind_forward_method(self, self.__class__.forward) warnings.warn( "Received a module with monkey patched forward method. Dynamic converted module" " might not work." @@ -644,10 +652,6 @@ def bind_forward_method_if_needed(self): # setup new hparams and dynamic attributes module._setup(**setup_kwargs) - # setup parallel state now that the module is converted - if module.parallel_state is None: - module._initialize_parallel_state() - return module def _setup(self, **setup_kwargs: Any): @@ -854,36 +858,6 @@ def original_cls(self) -> type[nn.Module]: """ return self._get_dm_attribute_manager().og_cls - @property - def parallel_state(self) -> ParallelState | None: - """Return the parallel state of the dynamic module.""" - return getattr(self, "_parallel_state", None) - - @parallel_state.setter - def parallel_state(self, parallel_state: ParallelState): - """Set the parallel state of the dynamic module.""" - assert isinstance(parallel_state, ParallelState), ( - "parallel_state must be a ParallelState object!" - ) - self._parallel_state = parallel_state - - def _initialize_parallel_state(self): - """Initialize the parallel state of the dynamic module. - - This method is called only if the `DynamicModule` does not have a `parallel_state` attribute - after `_setup` is called. - """ - if torch.distributed.is_initialized(): - warnings.warn( - f"Distributed training is initialized but no parallel_state is set for {type(self)}. " - "Using default parallel_state which has data_parallel_group set to the default process group and " - "tensor_parallel_group is unspecified. " - "If you are using tensor parallelism for this module, you should set the parallel_state " - "in its `_setup` method." - ) - - self.parallel_state = ParallelState(data_parallel_group=None) - def get_original_cls_by_level(self, level: int = -1) -> type[nn.Module]: """Return the original class of the dynamic module. diff --git a/modelopt/torch/opt/plugins/huggingface.py b/modelopt/torch/opt/plugins/huggingface.py index 672d0f99a8..99bab77257 100644 --- a/modelopt/torch/opt/plugins/huggingface.py +++ b/modelopt/torch/opt/plugins/huggingface.py @@ -79,6 +79,7 @@ def new_init_fn(self, *args, **kwargs): modelopt_state_path = _get_modelopt_state_path(model_path) _original__init__(self, *args, **kwargs) if os.path.isfile(modelopt_state_path): + # Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False) with extra_context() if extra_context else nullcontext(): restore_from_modelopt_state(self, modelopt_state) diff --git a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py index 278c7b9fb5..3e5b359468 100644 --- a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py +++ b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py @@ -242,6 +242,7 @@ def restore_sharded_modelopt_state( return # Loading the common modelopt_state (replicated on all ranks) + # Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input common_modelopt_state = torch.load( modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False ) diff --git a/modelopt/torch/opt/plugins/megatron.py b/modelopt/torch/opt/plugins/megatron.py index fb2a8ea875..e45198c208 100644 --- a/modelopt/torch/opt/plugins/megatron.py +++ b/modelopt/torch/opt/plugins/megatron.py @@ -102,7 +102,7 @@ def _modelopt_set_extra_state(self, state: Any): # Default format: byte tensor with pickled data # # TODO: possible deserialization improvement - # https://github.com/NVIDIA/TensorRT-LLM/commits/main/tensorrt_llm/serialization.py + # https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec else: raise RuntimeError("Unsupported extra_state format.") diff --git a/modelopt/torch/opt/plugins/peft.py b/modelopt/torch/opt/plugins/peft.py index 5e5ed0f934..c3fd268a58 100644 --- a/modelopt/torch/opt/plugins/peft.py +++ b/modelopt/torch/opt/plugins/peft.py @@ -72,6 +72,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): assert adapter_name in self.peft_config, ( f"ModelOpt modified model should have adapter_name={adapter_name} in peft_config" ) + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input restore_from_modelopt_state( self, torch.load(modelopt_state_path, map_location="cpu", weights_only=False) ) @@ -85,6 +86,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): if os.path.isfile(_get_quantizer_state_save_path(model_id)): from modelopt.torch.quantization.nn import TensorQuantizer + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input quantizer_state_dict = torch.load( _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False ) diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 5eb2e134ec..ab3930c207 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -27,7 +27,6 @@ from collections.abc import Callable from contextlib import nullcontext from typing import Any, final -from warnings import warn import numpy as np import pulp @@ -35,7 +34,7 @@ import torch.nn as nn from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import no_stdout, run_forward_loop +from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0 LimitsTuple = tuple[float, float] ConstraintsDict = dict[str, str | float | dict | None] @@ -212,6 +211,7 @@ def construct_forward_loop( return None def forward_loop_with_silence_check(m: nn.Module) -> None: + print_rank_0("Running forward loop...") with no_stdout() if silent else nullcontext(): if data_loader is not None: run_forward_loop( @@ -243,12 +243,12 @@ def load_search_checkpoint(self) -> bool: if checkpoint is None: return False if not os.path.exists(checkpoint): - if dist.is_master(): - warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") + warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") return False # iterate through state dict and load keys - print(f"Loading searcher state from {checkpoint}...") + print_rank_0(f"Loading searcher state from {checkpoint}...") + # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state_dict = torch.load(checkpoint, weights_only=False) assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" for key, state in state_dict.items(): diff --git a/modelopt/torch/prune/__init__.py b/modelopt/torch/prune/__init__.py index aac5f7e878..847b22e9d6 100644 --- a/modelopt/torch/prune/__init__.py +++ b/modelopt/torch/prune/__init__.py @@ -19,8 +19,6 @@ simplifies the overall workflow to accommodate for the simpler nature of pruning algorithms. """ -# nas is a required - so let's check if it's available -import modelopt.torch.nas from modelopt.torch.utils import import_plugin from . import fastnas, gradnas, plugins diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index db6769b7b5..6470776e7a 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -24,25 +24,30 @@ Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`. """ -import copy from collections.abc import Callable +from dataclasses import dataclass from functools import partial +from itertools import product from typing import Any from warnings import warn import torch import torch.nn as nn import torch.nn.functional as F +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel import ( gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, ) from pydantic import create_model +from tqdm import tqdm from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.nas.plugins.megatron import ( @@ -52,12 +57,13 @@ _DynamicMambaMixer, _DynamicMCoreLanguageModel, _DynamicMLP, + _DynamicMoELayer, _DynamicSelfAttention, _DynamicSequentialMLP, _DynamicTransformerLayer, ) from modelopt.torch.nas.registry import DMRegistry -from modelopt.torch.nas.utils import get_subnet_config, sort_parameters +from modelopt.torch.nas.utils import get_subnet_config, sample, sort_parameters from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules from modelopt.torch.opt.conversion import ApplyModeError from modelopt.torch.opt.dynamic import DynamicModule, DynamicSpace @@ -70,7 +76,7 @@ from modelopt.torch.opt.searcher import BaseSearcher, SearchConfig, SearchStateDict from modelopt.torch.opt.utils import named_hparams from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import get_module_device, print_rank_0 +from modelopt.torch.utils import get_module_device, num2hrb, print_rank_0 from ..pruning import PruneModeRegistry @@ -99,6 +105,7 @@ "MCoreMinitronSearcher", "drop_mcore_language_model_layers", "get_mcore_minitron_config", + "get_mcore_param_count", ] @@ -120,7 +127,7 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i assert isinstance(model, supported_model_types), ( f"Model should have one of {supported_model_types} submodule, got {model}" ) - print_rank_0(f"Dropping layers {layers_to_drop} from {n} ({type(model)}).") + print_rank_0(f"Dropping decoder layers {layers_to_drop} from model.") # get the number of layers remaining in each pp rank layers_remaining_per_pp = torch.zeros( @@ -144,34 +151,42 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i new_num_layers = sum(layers_remaining_per_pp) # reindex kept layers, exclude sharded state dict for dropped layers - layer_offset = sum(layers_remaining_per_pp[: get_pipeline_model_parallel_rank()]) - layer_number = layer_offset + 1 - dropped_layers = [] + layer_number = sum(layers_remaining_per_pp[: get_pipeline_model_parallel_rank()]) + 1 + kept_layers = [] for layer in model.decoder.layers: - if layer.layer_number in layers_to_drop: - layer.layer_number = -1 # should not be used - # layer.sharded_state_dict = lambda prefix, sharded_offsets, metadata: {} - dropped_layers.append(layer) - else: + if layer.layer_number not in layers_to_drop: layer.layer_number = layer_number - layer.get_transformer_layer_offset = lambda: layer_offset layer_number += 1 - - # remove dropped layers from the modulelist - model.decoder.layers = nn.ModuleList( - [layer for layer in model.decoder.layers if layer.layer_number != -1] - ) - for layer in dropped_layers: - del layer + kept_layers.append(layer) + model.decoder.layers = nn.ModuleList(kept_layers) model.config.num_layers = new_num_layers +@dataclass +class CandidateSubnet: + ss_config: dict + params: float + score: float | None + + class MCoreMinitronSearcher(BaseSearcher): - """Searcher for Minitron pruning algorithm.""" + """Searcher for Minitron pruning algorithm. + + Available additional config options (used when `params` constraint is provided): + - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40). + Only top (1 - max_width_pruning) choices will be considered. + - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20). + Only top (1 - max_depth_pruning) choices will be considered. + - `hparams_to_skip`: List of hparams to skip during the search (default: None). + - `top_k`: Number of candidates to consider for score_func validation (default: 10). + """ activations_per_rank: list[dict[str, torch.Tensor]] layer_scores: dict[int, torch.Tensor] + sorted_layers: list[int] | None # 1-indexed sorted list of layer numbers + # Dict from params constraint to list of tuples (ss_config, params, score) + top_k_candidates_per_constraint: dict[float, list[CandidateSubnet]] @property def default_search_config(self) -> SearchConfig: @@ -181,17 +196,28 @@ def default_search_config(self) -> SearchConfig: "max_iter_data_loader": 1024, "skip_sorting": False, "scores_path": None, + # Additional search config for parameter-based pruning + "max_width_pruning": 0.40, + "max_depth_pruning": 0.20, + "hparams_to_skip": None, + "top_k": 10, } @property def default_state_dict(self) -> SearchStateDict: """Return default state dict for importance scores and activations from forward loop.""" - return {"activations_per_rank": [], "layer_scores": {}} + return { + "activations_per_rank": [], + "layer_scores": {}, + "sorted_layers": None, + "top_k_candidates_per_constraint": {}, + } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: """Sanitize the search config dict.""" config = super().sanitize_search_config(config) - config["checkpoint"] = config["scores_path"] + if config["scores_path"]: + config["checkpoint"] = config["scores_path"] config["verbose"] = True # Print for all ranks return config @@ -200,53 +226,56 @@ def before_search(self) -> None: super().before_search() # Check that the constraint is valid - assert self.constraints.keys() == {"export_config"}, ( - "Only `export_config` constraint is supported for pruning!" - ) - - self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"]) - export_config = self.constraints["export_config"] - if "num_query_groups" in export_config: - warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") - if export_config["num_query_groups"] != self.model.config.num_query_groups: # type: ignore[index] - raise ValueError(f"num_query_groups must be {self.model.config.num_query_groups}!") - export_config.pop("num_query_groups") # type: ignore[union-attr] - assert isinstance(export_config, dict) # to keep mypy happy - assert export_config.keys() <= SUPPORTED_HPARAMS, ( - f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}" - ) + assert len(self.constraints) == 1 and next(iter(self.constraints.keys())) in { + "export_config", + "params", + }, "Only `export_config` or `params` constraint is supported!" + + if "export_config" in self.constraints: + export_config = self.constraints["export_config"] + assert isinstance(export_config, dict) # to keep mypy happy + if "num_query_groups" in export_config: + warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") + if export_config["num_query_groups"] != self.model.config.num_query_groups: + raise ValueError( + f"num_query_groups must be {self.model.config.num_query_groups}!" + ) + export_config.pop("num_query_groups") + assert export_config.keys() <= SUPPORTED_HPARAMS, ( + f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config=}" + ) - # Only sort the parameters that are to be pruned - # If a user only prunes depth, we should not sort width parameters - self.hps_to_sort = SUPPORTED_HPARAMS & export_config.keys() + # Only sort the parameters that are to be pruned + # If a user only prunes depth, we should not sort width parameters + self.hps_to_sort = set(export_config.keys()) + else: + assert isinstance(self.constraints["params"], (int, float)), "params must be a float!" + assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!" + export_config = None + # Sort all parameters for parameter-based pruning + self.hps_to_sort = SUPPORTED_HPARAMS for n, hp in named_hparams(self.model, unique=True): hp_name = n.split(".")[-1] if hp.is_configurable: # Make sure configurable hparams are the ones with right names else implementation needs to be fixed! assert hp_name in SUPPORTED_HPARAMS, f"[ImplError] Invalid hparam {hp_name}!" - if hp_name in export_config: + if export_config is not None and hp_name in export_config: assert export_config[hp_name] in hp.choices, ( f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" ) hp.reset_choices() # Make sure ConcatHparam choices are updated after modify() - def run_search(self) -> None: - """Run actual search.""" - # Run forward loop to collect activations and sort parameters - unwrapped_model = self.model - for m in self.model.modules(): - if isinstance(m, _DynamicMCoreLanguageModel): - unwrapped_model = m - break - assert isinstance(unwrapped_model, _DynamicMCoreLanguageModel), "Model not supported!" + assert isinstance(self.model, _DynamicMCoreLanguageModel), ( + "Input should be unwrapped MCore model!" + ) - registry = ImportanceEstimatorRegistry(unwrapped_model) + def run_search(self) -> None: + """Run forward loop to collect activations, sort parameters, and prune the model.""" + registry = ImportanceEstimatorRegistry(self.model) if self.layer_scores and self.activations_per_rank: # Available from checkpoint - print_rank_0("Loading activations and scores per rank from checkpoint...") registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores) elif not self.config["skip_sorting"]: - print_rank_0("Running forward loop...") assert self.forward_loop is not None is_training = self.model.training self.model.eval() @@ -264,36 +293,349 @@ def run_search(self) -> None: print_rank_0("Skipping sorting parameters...") else: sort_parameters(self.model, self.hps_to_sort, verbose=True) + registry.cleanup() + + if self.layer_scores: + # sort layers by scores and drop the lowest ones + self.sorted_layers = [ + layer + for layer, _ in sorted(self.layer_scores.items(), key=lambda x: x[1], reverse=True) + ] + assert sorted(self.sorted_layers) == list(range(1, self.model.config.num_layers + 1)) + else: + assert ( + self.constraints.keys() == {"export_config"} + and "num_layers" not in self.constraints["export_config"] + ), "Cannot prune `num_layers` without collecting layer scores!" + self.sorted_layers = None + + if "params" in self.constraints: + export_config = self.search_best_arch_by_params() + else: + export_config = self.constraints["export_config"] + + # Prune homogeneously + self._prune(export_config, prune_depth=True) + + # TODO: Rename to hybrid_layer_pattern after https://github.com/NVIDIA/Megatron-LM/pull/3377 + # Update hybrid_override_pattern if pruning is done on a hybrid model + if isinstance(self.model, MambaModel): + print_rank_0(f"Original hybrid_override_pattern: {self.model.hybrid_override_pattern}") + new_num_layers = self.model.config.num_layers + assert self.sorted_layers is not None + kept_layers_numbers = self.sorted_layers[:new_num_layers] + self.model.hybrid_override_pattern = "".join( + c + for i, c in enumerate(self.model.hybrid_override_pattern) + if i + 1 in kept_layers_numbers + ) + print_rank_0(f"Pruned hybrid_override_pattern: {self.model.hybrid_override_pattern}") + def _prune( + self, + export_config: dict, + prune_depth: bool = True, + ) -> None: + """Prune the model homogeneously based on the export_config by setting active choices for configurable hparams. + + Args: + export_config: Dictionary mapping hyperparameter names to their pruned values. + prune_depth: Whether to drop layers based on sorted_layers (default: True). + """ # Prune homogeneously - export_config = self.constraints["export_config"] - assert isinstance(export_config, dict) # to keep mypy happy for n, hp in named_hparams(self.model, configurable=True): hp_name = n.split(".")[-1] if hp_name in export_config: hp.active = export_config[hp_name] # Drop layers if depth pruning is enabled - num_layers_hp = unwrapped_model.get_hparam("num_layers") - if num_layers_hp.active != num_layers_hp.max: - # sort layers by scores and drop the lowest ones - sorted_layers = sorted(self.layer_scores.items(), key=lambda x: x[1], reverse=True) - layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc] - drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) - - # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads - model_cfg = self.model.config - orig_kv_channels = getattr(model_cfg, "kv_channels") - if orig_kv_channels is None: - orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr( - model_cfg, "num_attention_heads" + if prune_depth: + num_layers_hp = self.model.get_hparam("num_layers") + if num_layers_hp.active != num_layers_hp.max: + assert self.sorted_layers is not None + layers_to_drop = self.sorted_layers[num_layers_hp.active :] + drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) + + # Update model config with pruned architecture + # kv_channels can be None so we need to save from original hidden_size and num_attention_heads + if self.model.config.kv_channels is None: + self.model.config.kv_channels = ( + self.model.config.hidden_size // self.model.config.num_attention_heads ) - setattr(model_cfg, "kv_channels", orig_kv_channels) - for n in SUPPORTED_HPARAMS: - if n in export_config: - setattr(model_cfg, n, export_config[n]) + # num_query_groups can be None so we need to save from original num_attention_heads + if self.model.config.num_query_groups is None: + self.model.config.num_query_groups = self.model.config.num_attention_heads + # moe_ffn_hidden_size can be None so we need to save from original ffn_hidden_size + if ( + self.model.config.moe_ffn_hidden_size is None + and self.model.config.num_moe_experts is not None + ): + self.model.config.moe_ffn_hidden_size = self.model.config.ffn_hidden_size + # Now set hparam active choices + for hp_name, hp_value in export_config.items(): + setattr(self.model.config, hp_name, hp_value) + + # Reinitialize the MoE token dispatcher after pruning + for m in self.model.modules(): + if isinstance(m, _DynamicMoELayer): + m._export_reinit_token_dispatcher() + break - registry.cleanup() + def search_best_arch_by_params(self) -> dict: + """Search for the best architecture based on the given parameters constraints. + + We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints. + Top-k candidates (sorted by param count) are then validated using the score_func (e.g. MMLU) + and the best subnet is returned. + + Returns: + export_config: Dictionary mapping hyperparameter names to their pruned values. + """ + assert self.sorted_layers is not None + max_params = float(self.constraints["params"]) # type: ignore[arg-type] + max_width_pruning = self.config["max_width_pruning"] + max_depth_pruning = self.config["max_depth_pruning"] + hparams_to_skip = self.config["hparams_to_skip"] + top_k = self.config["top_k"] + print_rank_0( + f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..." + ) + + # 1. Find available search space choices (across all PP ranks) + hp_choices = {} + for n, hp in named_hparams(self.model, configurable=True): + hp_name = n.split(".")[-1] + hp_choices[hp_name] = hp.choices + pp_group = dist.DistributedProcessGroup(get_pipeline_model_parallel_group()) + hp_choices = dist.DistributedProcessGroup.get_dist_syncd_obj( + hp_choices, + pp_group, + op=lambda all_pp_search_spaces: { + k: v for d in all_pp_search_spaces for k, v in d.items() + }, + ) + + # 2. Perform grid-search over the search space to find subnets fitting the constraints + if ( + max_params not in self.top_k_candidates_per_constraint + or len(self.top_k_candidates_per_constraint[max_params]) != top_k + ): + max_num_layers = self.model.get_hparam("num_layers").max + search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( + hp_choices, + max_width_pruning, + max_depth_pruning, + hparams_to_skip, + ) + sample(self.model, sample_func=max) # reset to max subnet (for sanity) + selected = [] + for ss_config in tqdm( + search_space_configs, + desc=f"Finding top {top_k} (`config['top_k']`) candidates fitting the constraints...", + disable=not dist.is_master(), + ): + self._prune(ss_config, prune_depth=False) + layer_ids = None + if "num_layers" in ss_config and ss_config["num_layers"] < max_num_layers: + layer_ids = self.sorted_layers[: ss_config["num_layers"]] + candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids) + if candidate_params <= max_params: + selected.append(CandidateSubnet(ss_config, candidate_params, None)) + sample(self.model, sample_func=max) # reset to max subnet + assert len(selected) > 0, "No subnets found fitting the constraints!" + print_rank_0(f"Found {len(selected)} candidates fitting the constraints!") + self.top_k_candidates_per_constraint[max_params] = sorted( + selected, key=lambda x: x.params, reverse=True + )[:top_k] + self.save_search_checkpoint(verbose=True) + else: + print_rank_0(f"\nUsing top {top_k} candidates from checkpoint") + top_k_candidates = self.top_k_candidates_per_constraint[max_params] + + print_rank_0(f"\n====================\nTop {top_k} candidates:") + for candidate in top_k_candidates: + print_rank_0(f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params") + print_rank_0("====================\n") + + # 3. Optional Knowledge Distillation (KD) step for all top-k candidates + print_rank_0( + "\nSkipping optional Knowledge Distillation (KD) step for candidates as it is a manual step. " + "As per the original paper (https://arxiv.org/pdf/2407.14679), ideally we need to perform a short " + f"Knowledge Distillation on ~2B tokens for all top {top_k} candidates before evaluating the " + "`score_func`, which will take a lot longer to prune, require splitting the pruning process into multiple " + "stages and a lot more compute for pruning but can lead to better pruned model selection. If you are " + f"interested to do this, you can take the top {top_k} candidates' `export_config` from the logs above and " + "then export all models separately and perform Knowledge Distillation on each of them before evaluating " + "the `score_func`.\n" + ) + + # 4. Validate top-k candidates using the score_func and return the best subnet + for candidate in tqdm( + top_k_candidates, + desc=f"Validating top {top_k} candidates on given score_func (this will take some time)...", + disable=not dist.is_master(), + smoothing=0.7, + ): + if candidate.score is None: # not restored from checkpoint + all_layers = self.model.decoder.layers + start_layer_number = all_layers[0].layer_number + + self._prune(candidate.ss_config, prune_depth=True) + candidate.score = self.eval_score(silent=False) + self.save_search_checkpoint(verbose=False) + + # reset to max subnet and revert dropped layers + sample(self.model, sample_func=max) + for layer in all_layers: + layer.layer_number = start_layer_number + start_layer_number += 1 + self.model.decoder.layers = all_layers + print_rank_0( + f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n" + ) + + print_rank_0(f"\n====================\nTop {top_k} candidates with scores:") + for candidate in top_k_candidates: + print_rank_0( + f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score" + ) + print_rank_0("====================\n") + + dist.barrier() + best = max(top_k_candidates, key=lambda x: x.score) # type: ignore[arg-type, return-value] + print_rank_0( + f"\n[BEST SUBNET] {best.ss_config} -> {num2hrb(best.params)} params, {best.score:.4f} score\n" + ) + return best.ss_config + + @staticmethod + def _generate_search_space_combos( + search_space: dict[str, list], + max_width_pruning: float = 0.40, + max_depth_pruning: float = 0.20, + hparams_to_skip: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Generate all possible combinations of hyperparameters from the search space. + + Args: + search_space: Dictionary mapping hyperparameter names to their possible sorted choices. + Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]} + max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.40). + Only top (1 - max_width_pruning) choices will be considered. + max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.20). + Only top (1 - max_depth_pruning) choices will be considered. + hparams_to_skip: List of hparams to skip during the search (default: None). + + Returns: + List of configuration dictionaries, where each dictionary maps hyperparameter + names to their chosen values. Example: + [ + {"hidden_size": 1024, "num_layers": 1}, + {"hidden_size": 1024, "num_layers": 2}, + ... + {"hidden_size": 4096, "num_layers": 32}, + ] + """ + print_rank_0( + f"\nOnly considering atmost {(max_width_pruning * 100):.0f}% for width and " + f"{max_depth_pruning * 100:.0f}% for depth pruning hparams" + ) + + if hparams_to_skip: + search_space = dict(search_space) # Avoid modifying the original search space + print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...") + for hparam in hparams_to_skip: + if hparam in search_space: + search_space.pop(hparam) + else: + warn(f"Hparam {hparam} not found in search space! Skipping...") + + filtered_ss = { + k: ( + sorted(v)[int((1 - max_depth_pruning) * len(v)) :] + if k == "num_layers" + else sorted(v)[int((1 - max_width_pruning) * len(v)) :] + ) + for k, v in search_space.items() + if len(v) > 1 + } + + ss_size = 1 + for k, v in filtered_ss.items(): + print_rank_0(f"\tSearch space for {k}: {v}") + ss_size *= len(v) + print_rank_0(f"\tTotal search space in consideration: {ss_size}\n") + + hparam_names = list(filtered_ss.keys()) + hparam_choices_lists = [filtered_ss[name] for name in hparam_names] + + search_space_combos = [ + dict(zip(hparam_names, choices)) for choices in product(*hparam_choices_lists) + ] + assert len(search_space_combos) == ss_size + + return search_space_combos + + +def get_mcore_param_count(model: GPTModel | MambaModel) -> float: + """Get the number of parameters in the MCore GPTModel or MambaModel (reduced across TP and PP ranks).""" + assert isinstance(model, (GPTModel, MambaModel)), "Model must be a GPTModel or MambaModel" + if isinstance(model, DynamicModule): + return _param_num_dynamic(model) + else: + return _param_num(model) + + +def _param_num(model: GPTModel | MambaModel) -> float: + """Get the number of parameters in the model (reduced across TP and PP ranks).""" + # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True + params = sum( + p.numel() + for name, p in model.named_parameters() + if not model.share_embeddings_and_output_weights or "output_layer.weight" not in name + ) + + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() + + +def _param_num_dynamic( + model: _DynamicMCoreLanguageModel, *, layer_numbers_to_count: list[int] | None = None +) -> float: + """Get the number of parameters in the Dynamic Module (reduced across TP and PP ranks). + + Args: + model: GPTModel or MambaModel converted to a DynamicModule. + layer_numbers_to_count: If specified, only count the parameters of the given layer numbers (1-indexed). + Only needed when input is a DynamicModule to correctly count the parameters of the active layers. + """ + + # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! + def get_param_count(mod, name) -> int: + """Use getattr to access parameters correctly.""" + module_path, _, param_name = name.rpartition(".") + submodule = mod.get_submodule(module_path) if module_path else mod + return getattr(submodule, param_name).numel() + + # Account for depth pruning with uneven PP and hybrid models! + # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True + params = sum( + get_param_count(model, name) + for name, _ in model.named_parameters() + if ("decoder.layers." not in name or layer_numbers_to_count is None) + and not (model.share_embeddings_and_output_weights and "output_layer.weight" in name) + ) + if layer_numbers_to_count is not None: + for layer in model.decoder.layers: + if layer.layer_number in layer_numbers_to_count: + params += sum(get_param_count(layer, name) for name, _ in layer.named_parameters()) + + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( @@ -302,17 +644,19 @@ def run_search(self) -> None: registry=DMRegistry, default_rules={ "megatron.core.models.gpt.GPTModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, + "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, }, **( { "megatron.core.models.mamba.MambaModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "mamba_head_dim_divisor": 4, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, + "mamba_head_dim_divisor": 8, + "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, } } if HAS_MAMBA @@ -325,23 +669,30 @@ def run_search(self) -> None: def get_mcore_minitron_config( - channel_divisor: int = 64, - mamba_head_dim_divisor: int = 4, - num_moe_experts_divisor: int = 1, + *, + hidden_size_divisor: int = 256, + ffn_hidden_size_divisor: int = 512, + mamba_head_dim_divisor: int = 8, + num_moe_experts_divisor: int = 8, + num_layers_divisor: int = 2, ) -> ModeloptBaseConfig: - """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" + """Get a MCoreMinitronConfig with the given divisors instead of default.""" config = MCoreMinitronConfig() def _set_divisors(c): for k, v in c.items(): if isinstance(v, dict): _set_divisors(v) - elif k in ["hidden_size_divisor", "ffn_hidden_size_divisor"]: - c[k] = channel_divisor + elif k == "hidden_size_divisor": + c[k] = hidden_size_divisor + elif k == "ffn_hidden_size_divisor": + c[k] = ffn_hidden_size_divisor elif k == "mamba_head_dim_divisor": c[k] = mamba_head_dim_divisor elif k == "num_moe_experts_divisor": c[k] = num_moe_experts_divisor + elif k == "num_layers_divisor": + c[k] = num_layers_divisor _set_divisors(config) return config @@ -524,11 +875,14 @@ def get_layer_scores(self) -> dict[int, torch.Tensor]: layer_scores = {} for layer in self.model.decoder.layers: layer_scores[layer.layer_number] = layer._scores - all_pp_layer_scores = [None] * get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object( - all_pp_layer_scores, layer_scores, group=get_pipeline_model_parallel_group() + pp_group = dist.DistributedProcessGroup(get_pipeline_model_parallel_group()) + layer_scores = dist.DistributedProcessGroup.get_dist_syncd_obj( + layer_scores, + pp_group, + op=lambda all_pp_layer_scores: { + k: v for d in all_pp_layer_scores for k, v in d.items() + }, ) - layer_scores = {k: v for d in all_pp_layer_scores for k, v in d.items()} # type: ignore[attr-defined] print_rank_0(f"Layerwise scores (1-indexed, higher is better): {layer_scores}") assert sorted(layer_scores.keys()) == list(range(1, num_layers_hp.max + 1)) # type: ignore[arg-type] @@ -562,6 +916,7 @@ def set_activations_and_layer_scores( activations_per_rank: List of dicts from module name to activations. Should match PP size. layer_scores: Dict from layer_number (1-indexed) to score. """ + print_rank_0("Loading activations and scores per rank from checkpoint...") rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( diff --git a/modelopt/torch/prune/pruning.py b/modelopt/torch/prune/pruning.py index cdc4e7d8f8..50a4850ea7 100644 --- a/modelopt/torch/prune/pruning.py +++ b/modelopt/torch/prune/pruning.py @@ -78,7 +78,7 @@ def prune( constraints = {"params": "60%"} # Specify export_config with pruned hyperparameters - # This is supported and required if the model is converted via ``mcore_minitron`` mode. + # This is supported only if the model is converted via ``mcore_minitron`` mode. constraints = { "export_config": { "ffn_hidden_size": 128, diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py index 67ed74ed9c..5fdc92718c 100644 --- a/modelopt/torch/puzzletron/anymodel/converter/converter.py +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -27,6 +27,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig @@ -61,9 +62,9 @@ def _get_weight_map(input_dir: Path) -> Dict[str, str]: f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." ) - @staticmethod + @classmethod def convert_model_weights( - input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int ): """Convert model weights to subblock format.""" param_to_file = Converter._get_weight_map(input_dir) @@ -95,7 +96,21 @@ def convert_model_weights( data = load_file(os.path.join(input_dir, file)) for name in param_names: if param_to_file[name] == file and name in data: - tensors[name] = data[name] + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, "quantized", None) == "mxfp4": + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors( + data[converted_name + "_blocks"], + data[converted_name + "_scales"], + ) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] # Save this subblock print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") @@ -195,3 +210,26 @@ def create_block_configs_from_main_config(config: PretrainedConfig) -> List[Bloc return [BlockConfig(...) for layer_idx in range(num_layers)] """ raise NotImplementedError + + @staticmethod + def convert_weight_name(name: str) -> str: + """ + Convert weight names during checkpoint conversion. + + This method can be overridden by subclasses to apply model-specific weight name + transformations when converting checkpoints from HuggingFace format to Puzzletron format. + + Default implementation returns the name unchanged (identity function). + + Args: + name: Original weight name from HuggingFace checkpoint + + Returns: + Converted weight name for Puzzletron format + + Example: + For Qwen2.5-VL, this converts: + - visual.* → model.visual.* + - model.* → model.language_model.* + """ + return name diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py index f7352993da..23a42da581 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -17,10 +17,46 @@ import inspect from typing import Callable, Type +from transformers import AutoConfig + from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor import ModelDescriptor __all__ = ["ModelDescriptorFactory"] +# Map from HuggingFace config.model_type (in checkpoint config.json) to ModelDescriptorFactory name. +# Local to this script; add entries when supporting new model types for auto-detection. +_MODEL_TYPE_TO_DESCRIPTOR = { + "llama": "llama", + "mistral": "mistral_small", + "qwen2": "qwen2", + "qwen3": "qwen3", + "nemotron_h": "nemotron_h", + "nemotron_h_v2": "nemotron_h_v2", + "gpt_oss_20b": "gpt_oss_20b", +} + + +def resolve_descriptor_from_pretrained(pretrained: str | None, trust_remote_code: bool = True): + """Resolve the model descriptor by loading the checkpoint config and mapping model_type.""" + if not pretrained: + raise ValueError("pretrained must be provided") + + config = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", None) + + if model_type and model_type in _MODEL_TYPE_TO_DESCRIPTOR: + detected = _MODEL_TYPE_TO_DESCRIPTOR[model_type] + print( + f"[resolve_descriptor_from_pretrained] Auto-detected model_type='{model_type}' → descriptor='{detected}'" + ) + return ModelDescriptorFactory.get(detected) + + known = sorted(_MODEL_TYPE_TO_DESCRIPTOR.keys()) + raise ValueError( + f"Cannot auto-detect descriptor for model_type='{model_type}'. " + f"Known model types: {known}. Add this model_type to _MODEL_TYPE_TO_DESCRIPTOR if supported." + ) + class ModelDescriptorFactory: """Factory for registering and retrieving ModelDescriptor classes.""" diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py index b7e83dceca..d3c3d6cf6e 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py @@ -37,6 +37,8 @@ class GptOss20bConverter(Converter): All layers use MoE FFN (no standard dense FFN layers). """ + quantized = "mxfp4" + @staticmethod def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: """Create block configs for GPT-OSS-20B layers. diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py index fd5edc0636..85a6f139d5 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py @@ -42,6 +42,7 @@ # Production models use MXFP4 quantized MoE with combined tensors # (gate_up_proj_blocks, down_proj_blocks), which is not yet supported. from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn +from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock @ModelDescriptorFactory.register_decorator("gpt_oss_20b") @@ -50,6 +51,13 @@ class GptOss20bModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward`. + dummy_block.attention_type = original_layer.attention_type + return dummy_block + @staticmethod def decoder_layer_cls(): """Get the decoder layer class for GPT-OSS models. @@ -132,7 +140,7 @@ def build_ffn_predicates() -> Dict[str, re.Pattern]: r"(post_attention_layernorm\.weight" r"|mlp\.router\.weight" r"|mlp\.router\.bias" - r"|mlp\.experts\.((\d+\.)?(gate_up_proj|down_proj)(\.(weight|bias|blocks|scales))?|gate_up_proj_(bias|blocks|scales)|down_proj_(bias|blocks|scales)))$" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" ) for layer_idx in range(num_layers) } @@ -190,12 +198,15 @@ class GptOss20bExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mlp" moe_prefix_name: str = "model.layers.{layer_idx}.mlp" - expert_prefix_name: str = "experts.{expert_idx}" + expert_prefix_name: str = "experts" # Router has both weight and bias router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + # Fused format: single tensors containing all experts (test models) fused_expert_weights: List[str] = field( default_factory=lambda: [ @@ -212,5 +223,16 @@ class GptOss20bExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] ) - # Fused format: experts stored as single tensors - is_fused_experts: bool = True + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if ( + module_name.endswith(self.target_name) + and module.__class__.__name__ == target_class_name + ): + module_names_to_hook.append( + (self.block_idx_from_module_name(module_name), module_name) + ) + return module_names_to_hook diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 0000000000..8500abad71 --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,548 @@ +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c + +# MIT License +# +# Copyright (c) 2020 EleutherAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Deduces expert mappings by comparing weights +4. Outputs a new pruned (heterogeneous) checkpoint with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Any, Dict, List, Optional, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + if "experts" in key or "router" in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f"model.layers.{layer}.mlp.experts.gate_up_proj"].size(0) + print( + f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts" + ) + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + original_experts_dequant.append( + {"up": deqexpert_mlp1[orig_idx], "down": deqexpert_mlp2[orig_idx]} + ) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f"model.layers.{layer}.mlp" + student_up = student_experts.get(f"{prefix}.experts.gate_up_proj")[student_idx] + student_down = student_experts.get(f"{prefix}.experts.down_proj")[student_idx] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError( + f"Missing student expert weights for layer {layer} expert {student_idx}" + ) + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = ( + ( + orig_expert["up"].flatten()[:quick_compare_size] + - student_up.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + down_quick = ( + ( + orig_expert["down"].flatten()[:quick_compare_size] + - student_down.float().flatten()[:quick_compare_size] + ) + .pow(2) + .mean() + .sqrt() + ) + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float("inf") + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert["up"] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert["down"] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError( + f"Could not find match for student expert {student_idx} in layer {layer}" + ) + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print( + f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})" + ) + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, "r") as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks", + f"model.layers.{layer}.mlp.experts.gate_up_proj_scales", + f"model.layers.{layer}.mlp.experts.gate_up_proj_bias", + f"model.layers.{layer}.mlp.experts.down_proj_blocks", + f"model.layers.{layer}.mlp.experts.down_proj_scales", + f"model.layers.{layer}.mlp.experts.down_proj_bias", + f"model.layers.{layer}.mlp.router.weight", # Router weight + f"model.layers.{layer}.mlp.router.bias", # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index["weight_map"]: + filename = index["weight_map"][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework="pt") as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights(student_path: str, output_path: str, num_layers: int) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Copy embeddings + src_emb = os.path.join(student_subblocks, "embeddings.safetensors") + dst_emb = os.path.join(subblocks_dir, "embeddings.safetensors") + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/embeddings.safetensors" + + # Copy lm_head + src_head = os.path.join(student_subblocks, "lm_head.safetensors") + dst_head = os.path.join(subblocks_dir, "lm_head.safetensors") + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework="pt") as f: + for key in f.keys(): + weight_map[key] = "subblocks_safetensors/lm_head.safetensors" + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f"block_{layer}_attention.safetensors") + dst_attn = os.path.join(subblocks_dir, f"block_{layer}_attention.safetensors") + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework="pt") as f: + for key in f.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_attention.safetensors" + + return weight_map + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, "subblocks_safetensors") + student_subblocks = os.path.join(student_path, "subblocks_safetensors") + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f"block_{layer}_ffn.safetensors") + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if "experts" not in key and "router" not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f"model.layers.{layer}.mlp.router.weight"] + orig_router_bias = orig_tensors[f"model.layers.{layer}.mlp.router.bias"] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f"model.layers.{layer}.mlp.router.weight"] = sliced_router_weight + tensors_to_save[f"model.layers.{layer}.mlp.router.bias"] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] + mlp1_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] + mlp2_blocks = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] + mlp2_scales = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_scales"] + mlp1_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] + mlp2_bias = orig_tensors[f"model.layers.{layer}.mlp.experts.down_proj_bias"] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_blocks"] = mlp1_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_scales"] = mlp1_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.gate_up_proj_bias"] = mlp1_bias[ + kept_indices_tensor + ] + + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_blocks"] = mlp2_blocks[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_scales"] = mlp2_scales[ + kept_indices_tensor + ] + tensors_to_save[f"model.layers.{layer}.mlp.experts.down_proj_bias"] = mlp2_bias[ + kept_indices_tensor + ] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f"block_{layer}_ffn.safetensors") + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f"subblocks_safetensors/block_{layer}_ffn.safetensors" + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "chat_template.jinja", + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith("transformers_"): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, "config.json") + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, "r") as f: + config = json.load(f) + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config["architectures"] = ["DeciGptOssForCausalLM"] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config["quantization_config"] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + } + + dst_config = os.path.join(output_path, "config.json") + with open(dst_config, "w") as f: + json.dump(config, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Create MXFP4 checkpoint from student model") + parser.add_argument( + "--student-path", type=str, required=True, help="Path to student model checkpoint" + ) + parser.add_argument( + "--original-path", + type=str, + required=True, + help="Path to original gpt-oss-120b model with MXFP4 weights", + ) + parser.add_argument( + "--output-path", type=str, required=True, help="Output path for the new checkpoint" + ) + parser.add_argument("--num-layers", type=int, default=36, help="Number of transformer layers") + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, "model.safetensors.index.json") + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'=' * 70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'=' * 70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-' * 70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-' * 70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'=' * 70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, "subblocks_safetensors"), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, "experts_to_keep.json") + with open(experts_to_keep_output, "w") as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights(args.student_path, args.output_path, args.num_layers) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, "subblocks_safetensors") + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = {"metadata": {"total_size": total_size}, "weight_map": weight_map} + + index_path = os.path.join(args.output_path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py index c37b9adaf7..36d4fdfc1d 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py @@ -20,7 +20,7 @@ import warnings from typing import Any -from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available +from transformers.utils import is_flash_attn_2_available # , is_torch_sdpa_available from .block_config import BlockConfig from .transformers_4_44_2__configuration_llama import LlamaConfig @@ -119,7 +119,8 @@ def _delete_per_layer_attributes(self): def _choose_llama4_attn_implementation(self, llama4_attn_implementation): self.llama4_attn_implementation = llama4_attn_implementation if self.llama4_attn_implementation is None: - if is_torch_sdpa_available(): + # if is_torch_sdpa_available(): + if True: _print_once("auto-setting llama4_attn_implementation to sdpa") self.llama4_attn_implementation = "sdpa" else: diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 0000000000..471e68984b --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin +from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 + PuzzletronLlamaAnyModelBridge, +) +from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 + PuzzletronQwen3AnyModelBridge, +) + +__all__ = [ + "HeterogeneousBridgeMixin", + "PuzzletronLlamaAnyModelBridge", + "PuzzletronQwen3AnyModelBridge", +] diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py new file mode 100644 index 0000000000..13ea6612af --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass, fields + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + Returns empty list for per_block_parameters if not yet initialized (before finalize()). + This allows OmegaConf to serialize/deserialize configs without errors. Actual usage + should call finalize() first to set per_block_parameters as a real attribute. + """ + if name == "per_block_parameters": + # Return existing attribute if set, otherwise [] for OmegaConf compatibility + try: + return object.__getattribute__(self, name) + except AttributeError: + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider. + + This method: + 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all + model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) + 2. Converts the provider to a dict and filters to only fields accepted by + GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid + GPTModelProvider fields are preserved) + 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider + + All parameters from the parent bridge (e.g., LlamaBridge) are maintained because + GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all + the fields that the parent bridge sets. + """ + + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + provider_kwargs = dataclasses.asdict(parent_provider) + + # Filter to only fields that GenericHeterogeneousProvider accepts. + # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all + # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, + # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits + # from GPTModelProvider, not from model-specific subclasses. + # + # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they + # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., + # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a + # model-specific heterogeneous provider that inherits from the model-specific provider. + valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} + + # Only keep kwargs that are valid fields + provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py new file mode 100644 index 0000000000..fa49dc29c5 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. + +import logging +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Optional + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +import modelopt.torch.distill as mtd +import modelopt.torch.distill.plugins.megatron as mtd_mcore + +if TYPE_CHECKING: + from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistillationProvider(TransformerConfig): + """Provider for Megatron Core GPT models in distillation mode. + + Please use `convert_to_distillation_provider()` to create an instance of this class. + """ + + teacher: Optional[GPTModelProvider | MambaModelProvider] = None + kd_config: Optional["ModelOptDistillConfig"] = None + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Use `convert_to_distillation_provider()` to create an instance of this class." + ) + + def __post_init__(self): + assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." + + shared_attrs = [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "seq_length", + "pipeline_dtype", + ] + for attr in shared_attrs: + if getattr(self, attr) != getattr(self.teacher, attr): + raise ValueError(f"Student and teacher providers must have the same {attr}.") + + # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. + self.cross_entropy_fusion_impl = "native" + + # Hack to dynamically subclass other providers and still use their methods + self._super_class = self.__class__.__bases__[0] + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a ModelOpt DistillationModel based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured ModelOpt DistillationModel instance + """ + if vp_stage is not None: + raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") + + assert self.teacher is not None, "Teacher model must be provided." + student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] + + # Finalize teacher provider before creating model (required for heterogeneous models). + # + # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in + # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during + # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent + # with current parallelism settings and distributed context. Student model creation (above) + # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter + # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) + # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to + # determine layer architecture. Without finalize() in this context, architecture expectations + # don't match checkpoint weights, causing: + # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 + # (expected (2880, 4096), got (3584, 4096)) + # + # Note: This explanation needs to be confirmed yet. + self.teacher.finalize() + + # Hack to get teacher's pre-wrap hooks called to potentially load HF weights + teacher_model = self.teacher.provide_distributed_model( + wrap_with_ddp=False, mixed_precision_wrapper=None + )[0] + + kd_cfg = mtd_mcore.setup_distillation_config( + self.kd_config, student_model.config, teacher_model.config + ) + modelopt_cfg = { + "teacher_model": teacher_model, + "criterion": kd_cfg.criterion, + "loss_balancer": kd_cfg.loss_balancer, + } + kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) + mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) + + return kd_model + + def to_cfg_dict(self) -> dict[str, Any]: + """Custom method to save equivalent to the original provider class. + + Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. + There is no need to restore a `DistillationProvider` from the run config file, as + it can always be re-converted using the original student provider. + + Returns: + Dictionary representation of this provider class + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + + # Include all fields from the original provider class (self._super_class), not just DistillationProvider + # This ensures fields like heterogeneous_layers_config_encoded_json are preserved + excluded_fields = {"teacher", "kd_config"} + for field in fields(self._super_class): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Only include if the field exists on this instance (it should, since we converted from the original provider) + if hasattr(self, field.name): + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + # Also include any additional fields from DistillationProvider itself (if any) + for field in fields(self): + if field.name.startswith("_") or field.name in excluded_fields: + continue + # Skip if already included from _super_class + if field.name not in result: + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + + return result + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mirror to teacher if it has that attribute + if hasattr(self.teacher, name): + setattr(self.teacher, name, value) + + +def convert_to_distillation_provider( + student_provider: GPTModelProvider | MambaModelProvider, + teacher_provider: GPTModelProvider | MambaModelProvider, + kd_config: Optional["ModelOptDistillConfig"] = None, +) -> "DistillationProvider": + """Convert a given model provider to a DistillationProvider.""" + + assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( + "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( + "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + + DistillationProvider.__bases__ = (type(student_provider),) + student_provider.__class__ = DistillationProvider + + student_provider.teacher = teacher_provider + student_provider.kd_config = kd_config + student_provider.__post_init__() + + return student_provider diff --git a/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py new file mode 100644 index 0000000000..2838b9ff20 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/export_mbridge_to_hf.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export utilities for Megatron-Bridge checkpoints.""" + +import shutil +from pathlib import Path + +from megatron.bridge import AutoBridge + +from modelopt.torch.utils import print_rank_0 + + +def export_to_hf_and_copy_config( + student_hf_path: str, + checkpoint_dir: str, + train_iters: int, + hf_export_path: str, + hf_model: str, +) -> None: + """ + Export Megatron checkpoint to HuggingFace format and copy config.json from student model. + + TODO: This script should not be needed (manually copying config.json from + student model to exported model). Remove it once export_to_hf() in AutoBridge + supports copying/preserving config.json from student model. + + + Args: + student_hf_path: Path to the original student HuggingFace model (source of config.json) + checkpoint_dir: Base directory where Megatron checkpoints are stored + train_iters: Number of training iterations (used to construct final checkpoint path) + hf_export_path: Directory path where the HuggingFace model will be saved + hf_model: HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct) + """ + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Exporting to HuggingFace format...") + print_rank_0(f"{'=' * 80}\n") + + # Construct path to final checkpoint iteration (format: iter_0000100 for 100 iterations) + final_iter_dir = Path(checkpoint_dir) / f"iter_{train_iters:07d}" + print_rank_0(f"📂 Using final checkpoint: {final_iter_dir}") + + # Use the final iteration directory for export (export_ckpt will validate it exists) + megatron_path = str(final_iter_dir) + + # Create bridge using standard model ID (not AnyModel checkpoint) to avoid sharding structure issues + print_rank_0("🌉 Creating bridge...") + print_rank_0(f" Using model ID: {hf_model}") + bridge = AutoBridge.from_hf_pretrained(hf_model, trust_remote_code=True) + + print_rank_0("📤 Exporting to HuggingFace format...") + # Use strict=False for test_distill_hf.py which uses small models (2 layers) with fewer layers + # than the template model (32 layers). This allows partial exports when some tensors are missing. + # Note: This is NOT needed when running on real compressed puzzletron student models, + # which have the same number of layers as the template model (some may be skipped via no_op + # in block_configs, but all layer tensors are still present in the checkpoint). + bridge.export_ckpt( + megatron_path=megatron_path, + hf_path=hf_export_path, + show_progress=True, + strict=False, # Needed for test_distill_hf.py small models; not needed for real compressed models + ) + + print_rank_0(f"✅ Successfully exported model to: {hf_export_path}") + + # Copy config.json from student model to exported model (preserves block_configs) + student_config_path = Path(student_hf_path) / "config.json" + exported_config_path = Path(hf_export_path) / "config.json" + + print_rank_0(f"📋 Copying config.json from student model: {student_config_path}") + shutil.copy(student_config_path, exported_config_path) + print_rank_0(f"✅ Copied config.json to: {exported_config_path}") + + print_rank_0(f"\n{'=' * 80}") + print_rank_0("Export complete!") diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py new file mode 100644 index 0000000000..b802215298 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py new file mode 100644 index 0000000000..ace20fbf89 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge diff --git a/modelopt/torch/quantization/calib/max.py b/modelopt/torch/quantization/calib/max.py index 94cee406e0..4373fa69d5 100644 --- a/modelopt/torch/quantization/calib/max.py +++ b/modelopt/torch/quantization/calib/max.py @@ -66,15 +66,15 @@ def collect(self, x): if x.device.type == "meta": self._calib_amax = local_amax return + assert not torch.any(torch.isnan(local_amax)), ( + f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}" + ) assert torch.all(local_amax >= 0), ( "detected negative values after abs, could be torch or cuda bug" ) assert not torch.any(torch.isinf(local_amax)), ( f"detected inf values in amax. inf in original tensor: {torch.any(torch.isinf(x))}" ) - assert not torch.any(torch.isnan(local_amax)), ( - f"detected nan values in amax. nan in original tensor: {torch.any(torch.isnan(x))}" - ) if self._calib_amax is None: self._calib_amax = local_amax else: diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 9710d3a4b3..1f439a7e77 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -15,6 +15,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" +import math from collections.abc import Callable import torch @@ -23,7 +24,7 @@ from .. import utils as quant_utils from .calibrator import _Calibrator -__all__ = ["MseCalibrator"] +__all__ = ["MseCalibrator", "NVFP4MSECalibrator"] class MseCalibrator(_Calibrator): @@ -33,7 +34,7 @@ def __init__( self, amax: torch.Tensor, axis: int | tuple | list | None = None, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, @@ -44,25 +45,39 @@ def __init__( Args: amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. - num_steps: Number of amax candidates to try. + step_size: Step size for amax search. The number of steps is computed as + ceil((stop_multiplier - start_multiplier) / step_size) + 1. start_multiplier: Starting multiplier for amax search. stop_multiplier: Ending multiplier for amax search. quant_func: Function that quantizes input tensor given an amax value. - Should have signature: quant_func(x, amax) -> quantized_x. + Should have signature: quant_func(x, amax) -> quantized_x. error_func: Function to compute error between x and xq. - Default is F.mse_loss(x, xq, reduction='none'). + Default is F.mse_loss(x, xq, reduction='none'). """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax - self._num_steps = num_steps + self._step_size = step_size self._start_multiplier = start_multiplier self._stop_multiplier = stop_multiplier + self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1 + self._quant_func = quant_func self._error_func = error_func - self._losses_sum = [None] * num_steps - self._candidate_amaxs = [None] * num_steps + self._losses_sum: list[torch.Tensor | None] | None = None + self._candidates: torch.Tensor | None = None + self._amax: torch.Tensor | None = None - self._amax = None + def _generate_candidates(self, device: torch.device) -> torch.Tensor: + """Generate candidate multipliers. Override in subclasses for different candidate sets.""" + return torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + ) + + def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: + """Compute amax from candidates. Override in subclasses for different amax computation.""" + if candidates.ndim != 0: # Called during final compute amax + candidates = candidates.view_as(self._initial_amax) + return self._initial_amax * candidates @torch.no_grad() def collect(self, x: torch.Tensor): @@ -72,22 +87,22 @@ def collect(self, x: torch.Tensor): x: Input tensor. """ if self._quant_func is None: - raise RuntimeError( - "Quantization function not set. Msecalibrator requires a quant_func to be provided." - ) + raise RuntimeError("Quantization function not set.") x = x.detach().to(dtype=torch.float32) - device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device - ) - # Get reduce axis for per-channel quantization + candidates = self._generate_candidates(device) + if self._candidates is None: + self._candidates = candidates + self._num_steps = len(candidates) + self._losses_sum = [None] * self._num_steps + + assert self._losses_sum is not None reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) - for step, multiplier in enumerate(multipliers): - candidate_amax = self._initial_amax * multiplier + for step, candidate in enumerate(candidates): + candidate_amax = self._compute_candidate_amax(candidate) xq = self._quant_func(x, candidate_amax) if self._error_func is not None: @@ -97,9 +112,6 @@ def collect(self, x: torch.Tensor): loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False) - if self._candidate_amaxs[step] is None: - self._candidate_amaxs[step] = candidate_amax - if self._losses_sum[step] is None: self._losses_sum[step] = loss.clone() else: @@ -107,9 +119,12 @@ def collect(self, x: torch.Tensor): def reset(self): """Reset the stored losses and amax value.""" - self._losses_sum = [None] * self._num_steps - self._candidate_amaxs = [None] * self._num_steps + self._losses_sum = None + self._candidates = None self._amax = None + if self._initial_amax is not None: + del self._initial_amax + self._initial_amax = None @torch.no_grad() def compute_amax(self, verbose: bool = False): @@ -118,49 +133,28 @@ def compute_amax(self, verbose: bool = False): Args: verbose: If True, print the ratio of best_amax to initial_amax. """ - if not any(loss_sum is not None for loss_sum in self._losses_sum): + if self._losses_sum is None or not any(loss is not None for loss in self._losses_sum): return None - # Check if this is per-tensor or per-channel based on the first loss - first_loss_sum = None - for loss_sum in self._losses_sum: - if loss_sum is not None: - first_loss_sum = loss_sum - break - - if first_loss_sum is None: + first_loss = next((loss for loss in self._losses_sum if loss is not None), None) + if first_loss is None: return None - # Collect losses for all steps - losses_per_step = [] + # Stack losses: [num_steps] or [num_steps, num_channels] + losses = [] for step in range(self._num_steps): if self._losses_sum[step] is not None: - losses_per_step.append(self._losses_sum[step]) - # No data for this step, use inf - elif first_loss_sum.ndim == 0: - losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device)) + losses.append(self._losses_sum[step]) + elif first_loss.ndim == 0: + losses.append(torch.tensor(float("inf"), device=first_loss.device)) else: - losses_per_step.append(torch.full_like(first_loss_sum, float("inf"))) + losses.append(torch.full_like(first_loss, float("inf"))) - # Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel - losses_per_step = torch.stack(losses_per_step) - - # Find best step(s): scalar for per-tensor, [num_channels] for per-channel - best_steps = torch.argmin(losses_per_step, dim=0) - - # Stack candidate amaxs and select based on best_steps - candidate_amaxs = torch.stack(self._candidate_amaxs) - - if first_loss_sum.ndim == 0: - # Per-tensor case: best_steps is a scalar - self._amax = self._candidate_amaxs[best_steps.item()] - else: - # Per-channel case: best_steps is a tensor - num_channels = best_steps.shape[0] - self._amax = candidate_amaxs[ - best_steps, torch.arange(num_channels, device=best_steps.device) - ] - self._amax = self._amax.reshape(self._initial_amax.shape) + losses = torch.stack(losses) + best_indices = torch.argmin(losses, dim=0) + assert self._candidates is not None + best_candidates = self._candidates[best_indices] + self._amax = self._compute_candidate_amax(best_candidates) if verbose: ratio = self._amax / self._initial_amax @@ -175,3 +169,32 @@ def compute_amax(self, verbose: bool = False): ) return self._amax + + +class NVFP4MSECalibrator(MseCalibrator): + """Per-block FP8 scale sweep calibrator for NVFP4 static quantization.""" + + def __init__( + self, + amax: torch.Tensor, # per_block_amax shape [num_blocks] + global_amax: torch.Tensor, # scalar + axis: int | tuple | list | None = None, + quant_func: Callable | None = None, + error_func: Callable | None = None, + ): + """Initialize NVFP4 MSE calibrator with per-block and global amax.""" + super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) + self._global_amax = global_amax + + def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: + if candidates.ndim != 0: # Called during final compute amax + candidates = candidates.view_as(self._initial_amax) + return torch.ones_like(self._initial_amax) * self._global_amax * candidates + + def _generate_candidates(self, device: torch.device) -> torch.Tensor: + """Generate 126 valid FP8 E4M3 scale candidates.""" + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + fp8_values = fp8_values[valid_mask] + return fp8_values / 448.0 diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index ea8fb6217c..618c22606d 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -156,12 +156,21 @@ "*mlp.gate.*": {"enable": False}, # Skip the MOE router "*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router "*linear_attn.conv1d*": {"enable": False}, - "*mixer.conv1d*": {"enable": False}, + "*mixer.conv1d*": {"enable": False}, # Skip mamba conv1d "*output_layer*": {"enable": False}, "output.*": {"enable": False}, "default": {"enable": False}, } +_mamba_moe_disabled_quantizer_cfg = { + "*fc1_latent_proj*": {"enable": False}, # Skip Latent MOE + "*fc2_latent_proj*": {"enable": False}, # Skip Latent MOE + "*q_proj*": {"enable": False}, # Skip QKV Linear + "*k_proj*": {"enable": False}, # Skip QKV Linear + "*v_proj*": {"enable": False}, # Skip QKV Linear + "*o_proj*": {"enable": False}, # Skip QKV Output Projection +} + INT8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, @@ -198,6 +207,28 @@ "algorithm": "max", } +MAMBA_MOE_FP8_AGGRESSIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + +MAMBA_MOE_FP8_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, + "*input_quantizer": {"num_bits": (4, 3), "axis": None}, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + FP8_PER_CHANNEL_PER_TOKEN_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": 0}, @@ -234,6 +265,7 @@ "algorithm": "max", } + INT4_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -387,6 +419,69 @@ "algorithm": "max", } +NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "local_hessian", + "fp8_scale_sweep": True, + }, +} + +MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + }, + "algorithm": "max", +} +MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + **_mamba_moe_disabled_quantizer_cfg, + "*mixer.in_proj*": {"enable": False}, # Skip mamba linear + "*mixer.out_proj*": {"enable": False}, # Skip mamba linear + }, + "algorithm": "max", +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { @@ -652,6 +747,10 @@ "NVFP4_MLP_WEIGHT_ONLY_CFG", "MXFP4_MLP_WEIGHT_ONLY_CFG", "NVFP4_MLP_ONLY_CFG", + "MAMBA_MOE_NVFP4_CONSERVATIVE_CFG", + "MAMBA_MOE_NVFP4_AGGRESSIVE_CFG", + "MAMBA_MOE_FP8_CONSERVATIVE_CFG", + "MAMBA_MOE_FP8_AGGRESSIVE_CFG", } BiasType = Literal["static", "dynamic"] @@ -742,7 +841,7 @@ def validate_num_bits(self): raise ValueError( "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." ) - elif num_bits != (4, 3) and ( + elif num_bits not in [(4, 3), (2, 1)] and ( block_sizes is None or block_sizes.get("type", None) != "dynamic" ): raise ValueError( @@ -934,14 +1033,20 @@ def validate_calibrator(cls, v, info: ValidationInfo): assert v in ["max", "histogram"] return v - rotate: bool = ModeloptField( + rotate: bool | dict[str, bool] = ModeloptField( default=False, - title="""If rotate the input before quantization.""", - description=""""If true, the input of the quantizer will be rotated with a hadamard matrix + title="""Configuration for rotating the input before quantization.""", + description="""Can be a boolean or a dictionary with the following keys: + - "enable": Boolean to enable/disable rotation (default: False) + - "rotate_fp32": Boolean to compute rotation in float32 precision (default: False) + + If a boolean is provided, it is treated as the "enable" value with "rotate_fp32" defaulting to False. + + When enabled, the input of the quantizer will be rotated with a hadamard matrix given by scipy.linalg.hadamard, i.e. ``input = input @ scipy.linalg.hadamard(input.shape[-1]) / sqrt(input.shape[-1])``. - This can be used for ratation based PTQ methods, e.g. QuaRot or SpinQuant. + This can be used for rotation based PTQ methods, e.g. QuaRot or SpinQuant. See https://arxiv.org/abs/2404.00456 for example.""", ) @@ -1017,15 +1122,70 @@ class MseCalibConfig(QuantizeAlgorithmConfig): reconstruction error of a tensor after uniform Q→DQ: s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} + + When fp8_scale_sweep is enabled, step_size is ignored. """ method: Literal["mse"] = ModeloptField("mse") - num_steps: int | None = ModeloptField( - default=10, - ge=1, - title="Number of amax candidates to try.", - description="Number of amax candidates to search over for MSE minimization.", + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", + ) + + start_multiplier: float | None = ModeloptField( + default=0.25, + gt=0.0, + title="Starting multiplier for amax search.", + description="Starting multiplier for amax search range (multiplies initial amax).", + ) + + stop_multiplier: float | None = ModeloptField( + default=4.0, + gt=0.0, + title="Ending multiplier for amax search.", + description="Ending multiplier for amax search range (multiplies initial amax).", + ) + + fp8_scale_sweep: bool | None = ModeloptField( + default=False, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. " + "Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, " + "start_multiplier, and stop_multiplier are ignored.", + ) + + distributed_sync: bool | None = ModeloptField( + default=True, + title="Whether to sync the amax across the distributed processes.", + description="If True, the amax will be synced across the distributed processes.", + ) + + +class LocalHessianCalibConfig(QuantizeAlgorithmConfig): + """Configuration for local Hessian-weighted MSE calibration. + + This algorithm uses activation information to optimize per-block scales for weight + quantization. It minimizes the output reconstruction error by weighting the loss + with the local Hessian matrix computed from input activations. + + The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where: + - ``dw = weight - quantized_weight`` (weight reconstruction error per block) + - ``H = X @ X.T`` is the local Hessian computed from input activations X + + """ + + method: Literal["local_hessian"] = ModeloptField("local_hessian") + + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", ) start_multiplier: float | None = ModeloptField( @@ -1042,12 +1202,35 @@ class MseCalibConfig(QuantizeAlgorithmConfig): description="Ending multiplier for amax search range (multiplies initial amax).", ) + fp8_scale_sweep: bool | None = ModeloptField( + default=True, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep over all 128 possible FP8 E4M3 scale values " + "for NVFP4 per-block quantization instead of using multipliers. " + "This is the recommended setting for NVFP4 quantization.", + ) + + block_size: int | None = ModeloptField( + default=16, + gt=0, + title="Block size for local Hessian computation.", + description="The block size used for computing the local Hessian matrix. " + "This should match the block size used in the quantization config. " + "Default is 16 for NVFP4.", + ) + distributed_sync: bool | None = ModeloptField( default=True, title="Whether to sync the amax across the distributed processes.", description="If True, the amax will be synced across the distributed processes.", ) + debug: bool | None = ModeloptField( + default=False, + title="Debug mode.", + description="If True, module's local Hessian metadata will be kept as a module attribute.", + ) + class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): """The config for ``smoothquant`` algorithm (SmoothQuant). @@ -1179,6 +1362,44 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) +class GPTQLiteConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + The default values are taken from the official GPTQ implementation: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L35 + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + + + """ + + method: Literal["gptq_lite"] = ModeloptField("gptq_lite") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="""The block size for GPTQ weight update, which must be a multiple of the + group_size used in the quantization.""", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="""The path to the Hessian state file. If hessian path exists, we load from + hessian file instead of recomputing them.""", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index c93ea546f3..f7ef704eec 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -35,6 +35,7 @@ _QuantizeExportConfig, ) from .nn import ( + NVFP4StaticQuantizer, QuantModule, QuantModuleRegistry, SequentialQuantizer, @@ -125,6 +126,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, TensorQuantizer): name = get_unwrapped_name(name, model) + state = quantizer_state_dict[name] + # TODO: Add a registry for TensorQuantizers and avoid this manual conversion. + if state.get("_is_nvfp4_static_quantizer") and not isinstance( + module, NVFP4StaticQuantizer + ): + NVFP4StaticQuantizer.from_tensor_quantizer(module) module.set_from_modelopt_state(quantizer_state_dict[name]) for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 53651bbcce..1f3346ea98 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,8 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQLiteConfig, + LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, QuantizeAlgoCfgType, @@ -55,7 +57,15 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant +from .model_calib import ( + awq, + gptq_lite, + local_hessian_calibrate, + max_calibrate, + mse_calibrate, + smoothquant, + svdquant, +) __all__ = ["BaseCalibrateModeDescriptor"] @@ -376,6 +386,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: _calib_func = mse_calibrate +@CalibrateModeRegistry.register_mode +class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for local Hessian-weighted MSE calibration algorithm. + + This algorithm uses activation information to optimize per-block scales for weight + quantization by minimizing output reconstruction error instead of weight reconstruction error. + """ + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return LocalHessianCalibConfig + + _calib_func = local_hessian_calibrate + + @CalibrateModeRegistry.register_mode class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor): """Mode for smoothquant calibration algorithm.""" @@ -439,3 +465,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" return restore_svdquant_model + + +@CalibrateModeRegistry.register_mode +class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQLiteConfig + + _calib_func = gptq_lite diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78d..350af429be 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -16,22 +16,26 @@ """Calibration utilities.""" import math +import os import warnings +from collections.abc import Callable from functools import partial import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.perf import get_used_gpu_mem_fraction -from .calib import MseCalibrator +from .calib import MseCalibrator, NVFP4MSECalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context -from .nn import QuantModule, SequentialQuantizer, TensorQuantizer +from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import ( disable_calib, enable_fake_quant, @@ -41,10 +45,11 @@ is_quantized_linear, is_quantized_row_parallel_linear, quantizer_attr_names, + reduce_amax, weight_attr_names, ) -__all__ = ["awq", "max_calibrate", "smoothquant", "svdquant"] +__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] def weight_only_quantize(model: nn.Module): @@ -62,14 +67,47 @@ def weight_only_quantize(model: nn.Module): seen_modules.add(module) +def _has_expert_parallelism(module: nn.Module) -> bool: + """Check if module has expert parallelism enabled.""" + ps = getattr(module, "parallel_state", None) + return ps is not None and ps.expert_model_parallel_group.is_initialized() + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" + if isinstance(quantizer, SequentialQuantizer): + for _q in quantizer: + _check_moe_calibration_complete(_q, parallel_state) + return + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + has_amax = getattr(quantizer, "_amax", None) is not None + amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during calibration. " + "Increase --calib-size to ensure all experts see calibration data." + ) + + @torch.no_grad() -def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): +def max_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync=True, +): """Calibrate the model using max. Args: model: Model to be calibrated. forward_loop: A callable which takes the model as argument and forwards calibration data through the model. + distributed_sync: Whether to sync input_quantizer amax across distributed processes. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -81,9 +119,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Sync input_quantizer amax across local experts within each rank (for SequentialMLP) + for name, module in model.named_modules(): + if hasattr(module, "layer_sync_moe_local_experts_amax"): + module.layer_sync_moe_local_experts_amax() + if not distributed_sync: return + # Check MoE calibration completeness before sync + for name, module in model.named_modules(): + if isinstance(module, QuantModule) and _has_expert_parallelism(module): + for child in module.children(): + if isinstance(child, (TensorQuantizer, SequentialQuantizer)): + _check_moe_calibration_complete(child, module.parallel_state) + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): @@ -95,13 +145,13 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -156,7 +206,6 @@ def sync_quantizer_amax_across_tp( axes_for_sync=[None, -1], parallel_state=module.parallel_state, ) - sync_quantizer_amax_across_tp( module.weight_quantizer, name, @@ -182,10 +231,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) @@ -197,14 +242,39 @@ def sync_quantizer_amax_across_tp( ) +def _mse_quant_func(x, amax, quantizer): + """Quantization function for MSE calibration.""" + original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + quantizer._amax = amax + + with ( + enable_quant(quantizer), + disable_calib(quantizer), + enable_fake_quant(quantizer), + ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) + xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) + + if original_amax is not None: + quantizer._amax = original_amax + else: + delattr(quantizer, "_amax") + + return xq + + @torch.no_grad() def mse_calibrate( model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = False, ): """Calibrate the model using MSE-based amax search. @@ -217,9 +287,13 @@ def mse_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync amax across distributed processes. - num_steps: Number of amax candidates to try (default: 10). + step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization instead of using multipliers. + This is specifically designed for optimizing the FP8-quantized + per-block scales in NVFP4 format (default: False). See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -228,57 +302,392 @@ def mse_calibrate( max_calibrate(model, forward_loop, distributed_sync) # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - for name, module in model.named_modules(): + # and identify weight quantizers + weight_quantizers = [] + seen_modules = set() + + for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: - # Static block quantization is not supported by MseCalibrator - if module.is_static_block_quant: - raise ValueError( - f"MSE calibration does not support static block quantization. " - f"Found static block quantization at {name}." - ) if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - def quant_func(x, amax, quantizer=module): - original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None - quantizer._amax = amax + is_nvfp4_static = ( + module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) - with ( - enable_quant(quantizer), - disable_calib(quantizer), - enable_fake_quant(quantizer), - ): - xq = quantizer(x) + if is_nvfp4_static: + # Compute and set global_amax + global_amax = reduce_amax(initial_amax, axis=None) - if original_amax is not None: - quantizer._amax = original_amax - else: - delattr(quantizer, "_amax") + # Convert to NVFP4StaticQuantizer in-place + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) - return xq + if fp8_scale_sweep and is_nvfp4_static: + # Replace calibrator with NVFP4MSECalibrator + module._calibrator = NVFP4MSECalibrator( + amax=initial_amax, + axis=module._calibrator._axis, + global_amax=module.global_amax, + quant_func=partial(_mse_quant_func, quantizer=module), + ) + continue + + if fp8_scale_sweep and not is_nvfp4_static: + warnings.warn( + f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static " + "block quantization. fp8_scale_sweep will be ignored for this quantizer." + ) # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, axis=module._calibrator._axis, - num_steps=num_steps, + step_size=step_size, start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, - quant_func=quant_func, + quant_func=partial(_mse_quant_func, quantizer=module), ) - # Step 3: Collect data with MSE calibrators - enable_stats_collection(model) + # Identify weight quantizers by checking if they have corresponding weight parameters + for name, parent_module in model.named_modules(): + if parent_module in seen_modules: + continue + for weight_name in weight_attr_names(parent_module): + weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer + weight_quantizer = getattr(parent_module, weight_quantizer_name, None) + if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: + if getattr(weight_quantizer, "_calibrator", None) is not None: + weight_quantizers.append((parent_module, weight_name, weight_quantizer)) + seen_modules.add(parent_module) + + # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation + # This prevents massive memory accumulation seen in large models + for idx, (parent_module, weight_name, weight_quantizer) in enumerate( + tqdm(weight_quantizers, desc="MSE weight calibration") + ): + # Enable calibration mode for the weight quantizer + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + with enable_weight_access_and_writeback(parent_module, model): + weight = getattr(parent_module, weight_name) + weight_quantizer(weight) + + # IMMEDIATELY compute amax and reset calibrator to free memory + cal = getattr(weight_quantizer, "_calibrator", None) + if cal is not None and cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Synchronize ALL CUDA devices before resetting to ensure all async operations complete + # This is critical for multi-GPU setups where tensors may be on different devices + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + + if cal is not None and hasattr(cal, "reset"): + cal.reset() + + if (idx + 1) % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + + # TODO: Sync amax across distributed processes + + +@torch.no_grad() +def local_hessian_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync: bool = True, + step_size: float = 0.1, + start_multiplier: float = 0.25, + stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = True, + block_size: int = 16, + debug: bool = False, +): + """Calibrate the model using local Hessian-weighted MSE search. + + Instead of minimizing weight error ``||W - Wq||²``, this minimizes Hessian-weighted error + ``loss = (W - Wq)ᵀ H (W - Wq)`` where ``H = X @ X.T`` approximates output reconstruction + error ``||WX - WqX||²``. + + Per-block Hessians of shape ``(cin // block_size, block_size, block_size)`` are accumulated + during forward pass and used to weight the MSE loss during scale search. + + Args: + model: Model to be calibrated. + forward_loop: A callable which takes the model as argument and + forwards calibration data through the model. Required for this algorithm. + distributed_sync: Whether to sync amax across distributed processes. + step_size: Step size for amax search (default: 0.1). + start_multiplier: Starting multiplier for amax search (default: 0.25). + stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization (default: True). + block_size: Block size for local Hessian computation (default: 16). + debug: If True, keep the local Hessian metadata on modules. + + See :class:`LocalHessianCalibConfig ` + for details on the configuration options. + """ if forward_loop is None: - weight_only_quantize(model) - else: - forward_loop(model) + warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") + return - # Step 4: Compute optimal amax and load it - finish_stats_collection(model, method="mse") + class LocalHessianHelper: + """Helper class to collect activations and compute local Hessian per module.""" - # TODO: Sync amax across distributed processes + cache_mode: bool = False + + def __init__(self, module, name): + self.name = name + self.module = module + self.weight_shape = module.weight.shape # (cout, cin) + self.cout, self.cin = self.weight_shape + self.block_size = block_size + self.num_blocks_per_cin = self.cin // block_size + self.is_enabled = True + + # Accumulated Hessian per block: (cin // block_size, block_size, block_size) + self.hessian_per_block = torch.zeros( + self.num_blocks_per_cin, + block_size, + block_size, + dtype=torch.float32, + device=module.weight.device, + ) + self.num_samples = 0 + + def setup(self): + """Set up the forward hook to collect activations.""" + module = self.module + bind_forward_method(module, forward, "_forward_no_local_hessian") + + # Check if cin is divisible by block_size + if self.cin % self.block_size != 0: + warnings.warn( + f"Module {self.name}: input features ({self.cin}) not divisible by " + f"block_size ({self.block_size}). Skipping local Hessian for this module." + ) + self.is_enabled = False + + def cleanup(self): + """Clean up the forward hook.""" + unpatch_forward_method(self.module, "_forward_no_local_hessian") + if not debug: + if hasattr(self.module, "hessian_helper"): + delattr(self.module, "hessian_helper") + + def accumulate_hessian(self, input_tensor: torch.Tensor): + """Accumulate local Hessian from input activations. + + Args: + input_tensor: Input tensor of shape (..., cin) + """ + if not self.is_enabled: + return + + # Flatten to (num_tokens, cin) + x = input_tensor.reshape(-1, self.cin).T # (cin, num_tokens) + x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) # (num_blocks, bs, n) + + # Compute H = X @ X.T for each block and accumulate + hessian_batch = (x @ x.transpose(-1, -2)).to(torch.float32) + self.hessian_per_block += hessian_batch + self.num_samples += input_tensor.numel() // self.cin + + def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + """Get the local Hessian error function for MSE calibration.""" + cout = self.cout + bs = self.block_size + # Normalize hessian by number of samples + hessian = self.hessian_per_block / max(self.num_samples, 1) + + def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + """Compute local Hessian-weighted error.""" + original_shape = x.shape + # Reshape to (cout, num_blocks_per_cin, block_size) + dw = (x - xq).view(cout, -1, bs) + # Use einsum to avoid materializing cout-repeated Hessian + # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks) + block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw) + block_loss = block_loss.reshape(-1) + error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) + return error + + return local_hessian_error + + def forward(self, input, *args, **kwargs): + """Custom forward that collects activations in cache mode.""" + if LocalHessianHelper.cache_mode and self.hessian_helper.is_enabled: + # Get local tensor from DTensor if applicable + input_local = input.to_local() if hasattr(input, "to_local") else input + self.hessian_helper.accumulate_hessian(input_local) + + # Forward without quantization during caching + if LocalHessianHelper.cache_mode: + self.weight_quantizer.disable() + out = self._forward_no_local_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return self._forward_no_local_hessian(input, *args, **kwargs) + + # First, run max_calibrate on the whole model to get initial amax for all quantizers + # This calibrates both weight_quantizer and input_quantizer with max calibration + print_rank_0("local_hessian: Running max calibration for all quantizers...") + max_calibrate(model, forward_loop, distributed_sync) + + # Setup helpers for all quantized linear modules + name_to_module = dict(model.named_modules()) + weight_quantizers_info = [] + all_patched_modules = [] # Track all modules for cleanup (including disabled ones) + + for name, module in name_to_module.items(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + with enable_weight_access_and_writeback(module, model, name_to_module): + module.hessian_helper = LocalHessianHelper(module, name) + module.hessian_helper.setup() + all_patched_modules.append((name, module)) + if module.hessian_helper.is_enabled: + weight_quantizers_info.append((name, module)) + + # Cache activations by running forward loop + LocalHessianHelper.cache_mode = True + print_rank_0("local_hessian: Caching activations and computing local Hessian...") + forward_loop(model) + + # TODO(fridah-nv): Sync Hessian across distributed processes if needed + + # Replace calibrators with MseCalibrator using local Hessian error function + print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + helper = module.hessian_helper + + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + continue + + initial_amax = weight_quantizer._amax.clone().detach() + + def quant_func(x, amax, quantizer=weight_quantizer): + original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + quantizer._amax = amax + + with ( + enable_quant(quantizer), + disable_calib(quantizer), + enable_fake_quant(quantizer), + ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) + xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) + + if original_amax is not None: + quantizer._amax = original_amax + else: + delattr(quantizer, "_amax") + + return xq + + is_nvfp4_static = ( + weight_quantizer.is_static_block_quant + and weight_quantizer._num_bits == (2, 1) + and weight_quantizer._block_sizes is not None + and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) + ) + + if is_nvfp4_static: + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) + + error_func = helper.get_error_func() + + if fp8_scale_sweep and is_nvfp4_static: + weight_quantizer._calibrator = NVFP4MSECalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + global_amax=weight_quantizer.global_amax, + quant_func=quant_func, + error_func=error_func, + ) + else: + weight_quantizer._calibrator = MseCalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + quant_func=quant_func, + error_func=error_func, + ) + + # Free cached memory before heavy calibration + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Process weights ONE AT A TIME with immediate amax computation and cleanup + weight_list = [ + (name, module) + for name, module in weight_quantizers_info + if module.weight_quantizer._calibrator is not None + ] + + for idx, (name, module) in enumerate(weight_list): + weight_quantizer = module.weight_quantizer + cal = weight_quantizer._calibrator + + # Step 1: Calibrate this weight + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + with enable_weight_access_and_writeback(module, model, name_to_module): + weight = module.weight + weight_quantizer(weight) + + # Step 2: IMMEDIATELY compute amax (before calibration data grows) + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Step 3: Sync all devices and reset calibrator for next weight + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + + if hasattr(cal, "reset"): + cal.reset() + + if (idx + 1) % 10 == 0 and torch.cuda.is_available(): + torch.cuda.empty_cache() + + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + + # Cleanup and free memory + LocalHessianHelper.cache_mode = False + for name, module in all_patched_modules: + module.hessian_helper.cleanup() + + print_rank_0("local_hessian: Calibration complete.") def enable_stats_collection(model: nn.Module): @@ -292,7 +701,7 @@ def enable_stats_collection(model: nn.Module): module.disable() -def finish_stats_collection(model: nn.Module, method: str | None = None): +def finish_stats_collection(model: nn.Module, method: str | None = None, **kwargs): """Finish stats collection for all quantizers in the model.""" for _, module in model.named_modules(): if not isinstance(module, TensorQuantizer) or module._disabled: @@ -300,15 +709,11 @@ def finish_stats_collection(model: nn.Module, method: str | None = None): cal = getattr(module, "_calibrator", None) if cal and not getattr(module, "_dynamic", False): - if method in {"mse", "entropy"}: + if method in {"entropy"}: if cal.compute_amax(method) is not None: - if method == "entropy": - module.load_calib_amax("entropy") - else: - module.load_calib_amax() - elif cal.compute_amax() is not None: - # Max calibrator - module.load_calib_amax() + module.load_calib_amax("entropy", **kwargs) + elif cal.compute_amax(**kwargs) is not None: + module.load_calib_amax(**kwargs) if module.bias_calibrator is not None and module.bias_type == "static": module.load_calib_bias() @@ -638,6 +1043,8 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None): def update_loss(self, out, out_actual, alpha): out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual out = out[0] if isinstance(out, tuple) else out + out = out.to_local() if hasattr(out, "to_local") else out + out_actual = out_actual.to_local() if hasattr(out_actual, "to_local") else out_actual loss = (out - out_actual).float().pow(2).mean() self.awq_lite.loss[alpha] += loss.to(self.awq_lite.loss[alpha].device) @@ -1032,6 +1439,30 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz return blocksize +def svd(weight, rank): + original_device = weight.device + original_dtype = weight.dtype + weight_f64 = weight.to(dtype=torch.float64, device=original_device) + u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) + us = u[:, :rank] * s[:rank] + vt = vt[:rank] + us = us.to(device=original_device, dtype=original_dtype) + vt = vt.to(device=original_device, dtype=original_dtype) + if us.shape[1] < rank or vt.shape[0] < rank: + warnings.warn( + "The low-rank dimensions do not match the layer dimensions. " + "Please verify your configuration and model settings. " + f"Rank is {us.shape[1]} and {vt.shape[0]}" + ) + us_temp = torch.zeros((us.shape[0], rank), dtype=us.dtype, device=us.device) + vt_temp = torch.zeros((rank, vt.shape[1]), dtype=vt.dtype, device=vt.device) + us_temp[:, : us.shape[1]] = us + vt_temp[: vt.shape[0], :] = vt + us = us_temp + vt = vt_temp + return us, vt + + @torch.no_grad() def svdquant( model: nn.Module, @@ -1053,25 +1484,9 @@ def svdquant( def postprocess(module, name): print_rank_0(f"SVD {name}") weight = module.weight.data - original_device = weight.device - original_dtype = weight.dtype - weight_f64 = weight.to(dtype=torch.float64, device=original_device) - u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) - if u.shape[1] < lowrank or vt.shape[0] < lowrank: - warnings.warn( - "The low-rank dimensions do not match the layer dimensions. " - "Please verify your configuration and model settings. " - f"SVD will be skipped for this layer {name}." - ) - return - us = u[:, :lowrank] * s[:lowrank] - vt = vt[:lowrank] - module.weight_quantizer.svdquant_lora_a = vt.to( - dtype=original_dtype, device=original_device - ) - module.weight_quantizer.svdquant_lora_b = us.to( - dtype=original_dtype, device=original_device - ) + us, vt = svd(weight, lowrank) + module.weight_quantizer.svdquant_lora_a = vt + module.weight_quantizer.svdquant_lora_b = us module.weight.data.sub_( module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a ) @@ -1086,3 +1501,321 @@ def postprocess(module, name): with enable_weight_access_and_writeback(module, model): postprocess(module, name) max_calibrate(model, forward_loop) + + +def _print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): + """Print relative mean squared error between quantized and original weights. + + Computes the Hessian-weighted relative MSE between quantized and original weights, + providing a measure of quantization quality. This metric is adapted from the GPTQ + repository. + + Args: + q (torch.Tensor): Quantized weight tensor + w (torch.Tensor): Original weight tensor + h (torch.Tensor): Hessian matrix used for weighting the error + module_name (str): Name of the module for logging purposes + Note: + Implementation adapted from the GPTQ repository: + https://github.com/IST-DASLab/FP-Quant + """ + delta = q - w + mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) + print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") + + +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. + + Args: + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + Returns: + Tuple of (updated_hessian, new_sample_count) + """ + batch_size = input.shape[0] + + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size + + # Compute outer product: H += (2/n_samples) * X @ X^T + # where X is the flattened input reshaped to (features, batch*seq) + input_flat = input.reshape(-1, input.shape[-1]).t().float() + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) + + return hessian, n_samples + + +def prepare_hessian_inverse(h, weight, percdamp): + """Prepare inverse Hessian with dead neuron handling and damping. + + Args: + h: Hessian matrix to update + weight: Weight tensor to prepare Hessian for + percdamp: Damping percentage for Hessian diagonal + Returns: + h_inv: Inverse Hessian matrix + Implementation adapted from the FP-Quant repository: + https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200 + """ + h = h.clone() + # Handle dead neurons (zero weight columns) + # Get columns with all zeros in weight + zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1) + + # Zero out entire rows and columns in Hessian for dead neurons + h[zero_cols, :] = 0 + h[:, zero_cols] = 0 + h[zero_cols, zero_cols] = 1 + + # Add damping to diagonal + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print("Warning: Hessian is not positive definite, using identity matrix") + h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + return h_inv + + +def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): + """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. + + Args: + full_weight: The full weight tensor (needed for INT4 quantization) + block_start: Starting column index of the block + block_end: Ending column index of the block + h_inv: Hessian inverse + quantizer: The quantizer to apply + Returns: + quantized_block: Quantized weights for this block + losses: Quantization losses per element + errors: Accumulated errors for propagation + """ + # Extract the block we're working on + block_weight = full_weight[:, block_start:block_end] + block_hinv = h_inv[block_start:block_end, block_start:block_end] + block_size = block_end - block_start + + quantized_block = torch.zeros_like(block_weight) + losses = torch.zeros_like(block_weight) + errors = torch.zeros_like(block_weight) + + # We perform column-wise update for GPTQ within the block + group_size = 1 + + for group_start in range(0, block_size, group_size): + group_end = min(group_start + group_size, block_size) + group_cols = slice(group_start, group_end) + # Get current column and its Hessian inverse diagonal + weight_col = block_weight[:, group_cols] + hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) + + # Quantize using the full weight, then extract the columns we need + quantized_full = quantizer(full_weight) + quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] + quantized_block[:, group_cols] = quantized_cols + + # Compute quantization error and loss + error = (weight_col - quantized_cols) / hinv_diag + losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 + errors[:, group_cols] = error + + # Propagate error to remaining columns in block + block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] + full_weight[:, block_start:block_end] = block_weight + + return quantized_block, losses, errors + + +def blockwise_weight_update(module, h, block_size, percdamp): + """Update module weights using GPTQ-style blockwise quantization. + + Args: + module: Neural network module with weight and weight_quantizer + H: Hessian matrix (d x d) + block_size: Size of blocks to process at once + percdamp: Damping percentage for Hessian diagonal + """ + weight = module.weight.data.float().clone() + _, num_cols = weight.shape + + # Preprocess Hessian: handle dead neurons and add damping + h_inv = prepare_hessian_inverse(h, weight, percdamp) + + # Initialize output tensors + quantized_weight = torch.zeros_like(weight) + losses = torch.zeros_like(weight) + + # Process weights in blocks + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + + quantized_block, block_losses, block_errors = quantize_block( + weight, block_start, block_end, h_inv, module.weight_quantizer + ) + # Store results + quantized_weight[:, block_start:block_end] = quantized_block + losses[:, block_start:block_end] = block_losses + + # Propagate errors to remaining weights + weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + + # Print relative mse error + _print_relative_mse_error(quantized_weight, module.weight.float(), h, module.name) + # Update module weights + module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + + +def gptq_lite( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + percdamp: float = 0.01, + block_size: int = 128, + hessian_state_path: str | None = None, +): + """GPTQ-lite quantization - a simplified GPTQ variant. + + Key differences from GPTQ: + - Layers are quantized in parallel (not sequentially with updated activations) + - Uses group-wise updates instead of column-wise updates + + Args: + model: Model to be calibrated. + forward_loop: Callable that forwards calibration data through the model. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + hessian_state_path: Path to save/load Hessian state. If None, compute without saving. + If path exists, load from it. If path doesnt exist then save computed hessians to path. + + See :class:`GPTQLiteConfig ` for + details on the remaining arguments. + + Note: This feature is currently experimental and may not translate to improved accuracy as expected. + """ + # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} + hessian_state = {} + + def initialize_hessian_state(tensor_mapping): + """Initialize hessian state with zeros.""" + for name, (shape, device) in tensor_mapping.items(): + # Use CPU if GPU memory is tight + target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), + "n_samples": 0, + } + + def load_hessian_state(path, tensor_mapping): + """Load hessian state from file.""" + print_rank_0(f"Loading hessian state from {path}") + loaded_state = torch.load(path, map_location="cpu") + + for name, (shape, device) in tensor_mapping.items(): + if name not in loaded_state: + raise KeyError(f"Layer '{name}' not found in loaded hessian state") + + # Move to appropriate device based on memory + target_device = "cpu" if get_used_gpu_mem_fraction(device) > 0.65 else device + hessian_state[name] = { + "hessian": loaded_state[name]["hessian"].to(target_device), + "n_samples": loaded_state[name]["n_samples"], + } + + print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") + + def save_hessian_state(path): + """Save hessian state to file.""" + print_rank_0(f"Saving hessian state to {path}") + try: + # Move to CPU for saving + cpu_state = { + name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} + for name, state in hessian_state.items() + } + + os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) + torch.save(cpu_state, path) + print_rank_0(f"Successfully saved hessian state to {path}") + except Exception as e: + print_rank_0(f"Error saving hessian state: {e}") + print_rank_0("Continuing execution...") + + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + + # Phase 1: Collect statistics for quantizers + max_calibrate(model) + + # Phase 2: Build tensor mapping for all quantized layers + tensor_mapping = {} + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[-1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + # Phase 3: Load or compute Hessians + hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) + save_hessians = hessian_state_path is not None and not hessian_exists + + if hessian_exists: + print_rank_0(f"Loading hessian state from {hessian_state_path}") + load_hessian_state(hessian_state_path, tensor_mapping) + else: + if forward_loop is None: + raise ValueError("forward_loop must be provided when computing Hessians") + + # Initialize hessian state + initialize_hessian_state(tensor_mapping) + + # Register hooks to collect activations + handles = [] + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward loop to compute hessians + print_rank_0("Computing Hessian matrices...") + forward_loop(model) + + for handle in handles: + handle.remove() + + # Save if configured + if save_hessians: + try: + save_hessian_state(hessian_state_path) + except Exception as e: + print_rank_0(f"Error saving hessian state: {e}") + print_rank_0("Continuing execution...") + + # Phase 4: Update weights using computed Hessians + print_rank_0("Updating weights using GPTQ-lite algorithm...") + + quantized_modules = [ + (name, module) + for name, module in model.named_modules() + if is_quantized_linear(module) and module.weight_quantizer.is_enabled + ] + + # Perform blockwise weight updates + for name, module in tqdm(quantized_modules, desc="Quantizing layers"): + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + # Delete hessian state to free memory + del hessian_state[module.name] + torch.cuda.empty_cache() + + print_rank_0("GPTQ-lite quantization completed successfully") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index a14469326e..0b40de8ab2 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): @atomic_print -def print_quant_summary(model: nn.Module): +def print_quant_summary(model: nn.Module, output_dir: str | None = None): """Print summary of all quantizer modules in the model.""" - count = 0 - for name, mod in model.named_modules(): - if isinstance(mod, TensorQuantizer): - print(f"{name:80} {mod}") - count += 1 - print(f"{count} TensorQuantizers found in model") + lines = [ + f"{name:80} {mod}" + for name, mod in model.named_modules() + if isinstance(mod, TensorQuantizer) + ] + lines.append(f"{len(lines)} TensorQuantizers found in model") + + if output_dir: + path = ( + output_dir.joinpath(".quant_summary.txt") + if hasattr(output_dir, "joinpath") + else f"{output_dir}/.quant_summary.txt" + ) + with open(path, "w", encoding="utf-8") as f: + f.write("\n".join(lines) + "\n") + print(f"\033[1mQuant summary saved to {path}\033[0m") + else: + print("\n".join(lines)) def fold_weight(model: nn.Module): diff --git a/modelopt/torch/quantization/nn/functional.py b/modelopt/torch/quantization/nn/functional.py index df8bcbbcd1..0beb7c9563 100644 --- a/modelopt/torch/quantization/nn/functional.py +++ b/modelopt/torch/quantization/nn/functional.py @@ -93,7 +93,7 @@ def backward(ctx, grad_outputs): return fast_hadamard_transform.hadamard_transform(grad_outputs) # type: ignore[name-defined] -def normalized_hadamard_transform(inputs): +def normalized_hadamard_transform(inputs, rotate_fp32=False): """Normalized fast hadamard transform.""" global fast_hadamard_transform try: @@ -104,6 +104,10 @@ def normalized_hadamard_transform(inputs): "`pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git`" ) - return FastHadamardTransform.apply(inputs) / torch.sqrt( + dtype = inputs.dtype + if rotate_fp32: + inputs = inputs.to(torch.float32) + outputs = FastHadamardTransform.apply(inputs) / torch.sqrt( torch.tensor(inputs.shape[-1], dtype=torch.float32) ) + return outputs.to(dtype) if rotate_fp32 else outputs diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8a..a792b6429c 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -17,10 +17,13 @@ import contextlib import warnings +from typing import Any import torch +import torch.nn as nn from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.utils.distributed import ParallelState from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR from ...utils import is_torch_export_mode @@ -35,7 +38,55 @@ class QuantModule(DynamicModule): - """A base class for quantized modules.""" + """A base class for quantized modules. + + In addition, the class also provides ``parallel_state`` attribute that can be used to access + the parallel state of the module. + """ + + _parallel_state: ParallelState + + @classmethod + @torch.no_grad() + def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule": + """Convert the module to a dynamic module.""" + module = super().convert(module, **setup_kwargs) + + # setup parallel state now that the module is converted + if module.parallel_state is None: + module._initialize_parallel_state() + + return module + + @property + def parallel_state(self) -> ParallelState | None: + """Return the parallel state of the quant module.""" + return getattr(self, "_parallel_state", None) + + @parallel_state.setter + def parallel_state(self, parallel_state: ParallelState): + """Set the parallel state of the dynamic module.""" + assert isinstance(parallel_state, ParallelState), ( + "parallel_state must be a ParallelState object!" + ) + self._parallel_state = parallel_state + + def _initialize_parallel_state(self): + """Initialize the parallel state of the dynamic module. + + This method is called only if the `QuantModule` does not have a `parallel_state` attribute + after `_setup` is called. + """ + if torch.distributed.is_initialized(): + warnings.warn( + f"Distributed training is initialized but no parallel_state is set for {type(self)}. " + "Using default parallel_state which has data_parallel_group set to the default process group and " + "tensor_parallel_group is unspecified. " + "If you are using tensor parallelism for this module, you should set the parallel_state " + "in its `_setup` method." + ) + + self.parallel_state = ParallelState(data_parallel_group=None) def modelopt_post_restore(self, prefix: str = ""): """Post-restore to correctly configure the TensorQuantizer states. @@ -110,7 +161,26 @@ class QuantInputBase(QuantModule): def forward(self, input, *args, **kwargs): """Quantize the input before calling the original forward method.""" input = self.input_quantizer(input) - output = super().forward(input, *args, **kwargs) + # Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824 + if hasattr(self, "_forward_pre_dm"): + pre_fwd = getattr(self, "_forward_pre_dm") + + def _is_forward_in_mro(bound_or_func) -> bool: + # If this is a bound method, compare its underlying function to any `forward` + # implementation in the current MRO. If it matches, it's not an external monkey-patch. + if hasattr(bound_or_func, "__func__"): + fn = bound_or_func.__func__ + for cls in type(self).mro(): + if cls.__dict__.get("forward") is fn: + return True + return False + + if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd): + output = super().forward(input, *args, **kwargs) + else: + output = pre_fwd(input, *args, **kwargs) + else: + output = super().forward(input, *args, **kwargs) if isinstance(output, tuple): return (self.output_quantizer(output[0]), *output[1:]) return self.output_quantizer(output) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 71e8237d76..2caec25656 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -49,15 +49,22 @@ INT4QTensor, INT8QTensor, MXFP4QTensor, + MXFP8QTensor, NF4QTensor, NVFP4QTensor, QTensorWrapper, ) -from ...tensor_quant import dynamic_block_quant, fake_tensor_quant, scaled_e4m3 +from ...tensor_quant import ( + dynamic_block_quant, + fake_tensor_quant, + scaled_e4m3, + static_blockwise_fp4_fake_quant, +) from ...utils import is_torch_export_mode from ..functional import normalized_hadamard_transform __all__ = [ + "NVFP4StaticQuantizer", "SequentialQuantizer", "TensorQuantizer", "TensorQuantizerCache", @@ -149,6 +156,10 @@ class TensorQuantizer(nn.Module): "_padding", # Extra flags added by huggingface "_is_hf_initialized", + # Extra flags added by accelerate + "_hf_hook", + "_old_forward", + "forward", # Extra flags added by deepspeed "ds_external_parameters", "all_parameters", @@ -518,6 +529,20 @@ def is_static_block_quant(self): and self._fake_quant ) + @property + def rotate_is_enabled(self): + """Check if rotate is enabled in quant config.""" + return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate + + @property + def rotate_is_fp32(self): + """Check if rotation needs to be computed in float32.""" + return ( + self._rotate.get("rotate_fp32", False) + if isinstance(self._rotate, dict) and self.rotate_is_enabled + else False + ) + def disable_calib(self): """Disable calibration.""" self._if_calib = False @@ -644,8 +669,32 @@ def _real_quantize(self, inputs): assert self._is_real_quantize_support(), "Real quantization not supported for this format." buffer_to_register = {} - if self._num_bits == (4, 3): - # FP8 quantization + # Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3) + if ( + self._block_sizes + and self._block_sizes.get("scale_bits") == (8, 0) + and self._block_sizes.get("type") == "dynamic" + ): + # MX quantization (MXFP4/MXFP8) + if self._num_bits == (2, 1): + # MXFP4 + outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) + buffer_to_register["_scale"] = scales + elif self._num_bits == (4, 3): + # MXFP8 + assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, ( + f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, " + f"got {self._block_sizes[-1]}" + ) + outputs, scales = MXFP8QTensor.quantize(inputs) + buffer_to_register["_scale"] = scales + else: + raise ValueError( + f"Unsupported MX format: num_bits={self._num_bits}. " + f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8." + ) + elif self._num_bits == (4, 3): + # FP8 quantization (non-MX) # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks # For blockwise quantization, amax will be recomputed in the kernel use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1) @@ -678,18 +727,6 @@ def _real_quantize(self, inputs): buffer_to_register["_scale"] = _scale buffer_to_register["_double_scale"] = _double_scale buffer_to_register["_scale_zeros"] = _scale_zeros - elif ( - self._block_sizes.get("scale_bits") == (8, 0) - and self._block_sizes.get("type") == "dynamic" - ): - # MX quantization - if self._num_bits == (2, 1): - outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1]) - buffer_to_register["_scale"] = scales - else: - raise ValueError( - f"Real quantization for MX {self._num_bits} format is not supported." - ) elif self._block_sizes.get("scale_bits") == (4, 3): # NVFP4 default quantization # Return real quantized tensor and store scales inside TensorQuantizer @@ -973,8 +1010,8 @@ def forward(self, inputs): inputs = inputs * self.pre_quant_scale # Rotating the input - if self._rotate: - inputs = normalized_hadamard_transform(inputs) + if self.rotate_is_enabled: + inputs = normalized_hadamard_transform(inputs, rotate_fp32=self.rotate_is_fp32) if self._disabled: # if quantizer is disabled, we still need to track the input dtype for saving the model @@ -1086,7 +1123,8 @@ def extra_repr(self): if self.pre_quant_scale is not None else "" ) - s += " rotated" if self._rotate else "" + s += " rotated" if self.rotate_is_enabled else "" + s += " (fp32)" if self.rotate_is_fp32 else "" s += ( f" calibrator={self._calibrator.__class__.__name__}" if (self._calibrator is not None) @@ -1221,6 +1259,66 @@ def _set_buffer(self, key, value): self.register_buffer(key, value) +class NVFP4StaticQuantizer(TensorQuantizer): + """TensorQuantizer for NVFP4 static block quantization with two-level scaling. + + Uses _global_amax and inherited _amax for per-block amax values. + """ + + @classmethod + def from_tensor_quantizer( + cls, tq: TensorQuantizer, global_amax: torch.Tensor | None = None + ) -> "NVFP4StaticQuantizer": + """Convert a TensorQuantizer to NVFP4StaticQuantizer in-place. + + Args: + tq: The TensorQuantizer to convert. + global_amax: Optional global amax value to set on the quantizer. + """ + if isinstance(tq, cls): + if global_amax is not None: + tq.global_amax = global_amax + return tq + tq.__class__ = cls + tq._is_nvfp4_static_quantizer = True + if global_amax is not None: + tq.global_amax = global_amax + return tq + + @property + def global_amax(self): + """Return global_amax for quantization.""" + if not hasattr(self, "_global_amax"): + return None + return self._global_amax + + @global_amax.setter + def global_amax(self, value): + if value is None: + if hasattr(self, "_global_amax"): + self._global_amax = None + return + if not isinstance(value, torch.Tensor): + value = torch.tensor(value) + if not hasattr(self, "_global_amax") or self._global_amax is None: + self.register_buffer("_global_amax", value.clone().detach()) + else: + self._global_amax.data.copy_(value.clone().detach().to(self._global_amax.device)) + + def _fake_quantize(self, inputs): + """Fake quantization using two-level scaling with _amax and _global_amax.""" + if self.amax is not None: + return static_blockwise_fp4_fake_quant( + inputs, + self.amax, + self.global_amax, # Can be None, will be computed internally + True, # quantize_block_scales + inputs.dtype, + self._pass_through_bwd, + ) + return super()._fake_quantize(inputs) + + class SequentialQuantizer(nn.Sequential): """A sequential container for :class:`TensorQuantizer` modules. diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index ef90dfda5a..ecd24d81e2 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -41,7 +41,7 @@ from .custom import * with import_plugin("diffusers"): - from .diffusers import * + from .diffusion.diffusers import * with import_plugin("fairscale"): from .fairscale import * @@ -75,3 +75,6 @@ with import_plugin("trl"): from .trl import * + +with import_plugin("fastvideo"): + from .diffusion.fastvideo import * diff --git a/modelopt/torch/quantization/plugins/attention.py b/modelopt/torch/quantization/plugins/attention.py index 643774da7e..2113edea8a 100644 --- a/modelopt/torch/quantization/plugins/attention.py +++ b/modelopt/torch/quantization/plugins/attention.py @@ -207,14 +207,17 @@ def patch_binop(node, quantizer_names, transpose=False): head = ast.fix_missing_locations(head) org_class = model_module.__dict__[org_class_name] - quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name, model_module) + quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name) register(original_cls=org_class, quantized_cls=quant_class) print(f"Successfully registered {org_class_name} for quantization") return True def _create_quantized_class_from_ast( - head, org_class, new_class_name, model_module, temp_file_name=None + head: ast.Module, + org_class: type, + new_class_name: str, + temp_file_name: str | None = None, ): """Create a quantized class from an AST representation. @@ -222,7 +225,6 @@ def _create_quantized_class_from_ast( head: The AST head containing the modified class definition org_class: The original class to be quantized new_class_name: Name for the new quantized class - model_module: The module containing the original class temp_file_name: Optional file name to save the generated code Returns: @@ -232,6 +234,19 @@ def _create_quantized_class_from_ast( # Save the generated code to a temporary file if requested module_code_str = ast.unparse(head) + + # Security: Validate generated code doesn't contain suspicious patterns + suspicious_patterns = ["__import__", "eval", "exec", "compile", "open(", "os.system"] + for pattern in suspicious_patterns: + if pattern in module_code_str: + # Allow compile for specific trusted ModelOpt internal use + if pattern == "compile" and "torch.compile" in module_code_str: + continue + raise ValueError( + f"Generated code contains suspicious pattern '{pattern}'. " + f"This may indicate a security issue in AST transformation." + ) + if temp_file_name is None: with tempfile.NamedTemporaryFile( prefix="modelopt_", suffix=".py", delete=False @@ -253,6 +268,11 @@ def _create_quantized_class_from_ast( # ) # bandit throws error here # quant_class = model_module.__dict__[new_class_name] + # Security NOTE: compile() is used here on internally-generated AST, + # not on untrusted user input. The AST is created by ModelOpt's quantization + # logic and has been validated above. This is safer than exec() but still + # requires the AST transformation logic to be secure. + # Extract the bytecode and create a new class on the fly # This is more tricky but doesn't require runtime execution module_code = compile(head, filename=f"{temp_file_name}", mode="exec") diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py similarity index 98% rename from modelopt/torch/quantization/plugins/diffusers.py rename to modelopt/torch/quantization/plugins/diffusion/diffusers.py index 440d190d39..2ec057766b 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -45,8 +45,8 @@ else: # torch >= 2.9 from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext -from ..export_onnx import export_fp8_mha -from ..nn import ( +from ...export_onnx import export_fp8_mha +from ...nn import ( QuantConv2d, QuantInputBase, QuantLinear, @@ -54,7 +54,7 @@ QuantModuleRegistry, TensorQuantizer, ) -from .custom import _QuantFunctionalMixin +from ..custom import _QuantFunctionalMixin onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, diff --git a/modelopt/torch/quantization/plugins/diffusion/fastvideo.py b/modelopt/torch/quantization/plugins/diffusion/fastvideo.py new file mode 100644 index 0000000000..2fd6a59455 --- /dev/null +++ b/modelopt/torch/quantization/plugins/diffusion/fastvideo.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Support quantization for FastVideo layers.""" + +import torch +import torch.nn.functional as F +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.models.vaes.wanvae import WanCausalConv3d + +from ...nn import QuantLinearConvBase, QuantModuleRegistry +from ...nn.modules.quant_conv import _QuantConv3d +from ...nn.modules.quant_linear import _QuantLinear +from ...utils import is_torch_export_mode + + +@QuantModuleRegistry.register({WanCausalConv3d: "WanCausalConv3d"}) +class _QuantWanCausalConv3d(_QuantConv3d): + @staticmethod + def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor: + """Quantize weight in linear format for proper block-wise FP4 quantization.""" + if module._enable_weight_quantization or is_torch_export_mode(): + # Quantize in linear format (block-wise quantization works correctly here) + return module.weight_quantizer(weight) + + return weight + + def forward(self, x, cache_x=None): + from fastvideo.platforms import current_platform + + with self.quantize_weight(): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + x = ( + x.to(self.weight.dtype) if current_platform.is_mps() else x + ) # casting needed for mps since amp isn't supported + + input = self.input_quantizer(x) + output = super(WanCausalConv3d, self).forward(input) + + if isinstance(output, tuple): + return (self.output_quantizer(output[0]), *output[1:]) + return self.output_quantizer(output) + + +@QuantModuleRegistry.register({ReplicatedLinear: "ReplicatedLinear"}) +class _QuantReplicatedLinear(_QuantLinear): + pass diff --git a/modelopt/torch/quantization/plugins/diffusion/ltx2.py b/modelopt/torch/quantization/plugins/diffusion/ltx2.py new file mode 100644 index 0000000000..d89fe4b828 --- /dev/null +++ b/modelopt/torch/quantization/plugins/diffusion/ltx2.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LTX-2 quantization plugin.""" + +import contextlib + +import torch + +from modelopt.torch.quantization.nn.modules.quant_linear import _QuantLinear +from modelopt.torch.quantization.nn.modules.quant_module import QuantModuleRegistry +from modelopt.torch.quantization.utils import is_torch_export_mode + +_FP8_DTYPES = tuple( + dtype + for dtype_name in ("float8_e4m3fn", "float8_e5m2", "float8_e4m3fnuz", "float8_e5m2fnuz") + if (dtype := getattr(torch, dtype_name, None)) is not None +) + + +def _upcast_fp8_weight( + weight: torch.Tensor, target_dtype: torch.dtype, seed: int = 0 +) -> torch.Tensor: + if target_dtype is torch.bfloat16: + try: + from ltx_core.loader.fuse_loras import fused_add_round_launch + + return fused_add_round_launch( + torch.zeros_like(weight, dtype=target_dtype), + weight, + seed, + ) + except Exception: + pass + return weight.to(target_dtype) + + +class _QuantLTX2Linear(_QuantLinear): + """Quantized Linear with FP8 upcast before weight quantization.""" + + @staticmethod + def _get_quantized_weight(module: "_QuantLTX2Linear", weight: torch.Tensor) -> torch.Tensor: + if _FP8_DTYPES and weight.dtype in _FP8_DTYPES: + weight = _upcast_fp8_weight(weight, torch.bfloat16, 0) + if module._enable_weight_quantization or is_torch_export_mode(): + return module.weight_quantizer(weight) + return weight + + +def register_ltx2_quant_linear() -> None: + """Register the LTX-2 quantized Linear, overriding the default mapping.""" + with contextlib.suppress(KeyError): + QuantModuleRegistry.unregister(torch.nn.Linear) + QuantModuleRegistry.register({torch.nn.Linear: "nn.Linear"})(_QuantLTX2Linear) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 30fdc5244c..aa274ea7eb 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -22,6 +22,10 @@ from typing import TYPE_CHECKING import torch +import transformers +from packaging import version +from torch import Tensor +from torch.nn.functional import linear try: from torch.distributed.tensor import Shard @@ -36,7 +40,6 @@ kitchen = None import torch.nn as nn -import transformers from transformers.models.t5.modeling_t5 import T5Attention from modelopt.torch.opt.dynamic import DynamicModule @@ -46,6 +49,13 @@ from ..conversion import register from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear +from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE + +if IS_TRITON_AVAILABLE: + from ..triton import weight_dequant +else: + weight_dequant = None + from ..utils import replace_function from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -55,6 +65,8 @@ __all__ = ["register_hf_attentions_on_the_fly"] +TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0") + class _QuantAttention(QuantModule): """Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0.""" @@ -438,18 +450,72 @@ class _QuantSparseMoe(QuantModule): """ def _setup(self): - pass + num_experts = 0 + if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): + num_experts = self.gate.num_experts + elif hasattr(self, "num_experts"): + num_experts = self.num_experts + elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): + num_experts = self.experts.num_experts + + self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") + self._count_expert_tokens = False + + if num_experts == 0: + warnings.warn( + f"{self.__class__.__name__}: could not resolve num_experts; " + "expert routing will not be tracked for this layer." + ) + return + + if hasattr(self, "gate"): + self.gate.register_forward_hook(self._gate_forward_hook) + + def _gate_forward_hook(self, module, input, output): + if not self._count_expert_tokens: + return + with torch.no_grad(): + if isinstance(output, tuple) and len(output) >= 3: + # v5.x TopKRouter: returns (logits, scores, indices) + indices = output[2] + else: + # v4.x nn.Linear gate: returns logits tensor + logits = output if not isinstance(output, tuple) else output[0] + top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k + _, indices = torch.topk(logits.float(), top_k, dim=-1) + counts = torch.bincount( + indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) + ) + self.expert_token_count += counts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): + is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) + if is_calib: # If any of the experts are in calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k - original_top_k = self.top_k - self.top_k = self.num_experts - super().forward(hidden_states) - self.top_k = original_top_k - return super().forward(hidden_states) + if TRANSFORMERS_VERSION_GE_5_0: + assert hasattr(self, "gate") and hasattr(self.gate, "top_k") + original_top_k = self.gate.top_k + self.gate.top_k = self.gate.num_experts + super().forward(hidden_states) + self.gate.top_k = original_top_k + else: + # Path for transformers < 5.0 + original_top_k = self.top_k + if hasattr(self, "num_experts"): + self.top_k = self.num_experts + elif hasattr(self, "experts"): + self.top_k = self.experts.num_experts + else: + raise ValueError(f"Could not find num_experts in module {self}") + super().forward(hidden_states) + self.top_k = original_top_k + # Enable counting only for the real-routing forward during calibration + self._count_expert_tokens = is_calib + output = super().forward(hidden_states) + self._count_expert_tokens = False + return output class _QuantLlama4TextExperts(QuantModule): @@ -571,6 +637,86 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) +class _QuantQwen3VLMoeTextExperts(QuantModule): + def _setup(self): + """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" + from accelerate import init_empty_weights + + dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device + + def _copy_weight(module, weight): + module.to_empty(device=device) + with torch.no_grad(): + module.weight.data = weight.detach().data.to(dtype=dtype, device=device) + + # The attribute name was changed from `intermediate_size` to `intermediate_dim` in + # https://github.com/huggingface/transformers/commit/0642963ba13f2dae0596fe489415569e1d91fbda + if hasattr(self, "intermediate_size"): + expert_dim = self.intermediate_size + elif hasattr(self, "intermediate_dim"): + expert_dim = self.intermediate_dim + else: + raise AttributeError("Could not find intermediate dimension size in model") + + with init_empty_weights(): + gate_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + up_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + down_proj = nn.ModuleList( + [ + nn.Linear(expert_dim, self.hidden_size, bias=False) + for _ in range(self.num_experts) + ] + ) + + for idx in range(self.num_experts): + _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :expert_dim].T) + _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, expert_dim:].T) + _copy_weight(down_proj[idx], self.down_proj[idx, :].T) + + delattr(self, "gate_up_proj") + delattr(self, "down_proj") + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj + + def forward( + self, + hidden_states: torch.Tensor, + routing_weights: torch.Tensor, + router_indices: torch.Tensor, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + next_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate = self.gate_proj[expert_idx](current_state) + up = self.up_proj[expert_idx](current_state) + gated_output = up * self.act_fn(gate) + out = self.down_proj[expert_idx](gated_output) + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + + return next_states + + class _QuantDbrxFFN(_QuantSparseMoe): @property def num_experts(self): @@ -585,11 +731,81 @@ def top_k(self, value): self.router.moe_top_k = value -try: - from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe +class _QuantCompressedLinear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + + def forward(self, input: Tensor) -> Tensor: + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + weight_data = self.compressor.decompress_module(self) + else: + weight_data = self.weight + + return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) + + def unpack_weight(self): + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) + if hasattr(self, "weight_packed"): + del self.weight_packed + if hasattr(self, "weight_scale"): + del self.weight_scale + + +class _QuantFP8Linear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D" + assert self.weight.ndim == 2, "Weight must be 2D" + self.block_size = max( + self.weight.shape[0] // self.weight_scale_inv.shape[0], + self.weight.shape[1] // self.weight_scale_inv.shape[1], + ) + assert self.block_size == 128, "Block size must be 128" + + def _get_weight_and_scale_inv(self): + if isinstance(self.weight, torch.distributed.tensor.DTensor): + weight = self.weight._local_tensor.contiguous() + scale_inv = self.weight_scale_inv._local_tensor.contiguous() + else: + weight = self.weight.contiguous() + scale_inv = self.weight_scale_inv.contiguous() + return weight, scale_inv + + def forward(self, input: Tensor) -> Tensor: + assert weight_dequant is not None, "Triton is not available" + if self.weight.element_size() == 1: + with torch.cuda.device(self.weight.device): + weight, scale_inv = self._get_weight_and_scale_inv() + weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype) + else: + weight = self.weight + return linear( + self.input_quantizer(input), + self.weight_quantizer(weight), + self.bias, + ) + + def unpack_weight(self): + assert weight_dequant is not None, "Triton is not available" + with torch.cuda.device(self.weight.device): + weight, scale_inv = self._get_weight_and_scale_inv() + self.weight = nn.Parameter( + weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()), + requires_grad=False, + ) + if hasattr(self, "weight_scale_inv"): + del self.weight_scale_inv + - if Llama4TextMoe not in QuantModuleRegistry: - QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) +try: + from transformers.models.llama4.modeling_llama4 import Llama4TextExperts if Llama4TextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( @@ -612,16 +828,6 @@ def top_k(self, value): except ImportError: pass -try: - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - if MixtralSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from transformers.models.falcon.modeling_falcon import FalconLinear @@ -631,32 +837,30 @@ def top_k(self, value): pass try: - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + from compressed_tensors.linear.compressed_linear import CompressedLinear - if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( - _QuantSparseMoe + if CompressedLinear not in QuantModuleRegistry: + QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})( + _QuantCompressedLinear ) except ImportError: pass try: - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts - if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( - _QuantSparseMoe + if Qwen3VLMoeTextExperts not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( + _QuantQwen3VLMoeTextExperts ) except ImportError: pass try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock + from transformers.integrations.finegrained_fp8 import FP8Linear - if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( - _QuantSparseMoe - ) + if FP8Linear not in QuantModuleRegistry: + QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear) except ImportError: pass @@ -774,6 +978,58 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) +def _is_sparse_moe_block(module): + """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. + + All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) + share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes + (``top_k`` and ``num_experts``), and an ``experts`` sub-module. + + This function detects that pattern instead of relying on class names, making it forward-compatible + with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but + use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom + ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. + """ + if not hasattr(module, "experts"): + return False + + # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) + if hasattr(module, "gate"): + gate = module.gate + has_topk = hasattr(gate, "top_k") + has_num_experts = hasattr(gate, "num_experts") + if has_topk and has_num_experts: + return True + + # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) + return hasattr(module, "top_k") and hasattr(module, "num_experts") + + +def register_sparse_moe_on_the_fly(model): + """Auto-detect and register MOE modules as _QuantSparseMoe. + + Walks the model tree, identifies MoE blocks by their structural attributes + (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. + """ + visited_types = set() + for name, module in model.named_modules(): + mod_type = type(module) + + # Avoid duplicate registration: skip if we already processed this type + # in this walk, or if it was previously registered in the QuantModuleRegistry. + if mod_type in visited_types or QuantModuleRegistry.get(mod_type) is not None: + continue + + visited_types.add(mod_type) + + if _is_sparse_moe_block(module): + print( + f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " + f"registering with _QuantSparseMoe.\033[0m" + ) + QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) + + def _is_supported_hf_model(model): """Check if the model a valid model for transformers quantization specific support.""" supported_models = [transformers.PreTrainedModel] @@ -839,6 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, + register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, ] diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f3..e84735ae93 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -24,7 +24,6 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import megatron.core.transformer.moe.experts as megatron_moe -import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region @@ -52,6 +51,7 @@ TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, + TELinear, TERowParallelGroupedLinear, TERowParallelLinear, ) @@ -574,19 +574,31 @@ def _setup(self): expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state - def sync_moe_local_experts_amax(self): - """Sync amax across local experts in a SequentialMLP. + def layer_sync_moe_local_experts_amax(self): + """Sync input quantizer amax across local experts in a SequentialMLP. - amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). - This function is called to synchronize the amax values across local experts s.t. all localexperts will - share the same amax. + Ensures all experts have the same input quantizer amax.This function operates + on a single rank and does not require distributed sync. + + Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). + This function should be called before the distributed sync to ensure the amax values + are synchronized across the layer first. + + Note: + Because there are logic which calls collective communication based on whether amax is not None, + We need to guarantee that all experts must have amax. Otherwise, there will be deadlock + when synchronizing over EP since some ranks may have amax None and not calling the collective + communication. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: for name, module in expert.named_modules(): - if isinstance(module, TensorQuantizer) and module.amax is not None: + if ( + isinstance(module, TensorQuantizer) + and module.amax is not None + and "input_quantizer" in name + ): stored_amax = amax_dict.get(name) amax_tensor = module.amax.detach().clone() amax_dict[name] = ( @@ -598,8 +610,8 @@ def sync_moe_local_experts_amax(self): # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): - if isinstance(module, TensorQuantizer) and module.amax is not None: - module.amax = amax_dict[name].detach().clone().to(module.amax.device) + if isinstance(module, TensorQuantizer) and name in amax_dict: + module.amax = amax_dict[name].detach().clone() def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -627,6 +639,10 @@ class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear) class _QuantTEMCoreColumnParallelLinear(_QuantTELinear, _MegatronColumnParallelLinear): pass + @QuantModuleRegistry.register({TELinear: "te_mcore_Linear"}) + class _QuantTEMCoreLinear(_QuantTELinear): + pass + @QuantModuleRegistry.register( {TELayerNormColumnParallelLinear: "te_mcore_LayerNormColumnParallelLinear"} ) @@ -736,26 +752,3 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) - - -@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) -class _QuantMoELayer(QuantModule): - """Module to support special handling of token dispatching during calibration. - - During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. - However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance - returns. - - If calibration is not enabled, this module behaves as a normal MoELayer. - """ - - def _setup(self): - pass - - def forward(self, hidden_states): - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): - original_top_k = self.router.topk - self.router.topk = self.router.num_experts - super().forward(hidden_states) - self.router.topk = original_top_k - return super().forward(hidden_states) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index 8489086579..afc08211f0 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -120,6 +120,13 @@ def _functionals_to_replace(self, value): self._functionals_to_replace = value def _setup(self): + if getattr(self, "fuse_wgrad_accumulation", False): + warnings.warn( + "fuse_wgrad_accumulation is not supported with ModelOpt quantization. " + "Setting fuse_wgrad_accumulation to False." + ) + self.fuse_wgrad_accumulation = False + # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. @@ -131,6 +138,9 @@ def _setup(self): # Remove self.weight after setup. delattr(self, "weight") + # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization + # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. + def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning @@ -146,7 +156,17 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() idx = 1 if func_name == "_forward" else 0 inp = args[idx] - num_gemms = len(args[idx + 1]) + + # Handle both old and new TE signatures (changed in PR #2377 in TE 2.10) + # New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) + # Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...) + if Version("2.10") <= _TE_VERSION: + # New signature: non_tensor_args is a tuple, m_splits is the first element + num_gemms = len(args[idx + 1][0]) + else: + # Old signature: m_splits is directly args[idx + 1] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] quantized_inputs = self.input_quantizer(inp) diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index 2e95c98b79..b92b240c0d 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -15,6 +15,7 @@ """ModelOpt plugin for transformers Trainer.""" +import contextlib import gc import json import os @@ -100,6 +101,52 @@ class QuantizationArgumentsWithConfig(QuantizationArguments): ) +def _patch_fsdp2_post_backward(): + """Patch FSDP2 ``post_backward`` to handle mixed-precision gradient dtypes. + + FSDP2 with bf16 mixed precision upcasts bf16 parameters to fp32 for optimizer + precision, while gradients are reduced in bf16. In PyTorch >= 2.6, assigning a + bf16 gradient to a fp32 parameter raises a ``RuntimeError`` due to the + ``grad_dtype`` check, and the fused Adam optimizer also rejects mixed dtypes. + + This patch wraps ``FSDPParamGroup.post_backward`` to: + 1. Set ``grad_dtype=None`` on sharded params before reduction (allowing bf16 assignment). + 2. Cast gradients to match parameter dtype after reduction (so the optimizer sees matching dtypes). + + .. note:: + This is a workaround. The proper fix should come from PyTorch's FSDP2 + ``foreach_reduce`` (which should cast gradients to match the parameter dtype) + or from accelerate (which should set ``grad_dtype`` when it upcasts params). + Remove this once the upstream fix is available. + """ + try: + from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup + except ImportError: + return + + if hasattr(FSDPParamGroup, "_modelopt_original_post_backward"): + return # Already patched + + FSDPParamGroup._modelopt_original_post_backward = FSDPParamGroup.post_backward + + @torch.no_grad() + def _patched_post_backward(self): + # Allow bf16 gradients to be assigned to fp32 parameters + for fsdp_param in self.fsdp_params: + with contextlib.suppress(AttributeError): + fsdp_param.sharded_param.grad_dtype = None + + self._modelopt_original_post_backward() + + # Cast gradients to parameter dtype so the optimizer sees matching dtypes + for fsdp_param in self.fsdp_params: + sp = fsdp_param.sharded_param + if sp.grad is not None and sp.grad.dtype != sp.dtype: + sp.grad = sp.grad.to(sp.dtype) + + FSDPParamGroup.post_backward = _patched_post_backward + + def check_awq_smoothquant(quant_cfg): # TODO: Remove this once deepspeed for AWQ and SmoothQuant is added """Get the quantization type from the configuration.""" @@ -186,6 +233,7 @@ def _save_modelopt_state_with_weights(self): print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}") def _restore_modelopt_state_with_weights(self): + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(self._modelopt_state_path, weights_only=False) modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) restore_from_modelopt_state(self.model, modelopt_state) @@ -336,6 +384,7 @@ def _patch_accelerate_for_fsdp2_fix(self): is causing issues with quantized models since quantization modules adds buffers which are not sharded. This patch hides the buffers added by quantization modules from the original accelerate prepare. """ + _patch_fsdp2_post_backward() def _modelopt_prepare(self, *args, **kwargs): if not self.is_fsdp2: diff --git a/modelopt/torch/quantization/qtensor/__init__.py b/modelopt/torch/quantization/qtensor/__init__.py index c4ed88f87b..9c623c1bd2 100644 --- a/modelopt/torch/quantization/qtensor/__init__.py +++ b/modelopt/torch/quantization/qtensor/__init__.py @@ -20,5 +20,6 @@ from .int4_tensor import * from .int8_tensor import * from .mxfp4_tensor import * +from .mxfp8_tensor import * from .nf4_tensor import * from .nvfp4_tensor import * diff --git a/modelopt/torch/quantization/qtensor/mxfp8_tensor.py b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py new file mode 100644 index 0000000000..846a95ffcd --- /dev/null +++ b/modelopt/torch/quantization/qtensor/mxfp8_tensor.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements MXFP8 quantization for efficient tensor storage and computation.""" + +import torch + +from ..qtensor.base_qtensor import BaseQuantizedTensor +from ..utils import reduce_block_amax, reduce_block_padding + +__all__ = ["MXFP8QTensor"] + + +class MXFP8QTensor(BaseQuantizedTensor): + """Implements the MXFP8 quantization on tensors for more efficient storage or computation. + + MXFP8 uses: + - FP8 E4M3 format for elements + - E8M0 format for shared scales (power-of-2 only, stored as biased uint8 exponent) + - Block size of 32 elements along the last dimension + + Attributes: + quantized_data (torch.Tensor): The quantized data stored as float8_e4m3fn tensor. + """ + + E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + BLOCK_SIZE = 32 + SCALE_DTYPE = torch.uint8 # E8M0 format stores biased exponent as uint8 + + @classmethod + def _compute_e8m0_exponent(cls, amax: torch.Tensor) -> torch.Tensor: + """Compute E8M0 exponent from per-block amax values. + + Args: + amax: Per-block absolute max values. + + Returns: + torch.Tensor: Float tensor of E8M0 exponents (unbiased, range [-127, 127]). + """ + # Compute E8M0 scale: scale = 2^ceil(log2(amax / E4M3_max)) + descale = amax.float() / cls.E4M3_MAX + + # Handle zero/inf/nan cases + min_value = torch.tensor(-127.0, device=descale.device) + log2_descale = torch.where( + descale > 0, + torch.log2(descale), + min_value, + ) + + e8m0_exponent = torch.ceil(log2_descale) + + # Clamp exponent to valid E8M0 range + return torch.clamp(e8m0_exponent, min=-127, max=127) + + @classmethod + def get_weights_scaling_factor(cls, weight: torch.Tensor) -> torch.Tensor: + """Returns E8M0 scale (uint8 biased exponent) for weight tensor. + + Args: + weight: The weight tensor to compute scale for. Must be at least 2D. + Supports 2D (out_dim, in_dim) and 3D MoE (num_experts, out_dim, in_dim). + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. + For 2D input: (out_dim, in_dim // 32) + For 3D MoE input: (num_experts, out_dim, in_dim // 32) + """ + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + in_dim = weight.shape[-1] + + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" + ) + + # Compute amax per block (reduce_block_amax handles N-dimensional tensors) + amax = reduce_block_amax(weight, block_sizes={-1: cls.BLOCK_SIZE}) + + # Compute E8M0 exponent and convert to biased uint8 (bias = 127) + e8m0_exponent = cls._compute_e8m0_exponent(amax) + return (e8m0_exponent + 127).to(cls.SCALE_DTYPE) + + @classmethod + def get_weights_scaling_factor_from_quantizer( + cls, + weight: torch.Tensor, + weight_quantizer, + ) -> torch.Tensor: + """Returns E8M0 scale from quantizer or computes from weight. + + This method handles extracting the scale from a weight quantizer, + with proper format conversion and shape correction. + + Args: + weight: The weight tensor. Can be 2D (out_dim, in_dim) or + 3D for MoE (num_experts, out_dim, in_dim). + weight_quantizer: The weight quantizer with block_sizes and optional _scale. + + Returns: + torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32]. + """ + assert hasattr(weight_quantizer, "block_sizes"), ( + "weight_quantizer must have 'block_sizes' attribute" + ) + assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, ( + f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}" + ) + assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D" + + in_dim = weight.shape[-1] + # Expected scale shape: all dims except last, with last dim reduced by block size + # For 2D: (out_dim, in_dim // 32) + # For 3D MoE: (num_experts, out_dim, in_dim // 32) + expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE) + + if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None: + scale = weight_quantizer._scale + + assert scale.dtype == cls.SCALE_DTYPE, ( + f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}" + ) + assert scale.shape == expected_shape, ( + f"Scale shape {scale.shape} does not match expected shape {expected_shape}" + ) + return scale + + # No scale in quantizer, compute from weight + return cls.get_weights_scaling_factor(weight) + + @classmethod + def quantize_with_scale( + cls, + weight: torch.Tensor, + weights_scaling_factor: torch.Tensor, + ) -> torch.Tensor: + """Quantize weight tensor using a pre-computed E8M0 scale. + + This method is useful for export paths where the scale has already been computed. + + Args: + weight: The weight tensor to quantize. Must be at least 1D. + weights_scaling_factor: E8M0 scale as uint8 biased exponent (bias = 127). + Shape should be [..., out_dim, in_dim // 32] for 2D+ tensors, + or [in_dim // 32] for 1D tensors. + + Returns: + torch.Tensor: Quantized weight as float8_e4m3fn with same shape as input. + """ + assert weights_scaling_factor.dtype == cls.SCALE_DTYPE, ( + f"weights_scaling_factor must be {cls.SCALE_DTYPE} (E8M0 format), " + f"got {weights_scaling_factor.dtype}" + ) + + in_dim = weight.shape[-1] + num_blocks = in_dim // cls.BLOCK_SIZE + + assert in_dim % cls.BLOCK_SIZE == 0, ( + f"Weight inner dimension ({in_dim}) must be divisible by MXFP8 block size ({cls.BLOCK_SIZE})" + ) + + # Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent) + scale_factor = torch.exp2(127 - weights_scaling_factor.float()) + + # NOTE: vLLM/flashinfer may require this behavior: + # scale_factor = torch.where( + # weights_scaling_factor == 0, + # 1.0, + # torch.exp2(127 - weights_scaling_factor.float()) + # ) + + weight_reshaped = weight.view(*weight.shape[:-1], num_blocks, cls.BLOCK_SIZE) + scale_factor_expanded = scale_factor.unsqueeze(-1) + scaled_weight = weight_reshaped * scale_factor_expanded + scaled_weight = torch.clamp(scaled_weight, min=-cls.E4M3_MAX, max=cls.E4M3_MAX) + quantized_weight = scaled_weight.to(torch.float8_e4m3fn) + + return quantized_weight.view(weight.shape) + + @classmethod + def quantize( + cls, + input: torch.Tensor, + weights_scaling_factor: torch.Tensor | None = None, + ) -> tuple: + """Convert a tensor to MXFP8 quantized format. + + Args: + input (torch.Tensor): The input tensor to be quantized. + weights_scaling_factor (torch.Tensor | None): Optional pre-computed E8M0 scale + as uint8 biased exponent. If None, the scale will be computed from the input. + Shape should be [..., in_dim // 32] matching input dimensions. + + Returns: + tuple: (MXFP8QTensor, weights_scaling_factor) where weights_scaling_factor is + E8M0 scale as uint8 biased exponent. + """ + original_shape = input.shape + original_dtype = input.dtype + + input = reduce_block_padding(input, block_sizes={-1: cls.BLOCK_SIZE}) + + if weights_scaling_factor is None: + input_amax = reduce_block_amax(input, block_sizes={-1: cls.BLOCK_SIZE}) + e8m0_exponent = cls._compute_e8m0_exponent(input_amax) + weights_scaling_factor = (e8m0_exponent + 127).to(cls.SCALE_DTYPE) + + quantized_data = cls.quantize_with_scale(input, weights_scaling_factor) + + # Crop back to original shape + quantized_data = quantized_data[..., : original_shape[-1]] + + return cls(original_shape, original_dtype, quantized_data), weights_scaling_factor + + def dequantize(self, dtype: torch.dtype = None, **kwargs) -> torch.Tensor: + """Dequantize MXFP8 tensor back to the target dtype. + + Args: + dtype (torch.dtype | None): Target dtype for dequantization. Defaults to original dtype. + **kwargs: Must contain 'scale' (E8M0 biased uint8). + + Returns: + torch.Tensor: Dequantized tensor in the target dtype. + """ + assert "scale" in kwargs, "dequantize requires 'scale' in kwargs" + + e8m0_scale = kwargs["scale"] + + if dtype is None: + dtype = self.metadata["dtype"] + + original_shape = self.metadata["shape"] + quantized_data = self._quantized_data.float() + quantized_data = reduce_block_padding(quantized_data, block_sizes={-1: self.BLOCK_SIZE}) + + num_blocks = quantized_data.shape[-1] // self.BLOCK_SIZE + quantized_blocked = quantized_data.view( + *quantized_data.shape[:-1], num_blocks, self.BLOCK_SIZE + ) + + # Convert E8M0 biased exponent back to scale factor: descale = 2^(exponent - 127) + descale = torch.exp2(e8m0_scale.float() - 127) + + dequantized = quantized_blocked * descale.unsqueeze(-1) + + # Reshape and crop back to original shape + dequantized = dequantized.view(*quantized_data.shape[:-1], quantized_data.shape[-1]) + dequantized = dequantized[..., : original_shape[-1]] + + return dequantized.to(dtype) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 7c35af75e5..d9b5839716 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -171,6 +171,7 @@ def _dynamic_block_quantize_impl( num_bits == (2, 1) # type: ignore[comparison-overlap] and scale_bits == (4, 3) and triton_kernel.IS_AVAILABLE + and hasattr(triton_kernel, "fp4_fake_quant_block") # requires compute >= 8.9 and not DISABLE_TRITON_KERNEL and amax is not None ): @@ -562,6 +563,40 @@ def backward(ctx, grad_outputs): return _fake_quant_backward_function(ctx, grad_outputs, num_args=9) +class StaticBlockwiseFP4FakeQuantFunction(Function): + """Static blockwise FP4 fake quantization functional.""" + + @staticmethod + def forward( + ctx, + x, + amax, + global_amax=None, + quantize_block_scales=True, + out_dtype=None, + pass_through_bwd=False, + ): + """Forward method.""" + if not triton_kernel.IS_AVAILABLE: + raise RuntimeError( + "static_blockwise_fp4_fake_quant requires triton. " + "Install with `pip install triton`." + ) + _save_for_backward_if_needed(ctx, pass_through_bwd, x, amax) + return triton_kernel.static_blockwise_fp4_fake_quant( + x, + amax, + global_amax, + quantize_block_scales, + out_dtype, + ) + + @staticmethod + def backward(ctx, grad_outputs): + """Implements straight through estimation with clipping.""" + return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) + + def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): """Shared function body between TensorQuantFunction and FakeTensorQuantFunction.""" # Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning. @@ -606,3 +641,4 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): fake_tensor_quant = FakeTensorQuantFunction.apply scaled_e4m3 = ScaledE4M3Function.apply dynamic_block_quant = DynamicBlockQuantizationFunction.apply +static_blockwise_fp4_fake_quant = StaticBlockwiseFP4FakeQuantFunction.apply diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index c513a4b110..def70e5914 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -22,8 +22,7 @@ IS_AVAILABLE = False -# triton fp8 requires compute_cap >= 89 -if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9): +if torch.cuda.is_available(): with import_plugin( "triton", msg_if_missing=( @@ -31,6 +30,12 @@ "quantization simulations. Try to install triton with `pip install triton`." ), ): + # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * + from .fp8_kernel import * + + # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) + if torch.cuda.get_device_capability() >= (8, 9): + from .fp4_kernel_hopper import * IS_AVAILABLE = True diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index f2f9bd077f..63a8b3dcb7 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -24,7 +24,7 @@ import triton import triton.language as tl -__all__ = ["fp4_fake_quant_block"] +__all__ = ["fp4_dequantize", "static_blockwise_fp4_fake_quant"] _TORCH_TO_TL_DTYPE = { @@ -42,172 +42,6 @@ def _torch_dtype_to_tl(dtype: torch.dtype): return _TORCH_TO_TL_DTYPE[dtype] -@triton.jit -def fp4_fake_quant_kernel( - x_ptr, - y_ptr, - M, - N, - global_scale_ptr, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - BLOCK_SIZE: tl.constexpr, - TILE_M: tl.constexpr, - TILE_N: tl.constexpr, - NUM_FP4_BLOCKS: tl.constexpr, - OUT_DTYPE: tl.constexpr, -): - """Applies FP4 fake quantization using block pointers for memory addressing.""" - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - - row_start = pid_m * TILE_M - col_start = pid_n * TILE_N - - x_block_ptr = tl.make_block_ptr( - base=x_ptr, - shape=(M, N), - strides=(stride_xm, stride_xn), - offsets=(row_start, col_start), - block_shape=(TILE_M, TILE_N), - order=(1, 0), - ) - y_block_ptr = tl.make_block_ptr( - base=y_ptr, - shape=(M, N), - strides=(stride_ym, stride_yn), - offsets=(row_start, col_start), - block_shape=(TILE_M, TILE_N), - order=(1, 0), - ) - - global_scale = tl.load(global_scale_ptr).to(tl.float32) - global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) - - tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) - - tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) - x_abs = tl.abs(tile_reshaped) - - block_max = tl.max(x_abs, axis=2, keep_dims=True) - - block_max_scaled = block_max / (6.0 * global_scale_safe) - block_max_scaled = tl.minimum(block_max_scaled, 448.0) - block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale - block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) - - block_max_quant_broadcast = tl.broadcast_to( - block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE) - ) - - abs_scaled = x_abs / block_max_quant_broadcast - - q_val = tl.where( - abs_scaled <= 0.25, - 0.0, - tl.where( - abs_scaled < 0.75, - 0.5, - tl.where( - abs_scaled <= 1.25, - 1.0, - tl.where( - abs_scaled < 1.75, - 1.5, - tl.where( - abs_scaled <= 2.5, - 2.0, - tl.where( - abs_scaled < 3.5, - 3.0, - tl.where(abs_scaled <= 5.0, 4.0, 6.0), - ), - ), - ), - ), - ), - ) - - x_rescaled = q_val * block_max_quant_broadcast - x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) - - tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) - - tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) - - -def fp4_fake_quant_block( - x: torch.Tensor, - global_amax: torch.Tensor, - block_size: int = 16, - tile_rows: int = 16, - tile_cols: int = 64, - num_warps: int | None = None, - num_stages: int | None = None, -) -> torch.Tensor: - """FP4 fake quantization implementation using block-pointer tiling. - - Args: - x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. - global_amax (torch.Tensor): Global maximum value tensor for scaling. - block_size (int): Number of elements per FP4 block. - tile_rows (int, optional): Row tile size. Defaults to 64. - tile_cols (int, optional): Column tile size. Defaults to 128. Rounded up to - the nearest multiple of ``block_size`` internally. - num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``. - num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``. - - Returns: - torch.Tensor: Fake-quantized tensor matching the input shape and dtype. - """ - x_shape = x.shape - x_dtype = x.dtype - x = x.reshape(-1, x_shape[-1]).contiguous() - - M, N = x.shape - y = torch.empty_like(x) - - stride_xm, stride_xn = x.stride() - stride_ym, stride_yn = y.stride() - - tile_cols = max(tile_cols, block_size) - tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size - num_fp4_blocks = tile_cols_aligned // block_size - - global_scale = global_amax.float() / (6.0 * 448.0) - - grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) - - launch_kwargs = { - "BLOCK_SIZE": block_size, - "TILE_M": tile_rows, - "TILE_N": tile_cols_aligned, - "NUM_FP4_BLOCKS": num_fp4_blocks, - "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), - } - if num_warps is not None: - launch_kwargs["num_warps"] = num_warps - if num_stages is not None: - launch_kwargs["num_stages"] = num_stages - fp4_fake_quant_kernel[grid]( - x, - y, - M, - N, - global_scale, - stride_xm, - stride_xn, - stride_ym, - stride_yn, - **launch_kwargs, - ) - - y = y.view(*x_shape) - return y - - @triton.jit def fp4_dequantize_kernel( packed_ptr, @@ -345,3 +179,123 @@ def fp4_dequantize( ) return output + + +@triton.jit +def static_blockwise_fp4_fake_quant_kernel( + x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + scale_ptr, # [NUM_FP4_BLOCKS] + NUM_FP4_BLOCKS, + BLOCK_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid >= NUM_FP4_BLOCKS: + return + + block_offset = pid * BLOCK_SIZE + idx = block_offset + tl.arange(0, BLOCK_SIZE) + + scale = tl.load(scale_ptr + pid).to(tl.float32) + + x = tl.load(x_ptr + idx).to(tl.float32) + + x_abs = tl.abs(x) + # If scale is 0, inf, or nan, use 1.0 (matching CUDA kernel behavior) + # Note: (x != x) checks if x is NaN per IEEE 754 + scale_safe = tl.where( + (scale == 0) | (scale != scale) | (tl.abs(scale) == float("inf")), # noqa: PLR0124 + 1.0, + scale, + ) + abs_scaled = x_abs / scale_safe + + # FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * scale_safe + x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled) + + tl.store(y_ptr + idx, x_quant.to(OUT_DTYPE)) + + +def static_blockwise_fp4_fake_quant( + x: torch.Tensor, + amax: torch.Tensor, + global_amax: torch.Tensor | None = None, + quantize_block_scales: bool = True, + out_dtype: torch.dtype | None = None, +): + """Static blockwise FP4 fake quantization using Triton kernel. + + Args: + x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. + amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values. + global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax. + quantize_block_scales: If True, quantize block scales to FP8. + out_dtype: Output dtype. Defaults to x.dtype if None. + """ + assert x.ndim == 2 + NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + + if out_dtype is None: + out_dtype = x.dtype + + amax = amax.float() # Requires to be in float32 + scale = amax / 6.0 # FP4 max representable value is 6.0 + + if quantize_block_scales: + from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl + from modelopt.torch.quantization.utils import reduce_amax + + if global_amax is None: + global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True) + + global_amax = global_amax.float() + scale_fp8_quant_amax = global_amax / 6.0 + scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) + + x_flat = x.contiguous().view(-1) + y_flat = torch.empty_like(x_flat, dtype=out_dtype) + scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + + tl_out_dtype = _torch_dtype_to_tl(out_dtype) + + grid = (NUM_FP4_BLOCKS,) + + with torch.cuda.device(x.device): + static_blockwise_fp4_fake_quant_kernel[grid]( + x_flat, + y_flat, + scale_flat, + NUM_FP4_BLOCKS, + BLOCK_SIZE, + OUT_DTYPE=tl_out_dtype, + ) + + return y_flat.view_as(x) diff --git a/modelopt/torch/quantization/triton/fp4_kernel_hopper.py b/modelopt/torch/quantization/triton/fp4_kernel_hopper.py new file mode 100644 index 0000000000..2ec31863ef --- /dev/null +++ b/modelopt/torch/quantization/triton/fp4_kernel_hopper.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NVFP4 Fake Quantization Triton kernels requiring compute capability >= 8.9 (Hopper+). + +These kernels use tl.float8e4nv which requires native FP8 hardware support. +""" + +import torch +import triton +import triton.language as tl + +from .fp4_kernel import _torch_dtype_to_tl + +__all__ = ["fp4_fake_quant_block"] + + +@triton.jit +def fp4_fake_quant_kernel( + x_ptr, + y_ptr, + M, + N, + global_scale_ptr, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + BLOCK_SIZE: tl.constexpr, + TILE_M: tl.constexpr, + TILE_N: tl.constexpr, + NUM_FP4_BLOCKS: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + """Applies FP4 fake quantization using block pointers for memory addressing.""" + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + row_start = pid_m * TILE_M + col_start = pid_n * TILE_N + + x_block_ptr = tl.make_block_ptr( + base=x_ptr, + shape=(M, N), + strides=(stride_xm, stride_xn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + y_block_ptr = tl.make_block_ptr( + base=y_ptr, + shape=(M, N), + strides=(stride_ym, stride_yn), + offsets=(row_start, col_start), + block_shape=(TILE_M, TILE_N), + order=(1, 0), + ) + + global_scale = tl.load(global_scale_ptr).to(tl.float32) + global_scale_safe = tl.where(global_scale > 0.0, global_scale, 1e-12) + + tile = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + + tile_reshaped = tl.reshape(tile, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE)) + x_abs = tl.abs(tile_reshaped) + + block_max = tl.max(x_abs, axis=2, keep_dims=True) + + block_max_scaled = block_max / (6.0 * global_scale_safe) + block_max_scaled = tl.minimum(block_max_scaled, 448.0) + block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale + block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0) + + block_max_quant_broadcast = tl.broadcast_to( + block_max_quant, (TILE_M, NUM_FP4_BLOCKS, BLOCK_SIZE) + ) + + abs_scaled = x_abs / block_max_quant_broadcast + + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * block_max_quant_broadcast + x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled) + + tile_quant = tl.reshape(x_rescaled, (TILE_M, TILE_N)) + + tl.store(y_block_ptr, tile_quant.to(OUT_DTYPE), boundary_check=(0, 1)) + + +def fp4_fake_quant_block( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, + tile_rows: int = 16, + tile_cols: int = 64, + num_warps: int | None = None, + num_stages: int | None = None, +) -> torch.Tensor: + """FP4 fake quantization implementation using block-pointer tiling. + + Args: + x (torch.Tensor): Input tensor of shape ``(M, N)`` or higher. + global_amax (torch.Tensor): Global maximum value tensor for scaling. + block_size (int): Number of elements per FP4 block. + tile_rows (int, optional): Row tile size. Defaults to 16. + tile_cols (int, optional): Column tile size. Defaults to 64. Rounded up to + the nearest multiple of ``block_size`` internally. + num_warps (int | None, optional): Override for Triton warps. Autotuned when ``None``. + num_stages (int | None, optional): Override for pipeline stages. Autotuned when ``None``. + + Returns: + torch.Tensor: Fake-quantized tensor matching the input shape and dtype. + """ + x_shape = x.shape + x_dtype = x.dtype + x = x.reshape(-1, x_shape[-1]).contiguous() + + M, N = x.shape + y = torch.empty_like(x) + + stride_xm, stride_xn = x.stride() + stride_ym, stride_yn = y.stride() + + tile_cols = max(tile_cols, block_size) + tile_cols_aligned = ((tile_cols + block_size - 1) // block_size) * block_size + num_fp4_blocks = tile_cols_aligned // block_size + + global_scale = (global_amax.float() / (6.0 * 448.0)).to(x.device) + + grid = lambda *_: (triton.cdiv(M, tile_rows), triton.cdiv(N, tile_cols_aligned)) + + launch_kwargs = { + "BLOCK_SIZE": block_size, + "TILE_M": tile_rows, + "TILE_N": tile_cols_aligned, + "NUM_FP4_BLOCKS": num_fp4_blocks, + "OUT_DTYPE": _torch_dtype_to_tl(x_dtype), + } + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + with torch.cuda.device(x.device): + fp4_fake_quant_kernel[grid]( + x, + y, + M, + N, + global_scale, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + **launch_kwargs, + ) + + y = y.view(*x_shape) + return y diff --git a/examples/deepseek/ds_kernel.py b/modelopt/torch/quantization/triton/fp8_kernel.py similarity index 75% rename from examples/deepseek/ds_kernel.py rename to modelopt/torch/quantization/triton/fp8_kernel.py index 00586acc2a..0b3c93e3e9 100644 --- a/examples/deepseek/ds_kernel.py +++ b/modelopt/torch/quantization/triton/fp8_kernel.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,32 +35,18 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""FP8 Triton Kernel Implementations.""" import torch import triton import triton.language as tl -"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" - @triton.jit def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): - """ - Dequantizes weights using the provided scaling factors and stores the result. + """Dequantizes weights using the provided scaling factors and stores the result. + + Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py Args: x_ptr (tl.pointer): Pointer to the quantized weights. @@ -86,14 +72,21 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. +def weight_dequant( + x: torch.Tensor, + s: torch.Tensor, + block_size: int = 128, + dtype: torch.dtype = torch.get_default_dtype(), +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). block_size (int, optional): The block size to use for dequantization. Defaults to 128. + dtype (torch.dtype, optional): The dtype of the output tensor. Defaults to torch.get_default_dtype(). Returns: torch.Tensor: The dequantized weight tensor of the same shape as `x`. @@ -104,7 +97,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) + y = torch.empty_like(x, dtype=dtype) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index b663ef5f29..6cf6bc90fe 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -229,9 +229,7 @@ def weight_attr_names(module: nn.Module) -> Generator[str, None, None]: # the standard weight and quantizer case weight = getattr(module, "weight", None) weight_quantizer = getattr(module, "weight_quantizer", None) - if isinstance(weight, nn.Parameter) and isinstance( - weight_quantizer, (TensorQuantizer, SequentialQuantizer) - ): + if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)): yield "weight" # other weight and quantizer case diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 0000000000..87088f805b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .ruler_dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 0000000000..c38a716d79 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from modelopt.torch.utils import get_module_device + +from ..config import CalibrationConfig +from ..conversion import print_sparse_attention_summary +from ..utils import get_named_sparse_attention_modules +from .calibrator import DynamicThresholdCalibrator +from .ruler_dataset import RulerDatasetBuilder + + +def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer": + """Load tokenizer and ensure pad_token is set.""" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig instance, or None if calibration is not configured + + Raises: + ValueError: If calibration config has invalid type or contains invalid values + """ + sparse_cfg = config.get("sparse_cfg", {}) + + # Calibration is optional + if "calibration" not in sparse_cfg: + return None + + calib_dict = sparse_cfg["calibration"] + + # Validate calibration is a dict + if not isinstance(calib_dict, dict): + raise ValueError(f"Calibration config must be a dict, got {type(calib_dict).__name__}. ") + + # Create and validate CalibrationConfig + return CalibrationConfig(**calib_dict) + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, + chunk_size: int = 2048, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + chunk_size: Chunk size for chunked prefill to avoid OOM. Set to -1 to disable. + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = _load_tokenizer(tokenizer_name_or_path) + + def forward_loop(model: nn.Module) -> None: + device = get_module_device(model) + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + input_ids = inputs["input_ids"].to(device) + seq_len = input_ids.shape[1] + + with torch.no_grad(): + if chunk_size > 0 and seq_len > chunk_size: + # Chunked prefill to avoid OOM with long sequences + past_key_values = None + for start_idx in range(0, seq_len, chunk_size): + end_idx = min(start_idx + chunk_size, seq_len) + chunk_input_ids = input_ids[:, start_idx:end_idx] + + outputs = model( + chunk_input_ids, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # Clean up KV cache + del past_key_values + torch.cuda.empty_cache() + else: + # Full prefill without chunking + model(input_ids, use_cache=False) + + return forward_loop + + +def create_decode_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + num_decode_tokens: int = 10, +) -> Callable: + """Create forward loop for decode phase calibration. + + Uses flash attention for fast prefill, then switches to eager attention + for decode token generation with softmax hook measurement. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + num_decode_tokens: Number of decode tokens to generate per sample + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = _load_tokenizer(tokenizer_name_or_path) + + def forward_loop(model: nn.Module) -> None: + device = get_module_device(model) + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + input_ids = inputs["input_ids"].to(device) + + # Save original attention implementation + original_attn_impl = getattr(model.config, "_attn_implementation", "eager") + + with torch.no_grad(): + try: + # Step 1: Fast prefill with flash attention (no measurement) + model.config._attn_implementation = "flash_attention_2" + outputs = model(input_ids, use_cache=True) + past_key_values = outputs.past_key_values + + # Step 2: Switch to eager for decode (enables softmax hook) + model.config._attn_implementation = "eager" + + # Step 3: Manual decode loop for explicit control over token generation + # model.generate() method is not used here because it doesn't allow explicit control over KV cache + # Get the last token's logits and sample next token + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + + for _ in range(num_decode_tokens): + outputs = model( + next_token, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + next_token = outputs.logits[:, -1:, :].argmax(dim=-1) + finally: + # Restore original attention implementation + model.config._attn_implementation = original_attn_impl + + # Clean up + del past_key_values + torch.cuda.empty_cache() + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Supports both prefill and decode phase calibration with per-phase target sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. Only used for prefill. + + Returns: + Dictionary with calibration results for each phase + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + + # Skip calibration if not configured + if calib_config is None: + return {} + + # Get per-phase targets + target_dict = calib_config.target_sparse_ratio + calibrate_prefill = target_dict.get("prefill", 0.0) > 0.0 + calibrate_decode = target_dict.get("decode", 0.0) > 0.0 + + # Skip if both phases are disabled + if not calibrate_prefill and not calibrate_decode: + print("Both prefill and decode target sparsity are 0.0, skipping calibration") + return {} + + # Get sparse attention modules + sparse_modules = get_named_sparse_attention_modules(model) + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Extract tokenizer and build calibration data if needed + tokenizer = _extract_tokenizer_from_model(model) + calibration_data = None + + if calibrate_prefill or calibrate_decode: + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.5), + cache_dir=calib_config.cache_dir, + data_dir=calib_config.data_dir, + ) + calibration_data = builder.build_calibration_dataset() + + # Initialize results + calibration_results: dict[str, Any] = {} + + # Run prefill calibration if enabled + if calibrate_prefill: + print("\n" + "=" * 60) + print("PREFILL PHASE CALIBRATION") + print("=" * 60) + + if calibration_data is None: + raise RuntimeError("calibration_data must be built before prefill") + prefill_forward_loop = forward_loop or create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) + + prefill_calibrator = DynamicThresholdCalibrator( + threshold_trials=calib_config.threshold_trials, + ) + prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill") + + if "a" in prefill_result and "b" in prefill_result: + calibration_results["prefill"] = prefill_result + else: + warnings.warn("Prefill calibration did not produce valid results") + + # Run decode calibration if enabled + if calibrate_decode: + print("\n" + "=" * 60) + print("DECODE PHASE CALIBRATION") + print("=" * 60) + + if calibration_data is None: + raise RuntimeError("calibration_data must be built before decode") + decode_forward_loop = create_decode_calibration_forward_loop( + calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens + ) + + decode_calibrator = DynamicThresholdCalibrator( + threshold_trials=calib_config.threshold_trials, + ) + decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode") + + if "a" in decode_result and "b" in decode_result: + calibration_results["decode"] = decode_result + else: + warnings.warn("Decode calibration did not produce valid results") + + # Check if any calibration succeeded + if not calibration_results: + warnings.warn("No calibration produced valid results") + return {} + + # Extract a and b for each phase + calibration_params: dict[str, dict[str, float]] = {} + for phase in ["prefill", "decode"]: + if phase in calibration_results: + result = calibration_results[phase] + calibration_params[phase] = { + "a": result["a"], + "b": result["b"], + } + + # Apply calibration params to all modules + print("\n" + "=" * 60) + print("APPLYING CALIBRATION RESULTS") + print("=" * 60) + print(f"Applying calibration to {len(sparse_modules)} modules:") + for phase, params in calibration_params.items(): + result = calibration_results[phase] + print(f" {phase}:") + print(f" Model: scale_factor = {params['a']:.6f} * exp({params['b']:.4f} * sparsity)") + print(f" R-squared: {result['r_squared']:.6f}") + + for module_name, module in sparse_modules: + module._sparse_method_instance.calibration_params = calibration_params + module._sparse_method_instance.target_sparse_ratio = target_dict + + # Print final summary + print("\nCalibration complete!") + print( + f"Target sparsity: prefill={target_dict.get('prefill', 0):.0%}, " + f"decode={target_dict.get('decode', 0):.0%}" + ) + print("\nTo change target sparsity at inference time, update:") + print(" module._sparse_method_instance.target_sparse_ratio = {'prefill': X, 'decode': Y}") + print_sparse_attention_summary(model) + + return { + "calibration_params": calibration_params, + "target_sparse_ratio": target_dict, + "calibration_results": calibration_results, + } diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 0000000000..df2c05e206 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from scipy.optimize import curve_fit +from tqdm import tqdm + +from ..stats_manager import SparseAttentionStatsManager +from ..utils import get_sparse_attention_modules + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using Exponential model. + + Calibration Algorithm: + 1. For each threshold λ_j in threshold_trials: + - Run ALL samples through forward_loop + - For each sample i with length L_i, collect sparsity S_ij + - Compute scale_factor_ij = λ_j × L_i + + 2. Fit Exponential model to ALL individual (sf_ij, S_ij) pairs: + scale_factor = a * exp(b * sparsity) + + 3. Return fitted a and b parameters + + At inference time (user specifies target_sparsity S*): + scale_factor = a * exp(b * S*) + threshold = scale_factor / seqlen + + Key insight: Using all individual data points (N_thresholds × N_samples) + instead of per-threshold averages provides more accurate fitting without + additional calibration time cost. + """ + + def __init__( + self, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + threshold_trials: List of thresholds to try during calibration. + Should span a range that achieves sparsities from ~10% to ~95%. + """ + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 2e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, + 8e-1, + 9e-1, + 9.5e-1, + 9.9e-1, + ] + + def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]: + """Calibrate a and b parameters for Exponential model. + + Algorithm: + 1. For each threshold λ_j in threshold_trials: + - Run ALL samples, collect sparsities S_ij for each sample i + - Compute scale_factor_ij = λ_j × L_i (where L_i is sample length) + + 2. Fit Exponential model to ALL (sf_ij, S_ij) pairs: + scale_factor = a * exp(b * sparsity) + + 3. Return fitted a and b parameters + + At inference time (user specifies target_sparsity S*): + scale_factor = a * exp(b * S*) + threshold = scale_factor / seqlen + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + phase: Phase to calibrate ('prefill' or 'decode') + + Returns: + Dict with calibration results including a, b, r_squared, and num_data_points + """ + # Extract attention modules + attention_modules = get_sparse_attention_modules(model) + + if not attention_modules: + raise ValueError("No sparse attention modules found for calibration") + + print(f"Starting Exponential model calibration ({phase} phase)") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples + print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds...") + + # Collect ALL individual data points (not averaged) + all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"} + + for threshold in tqdm(self.threshold_trials, desc=f"Testing thresholds ({phase})"): + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) + self._disable_calibration_mode(attention_modules) + + if not per_sample_stats: + continue + + # Collect individual (scale_factor, sparsity) pairs for each sample + for sample_stat in per_sample_stats: + length = sample_stat["sample_length"] + sparsity = sample_stat["sparsity"] + scale_factor = threshold * length + + all_data_points.append( + { + "threshold": threshold, + "length": length, + "scale_factor": scale_factor, + "sparsity": sparsity, + } + ) + + if len(all_data_points) < 10: + warnings.warn( + f"Not enough data points for {phase} calibration. " + f"Got {len(all_data_points)}, need at least 10." + ) + return {} + + print(f"Collected {len(all_data_points)} individual (scale_factor, sparsity) pairs") + + # Stage 2: Fit Exponential model: scale_factor = a * exp(b * sparsity) + print("\nStage 2: Fitting Exponential model to all data points...") + + # Extract data for fitting + scale_factors = np.array([pt["scale_factor"] for pt in all_data_points]) + sparsities = np.array([pt["sparsity"] for pt in all_data_points]) + + # Filter out extreme sparsities (must be in (10%, 90%)) + # Extreme values are unreliable for fitting + valid_mask = (sparsities >= 0.10) & (sparsities <= 0.90) + scale_factors = scale_factors[valid_mask] + sparsities = sparsities[valid_mask] + + if len(scale_factors) < 3: + warnings.warn( + f"Not enough valid data points after filtering. Got {len(scale_factors)}." + ) + return {} + + # Define Exponential model: sf = a * exp(b * S) + def exponential(sparsity, a, b): + return a * np.exp(b * sparsity) + + # Fit the model + try: + popt, pcov = curve_fit( + exponential, + sparsities, + scale_factors, + p0=[1.0, 5.0], # Initial guess + bounds=([0.0, 0.0], [np.inf, 20.0]), # Bounds for a and b + maxfev=10000, + ) + a, b = popt + except Exception as e: + warnings.warn(f"Curve fitting failed: {e}") + return {} + + # Calculate R-squared and RMSE + pred_scale_factors = exponential(sparsities, a, b) + ss_res = np.sum((scale_factors - pred_scale_factors) ** 2) + ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + rmse = np.sqrt(np.mean((scale_factors - pred_scale_factors) ** 2)) + + print(f"\n{phase.capitalize()} Calibration Results (Exponential Model):") + print(" Model: scale_factor = a * exp(b * sparsity)") + print(f" Fitted a: {a:.6f}") + print(f" Fitted b: {b:.4f}") + print(f" R-squared: {r_squared:.6f}") + print(f" RMSE: {rmse:.2f}") + print(f" Data points used: {int(np.sum(valid_mask))} / {len(all_data_points)}") + + # Show scale_factor for various target sparsities + print("\nScale factors for different target sparsities:") + print(f" {'Target':<10} {'Scale Factor':<15}") + print(f" {'-' * 10} {'-' * 15}") + for target in [0.5, 0.7, 0.8, 0.9, 0.95]: + sf = a * np.exp(b * target) + print(f" {target:<10.0%} {sf:<15.2f}") + + # Print calibration data summary by threshold + print("\nCalibration data summary (per threshold):") + print(f" {'Threshold':<12} {'Avg SF':<12} {'Avg Sparsity':<12} {'Samples':<8}") + print(f" {'-' * 12} {'-' * 12} {'-' * 12} {'-' * 8}") + + # Group by threshold for summary + by_threshold = defaultdict(list) + for point in all_data_points: + by_threshold[point["threshold"]].append(point) + + for threshold in sorted(by_threshold.keys()): + points = by_threshold[threshold] + avg_sf = np.mean([p["scale_factor"] for p in points]) + avg_s = np.mean([p["sparsity"] for p in points]) + print(f" {threshold:<12.4f} {avg_sf:<12.2f} {avg_s:<12.2%} {len(points):<8}") + + return { + "phase": phase, + "a": float(a), + "b": float(b), + "r_squared": float(r_squared), + "rmse": float(rmse), + "num_data_points": int(np.sum(valid_mask)), + "total_samples": len(all_data_points), + "calibration_type": "exponential", + } + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats( + self, modules: list[nn.Module], phase: str | None = None + ) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + phase: Optional phase to filter by ('prefill' or 'decode'). + If None, returns all stats. + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats(phase) + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = float(np.mean(sparsities)) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_dataset.py new file mode 100644 index 0000000000..abbbc399d6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_dataset.py @@ -0,0 +1,1083 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""RULER dataset for sparse attention calibration. + +This module contains the RULER dataset builder and core generation logic adapted +from the RULER benchmark (https://github.com/NVIDIA/RULER) for calibration purposes. +The generation logic closely follows the official RULER implementation to ensure +dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import hashlib +import json +import logging +import random +import re +import string +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + + +def _load_paul_graham_essays_from_files(data_dir: Path) -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from data_dir/essays. + Files must be downloaded first using download_ruler_data.sh. + + Args: + data_dir: Base directory for RULER data (contains an 'essays' subdir with .txt files). + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + essays_dir = data_dir / "essays" + if not essays_dir.exists(): + raise RuntimeError( + f"Essays directory not found at {essays_dir}.\n" + "Please run the download script first:\n" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" + ) + + essay_files = list(essays_dir.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {essays_dir}.\n" + "Please run the download script first:\n" + " bash examples/llm_sparsity/attention_sparsity/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from {essays_dir}...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays(data_dir: Path) -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Args: + data_dir: Base directory for RULER data (contains an 'essays' subdir). + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files(data_dir) + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, + data_dir: Path | None = None, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + data_dir: Base directory for RULER data (required when type_haystack='essay'). + Must contain an 'essays' subdir with Paul Graham .txt files. + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + if data_dir is None: + raise ValueError( + "data_dir is required when type_haystack='essay'. " + "Pass the path to the RULER data directory (containing an 'essays' subdir)." + ) + # Load essay corpus + essay_text = _load_paul_graham_essays(Path(data_dir)) + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + elif type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + else: + raise ValueError(f"Unknown haystack type: {type_haystack}") + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + data_dir: Path | None = None, + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + data_dir: Base directory for RULER data (required when type_haystack='essay'). + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + data_dir=data_dir, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.debug(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.debug(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + data_dir=data_dir, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.debug(f"Optimal haystack size: {final_size}") + + return final_size + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assigned the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={"alpha": 2.0}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str | object, + num_length_bins: int = 4, + max_length_filter: int = 65536, + seed: int = 42, + cache_dir: str | None = None, + data_dir: str | Path | None = None, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/data/ + data_dir: Optional path to RULER data directory (contains 'essays' subdir). + Required for NIAH tasks with essay haystack when not using pip default layout. + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + self.cache_dir = cache_dir + self.data_dir = Path(data_dir) if data_dir is not None else None + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path + random.seed(seed) + + def _get_cache_path(self) -> Path: + """Generate cache file path based on calibration parameters.""" + tokenizer_path = ( + self.tokenizer_name_or_path + if isinstance(self.tokenizer_name_or_path, str) + else str(self.tokenizer_name_or_path) + ) + key = f"{tokenizer_path}_{self.total_samples}_{self.max_seqlen}" + hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12] + filename = f"ruler_cache_{self.total_samples}s_{self.max_seqlen}l_{hash_str}.json" + if self.cache_dir: + base_dir = Path(self.cache_dir) + else: + base_dir = Path.home() / ".cache" / "modelopt" / "data" + return base_dir / filename + + def _load_cached_data(self, cache_path: Path) -> list[dict[str, Any]] | None: + """Load calibration data from cache if it exists.""" + if cache_path.exists(): + try: + with open(cache_path) as f: + data = json.load(f) + print(f"Loaded {len(data)} cached calibration samples from {cache_path}") + return data + except Exception as e: + print(f"Warning: Failed to load cache: {e}") + return None + + def _save_cached_data(self, cache_path: Path, data: list[dict[str, Any]]) -> None: + """Save calibration data to cache.""" + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(data, f) + print(f"Saved calibration samples to cache: {cache_path}") + except Exception as e: + print(f"Warning: Failed to save cache: {e}") + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + If cache_dir was set, checks cache first and returns cached data if present. + Otherwise generates the dataset, saves to cache (if cache_dir set), and returns. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + cache_path = self._get_cache_path() + cached = self._load_cached_data(cache_path) + if cached is not None: + return cached + + all_samples = [] + + print( + f"Generating {self.total_samples} calibration samples " + f"across {len(self.target_lengths)} length bins: {self.target_lengths}" + ) + + # Generate calibration samples with sample-level progress + with tqdm(total=self.total_samples, desc="Generating RULER samples") as pbar: + for num_samples, target_length in zip(self.samples_per_length, self.target_lengths): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + pbar.update(1) + + random.shuffle(all_samples) + print(f"Generated {len(all_samples)} valid samples") + + self._save_cached_data(cache_path, all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Find optimal haystack size for target length + optimal_haystack = find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + data_dir=self.data_dir, + ) + + # Generate sample using official RULER implementation + sample = generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, + data_dir=self.data_dir, + ) + + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx + + return sample + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + answer_doc_idx = random.randint(0, num_docs - 1) + question = f"What is the special code mentioned in document {answer_doc_idx + 1}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == answer_doc_idx: # Insert answer in the correct document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index e72dacc943..d2d3b1078b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -46,12 +46,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="If True, enables sparse attention. If False, bypasses sparsity.", ) - threshold: float | dict[str, float] = ModeloptField( - default=1e-3, + threshold: dict[str, float] = ModeloptField( + default={"prefill": 1e-3, "decode": 1e-4}, title="Sparsity threshold.", description=( "Threshold for determining which attention values to skip. " - "Can be a float or dict with phase-specific values." + "Must be a dict with 'prefill' and 'decode' keys." ), ) @@ -77,6 +77,12 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass for monitoring.", + ) + is_causal: bool = ModeloptField( default=True, title="Causal attention flag.", @@ -87,16 +93,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) - calibration: dict | None = ModeloptField( - default=None, - title="Calibration configuration", - description=( - "Calibration settings for this pattern. " - "If provided, enables automatic threshold calibration. " - "Only one pattern should have calibration enabled." - ), - ) - @field_validator("method") @classmethod def validate_method(cls, v): @@ -127,47 +123,203 @@ def validate_block_size(cls, v): @field_validator("threshold") @classmethod def validate_threshold(cls, v): - """Validate threshold is in valid range (0, 1) or dict with valid phases.""" - if isinstance(v, dict): - # Validate phase keys - valid_phases = {"prefill", "decode", "default"} - invalid_keys = set(v.keys()) - valid_phases - if invalid_keys: + """Validate threshold is a dict with valid phases and values in range (0, 1).""" + if not isinstance(v, dict): + raise ValueError( + f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}" + ) + # Validate phase keys + valid_phases = {"prefill", "decode"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: raise ValueError( - f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" ) - # Validate all values are in range (0, 1) - for phase, threshold in v.items(): - if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + return v + + +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration fits an Exponential model to determine dynamic thresholds that + achieve target sparsity. The model learns parameters a and b per phase: + + scale_factor = a * exp(b * target_sparsity) + + At inference time, the threshold is computed as: + + threshold = scale_factor / sequence_length + + Key benefits: + - Target sparsity can be changed at runtime without recalibration + - Threshold automatically adapts to sequence length + - Supports independent prefill and decode phase calibration + - Exponential model provides better fit (lower RMSE) + """ + + target_sparse_ratio: dict[str, float] = ModeloptField( + default={"prefill": 0.5, "decode": 0.5}, + title="Target sparsity ratio", + description=( + "Target ratio of sparse attention blocks (0.0 to 1.0). " + "Dict with 'prefill' and 'decode' keys for per-phase targets. " + "Set a phase value to 0.0 to skip calibration for that phase." + ), + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description=( + "Total number of RULER samples for calibration (distributed across length bins). " + "Default (24) provides 1 sample per task per length bin (4 bins * 6 RULER tasks). " + "Increase for more robust calibration." + ), + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + chunk_size: int = ModeloptField( + default=2048, + title="Chunk size for prefill", + description=( + "Chunk size for chunked prefill to avoid OOM with long sequences. " + "When sequence length exceeds chunk_size, prefill is done in chunks using KV cache. " + "Set to -1 to disable chunking (full prefill)." + ), + ) + + num_decode_tokens: int = ModeloptField( + default=10, + title="Number of decode tokens", + description="Number of decode tokens to generate for decode phase calibration.", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, " + "1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1]. " + "Increasing the number of trials improves calibration accuracy but slows down calibration." + ), + ) + + cache_dir: str | None = ModeloptField( + default=None, + title="Cache directory", + description=( + "Directory to cache generated calibration samples. " + "Caching avoids regenerating samples on repeated calibration runs." + ), + ) + + data_dir: str | None = ModeloptField( + default=None, + title="RULER data directory", + description=( + "Path to RULER data directory (contains 'essays' subdir with Paul Graham .txt files). " + "Required for NIAH essay tasks when not using repo layout. Set from example script or CLI." + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: raise ValueError( - f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + f"All threshold_trials must be in range (0, 1), got {threshold}" ) - elif isinstance(v, (int, float)): - if v <= 0 or v >= 1: - raise ValueError(f"Threshold must be in range (0, 1), got {v}") - else: - raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") return v + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio dict.""" + if not isinstance(v, dict): + raise ValueError( + f"target_sparse_ratio must be a dict with 'prefill' and 'decode' keys, got {type(v)}" + ) + # Validate phase keys + valid_phases = {"prefill", "decode"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid target_sparse_ratio phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range [0, 1] + for phase, ratio in v.items(): + if not isinstance(ratio, (int, float)) or not 0.0 <= ratio <= 1.0: + raise ValueError( + f"target_sparse_ratio for phase '{phase}' must be between 0.0 and 1.0, got {ratio}" + ) + return v -# Pre-defined Sparse Attention Configuration -# Default configuration with block-wise sparsity optimized for Flash Attention -SKIP_SOFTMAX_DEFAULT = { - "sparse_cfg": { - "*attn*": { - "method": "flash_skip_softmax", - "threshold": { - "prefill": 1e-3, # More aggressive during prefill - "decode": 1e-4, # Conservative during decode - }, - "br": 128, # Flash Attention block rows - "bc": 128, # Flash Attention block columns - "backend": "pytorch", # Only pytorch backend supported - "enable": True, - }, - "default": {"enable": False}, - }, -} + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + @field_validator("chunk_size") + @classmethod + def validate_chunk_size(cls, v): + """Validate chunk_size is positive or -1 (disabled).""" + if v != -1 and v <= 0: + raise ValueError(f"chunk_size must be positive or -1 (disabled), got {v}") + return v + + @field_validator("num_decode_tokens") + @classmethod + def validate_num_decode_tokens(cls, v): + """Validate num_decode_tokens is positive.""" + if v <= 0: + raise ValueError(f"num_decode_tokens must be positive, got {v}") + return v class SparseAttentionConfig(ModeloptBaseConfig): @@ -184,8 +336,9 @@ class SparseAttentionConfig(ModeloptBaseConfig): "default": {"enable": False}, }, title="Sparse attention configuration", - description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " - "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names " + "(or 'calibration' for global calibration settings), values are configuration dicts with parameters like " + "'threshold', 'enable', etc.", validate_default=True, ) @@ -203,10 +356,11 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): default={ "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, # Enable statistics collection "enable": True, }, "default": {"enable": False}, @@ -218,8 +372,55 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): ) +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.9, "decode": 0.9}, + "samples": 64, + "max_seqlen": 65536, + "chunk_size": 4096, + }, + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + __all__ = [ + "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index ad137e9eef..2155a13d0d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -21,13 +21,15 @@ import torch.nn as nn +from modelopt import __version__ as mo_version from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict -from modelopt.torch.utils import get_unwrapped_name +from modelopt.torch.utils import atomic_print, get_unwrapped_name from .config import SparseAttentionConfig -from .plugins.huggingface import register_sparse_attention_on_the_fly +from .plugins import register_custom_model_plugins_on_the_fly from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules def is_attn_sparsified(model: nn.Module) -> bool: @@ -59,8 +61,8 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model - # Register sparse attention modules dynamically - register_sparse_attention_on_the_fly(model) + # Apply custom model plugins + register_custom_model_plugins_on_the_fly(model) # Replace attention modules with sparse versions replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) @@ -89,7 +91,7 @@ def replace_sparse_attention_modules(model: nn.Module, version=None): _replace_sparse_attention_modules(model, version=version) # Count and report replaced modules - replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + replaced_count = len(get_sparse_attention_modules(model)) if replaced_count > 0: print(f"Inserted {replaced_count} sparse attention modules") @@ -144,10 +146,7 @@ def set_sparse_attention_attribute( # Filter out model-level configs that shouldn't be passed to modules module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} - for name, module in model.named_modules(): - if not isinstance(module, SparseAttentionModule): - continue - + for name, module in get_named_sparse_attention_modules(model): # Check pattern match matched = False if isinstance(wildcard_or_filter, str): @@ -192,22 +191,21 @@ def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]) model: Model with sparse attention modules state_dict: Saved state dictionary """ - for name, module in model.named_modules(): - if isinstance(module, SparseAttentionModule): - module_name = get_unwrapped_name(name, model) - if module_name in state_dict: - module_state = state_dict[module_name] + for name, module in get_named_sparse_attention_modules(model): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] - # Restore method and config - if "method" in module_state: - module._method = module_state["method"] - if "method_config" in module_state: - # Restore config attributes - for key, val in module_state["method_config"].items(): - setattr(module, f"_{key}", val) + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + for key, val in module_state["method_config"].items(): + setattr(module, f"_{key}", val) - # Re-setup with restored config - module._setup() + # Re-setup with restored config + module._setup() def update_sparse_attention_metadata( @@ -222,18 +220,17 @@ def update_sparse_attention_metadata( """ sparse_state = {} - for name, module in model.named_modules(): - if isinstance(module, SparseAttentionModule): - module_name = get_unwrapped_name(name, model) + for name, module in get_named_sparse_attention_modules(model): + module_name = get_unwrapped_name(name, model) - # Save the method configuration that was used - # _method_config already contains the validated config dict - module_state = { - "method": module._sparse_method_instance.name, - "method_config": module._method_config.copy(), - } + # Save the method configuration that was used + # _method_config already contains the validated config dict + module_state = { + "method": module._sparse_method_instance.name, + "method_config": module._method_config.copy(), + } - sparse_state[module_name] = module_state + sparse_state[module_name] = module_state metadata["sparse_attention_state"] = sparse_state metadata["sparse_attention_config"] = ( @@ -241,6 +238,85 @@ def update_sparse_attention_metadata( ) +def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: + """Extract sparse attention config for export to config.json. + + Extracts the calibration parameters (a, b) for the exponential threshold model + from the first sparse attention module that has calibrated thresholds. + + The exported config allows computing threshold at runtime: + scale_factor = a * exp(b * target_sparsity) + threshold = scale_factor / seqlen + + Args: + model: Model with sparse attention applied + + Returns: + Dictionary with sparse attention config for HuggingFace config.json export. + Returns None if no calibrated sparse attention modules found. + + Example output:: + + { + "config_groups": { + "group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]} + }, + "threshold_scale_factor": { + "formula": "a * exp(b * target_sparsity)", + "prefill": {"a": 7.93, "b": 8.61}, + "decode": {"a": 0.12, "b": 9.85}, + }, + "producer": {"name": "modelopt", "version": "0.37.0"}, + } + """ + # Collect sparse attention module info + calibration_params = None + target_classes: set[str] = set() + + for module in get_sparse_attention_modules(model): + # Get the original wrapped module's class name + if hasattr(module, "get_original_cls_by_level"): + original_cls = module.get_original_cls_by_level(level=0) + if original_cls is not None: + target_classes.add(original_cls.__name__) + + # Get calibration params from first module that has them + if calibration_params is None: + calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) + + # Return None if no calibration params found + if calibration_params is None: + return None + + # Build threshold_scale_factor with model parameters + threshold_scale_factor: dict[str, Any] = { + "formula": "a * exp(b * target_sparsity)", + } + for phase in ["prefill", "decode"]: + if phase in calibration_params: + threshold_scale_factor[phase] = { + "a": calibration_params[phase]["a"], + "b": calibration_params[phase]["b"], + } + + # Build the export config + export_config: dict[str, Any] = { + "config_groups": { + "group_0": { + "sparse_algo": "softmax_skip", + "targets": sorted(target_classes) if target_classes else ["Attention"], + } + }, + "threshold_scale_factor": threshold_scale_factor, + "producer": { + "name": "modelopt", + "version": mo_version, + }, + } + + return export_config + + def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): """Disable sparse attention for matching modules. @@ -257,10 +333,7 @@ def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Ca >>> # Disable sparse attention for lm_head >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") """ - for name, module in model.named_modules(): - if not isinstance(module, SparseAttentionModule): - continue - + for name, module in get_named_sparse_attention_modules(model): matched = False if isinstance(wildcard_or_filter_func, str): matched = fnmatch.fnmatch(name, wildcard_or_filter_func) @@ -287,10 +360,7 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal >>> # Re-enable sparse attention for all attention modules >>> sparse_attn.enable_sparse_attention(model, "*attention*") """ - for name, module in model.named_modules(): - if not isinstance(module, SparseAttentionModule): - continue - + for name, module in get_named_sparse_attention_modules(model): matched = False if isinstance(wildcard_or_filter_func, str): matched = fnmatch.fnmatch(name, wildcard_or_filter_func) @@ -299,3 +369,52 @@ def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Cal if matched: module.enable() + + +def _format_threshold(info: dict) -> str: + """Format threshold info for display.""" + t = info.get("type") + if t == "dynamic_calibrated": + # Exponential model: threshold = a * exp(b * sparsity) / seqlen + params = info.get("calibration_params", {}) + target = info.get("target_sparse_ratio", {}) + parts = [] + for phase in ["prefill", "decode"]: + if phase in params: + a, b = params[phase]["a"], params[phase]["b"] + s = target.get(phase, 0.5) + parts.append(f"{phase}: a={a:.4f}, b={b:.2f}, target={s:.0%}") + return f"calibrated({', '.join(parts)})" + if t == "static": + v = info.get("value") + if isinstance(v, dict): + return f"threshold={v}" + return f"threshold={v:.2e}" if isinstance(v, float) else f"threshold={v}" + return "threshold=N/A" + + +@atomic_print +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Args: + model: Model with sparse attention applied + """ + sparse_modules = get_named_sparse_attention_modules(model) + + if not sparse_modules: + print("No sparse attention modules found") + return + + enabled = sum(1 for _, m in sparse_modules if m.is_enabled) + print(f"Sparse attention: {enabled}/{len(sparse_modules)} modules enabled") + + # Group by (method, threshold) + groups: dict[tuple[str, str], int] = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold = _format_threshold(module.get_threshold_info()) + groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 + + for (method, threshold), count in sorted(groups.items()): + print(f" {method}: {count} layers, {threshold}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 8801bafb05..e575de4da0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from typing import Any import numpy as np import torch @@ -42,9 +43,10 @@ def __init__(self, method_config: dict | None = None): method_config: Configuration dict with threshold, br, bc, is_causal, etc. All required fields should have defaults from SparseAttentionAttributeConfig. """ + super().__init__() config = method_config or {} - # Extract configuration (defaults handled by Pydantic) + # Extract configuration self.threshold_config = config["threshold"] self.br = config["br"] self.bc = config["bc"] @@ -52,23 +54,21 @@ def __init__(self, method_config: dict | None = None): self.is_causal = config["is_causal"] # Optional parameters not in Pydantic config - self.enable_correction_factor = config.get("enable_correction_factor", True) self.phase = config.get("phase", None) - # Initialize threshold - if isinstance(self.threshold_config, dict): - self.threshold = self.threshold_config.get( - "default", self.threshold_config.get("prefill", 1e-4) - ) - else: - self.threshold = self.threshold_config + # Initialize threshold from dict config (prefill phase as default) + self.threshold = self.threshold_config.get("prefill", 1e-3) + + # Calibration mode flag (prevents threshold updates during calibration) + self._calibration_mode = False + + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled def _update_threshold(self, phase: str): """Update threshold based on phase.""" - if isinstance(self.threshold_config, dict): - self.threshold = self.threshold_config.get( - phase, self.threshold_config.get("default", self.threshold) - ) + self.threshold = self.threshold_config.get(phase, self.threshold) def _infer_phase(self, attention_scores: torch.Tensor) -> str: """Infer phase from attention scores shape.""" @@ -133,12 +133,23 @@ def calc_correction_factor_and_p( batch_size, num_heads, seq_q, seq_k = attn_weights.shape # Calculate threshold - threshold_scale_factor = getattr(self, "threshold_scale_factor", None) - if threshold_scale_factor: - # Use calibrated dynamic threshold: λ = scale_factor / length - log_threshold = np.log(threshold_scale_factor / seq_k) + calibration_params = self.calibration_params + target_sparse_ratio = self.target_sparse_ratio + + if ( + calibration_params is not None + and phase in calibration_params + and target_sparse_ratio is not None + ): + # Use calibrated a, b to compute dynamic threshold + # Exponential model: scale_factor = a * exp(b * target_sparsity) + a = calibration_params[phase]["a"] + b = calibration_params[phase]["b"] + target_sparsity = target_sparse_ratio.get(phase, 0.5) + scale_factor = a * np.exp(b * target_sparsity) + log_threshold = np.log(scale_factor / seq_k) else: - # Use static threshold from config + # Use static threshold from config (no calibration or phase not calibrated) log_threshold = np.log(self.threshold) if phase == "prefill": @@ -161,7 +172,7 @@ def calc_correction_factor_and_p( # Used by Flash Attention to adjust running sum when max increases block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] - correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() # Step 4: Normalize attention scores by cumulative max # p represents log-space difference: log(score) - log(cummax) @@ -184,18 +195,17 @@ def calc_correction_factor_and_p( element_mask = element_mask[:, :, :seq_q, :seq_k] # Step 8: Calculate sparsity statistics - # Count kept blocks (averaged across batch and heads) - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) - - # Total valid blocks (lower triangle only for causal attention) - # Note: Causal mask pre-applied by attention module, so block_mask naturally - # has zeros in upper triangle. We only count lower triangle for denominator. - total_blocks = ( - num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 - if self.is_causal - else num_block_rows * num_block_cols # Non-causal: N*N - ) - sparsity = 1 - (kept_blocks / total_blocks) + if self.is_causal: + # For causal attention, only count lower triangle blocks (including diagonal) + num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 + total_valid_blocks = batch_size * num_heads * num_causal_blocks + dense_blocks = block_mask.sum() + total_blocks = num_causal_blocks + else: + dense_blocks = block_mask.sum() # Keep as tensor + total_valid_blocks = block_mask.numel() + total_blocks = num_block_rows * num_block_cols + sparsity = 1.0 - dense_blocks.item() / total_valid_blocks else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc @@ -216,7 +226,7 @@ def calc_correction_factor_and_p( # Tracks how often the maximum increases (needed for Flash Attention rescaling) block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] - correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() # Step 4: Normalize scores by cumulative max # p = log(score) - log(cummax) in log-space @@ -232,14 +242,15 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] - # Step 7: Calculate statistics - kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + # Step 7: Calculate sparsity statistics + dense_blocks = block_mask.sum() + total_valid_blocks = block_mask.numel() + sparsity = 1.0 - dense_blocks.item() / total_valid_blocks total_blocks = num_block_cols - sparsity = 1 - (kept_blocks / total_blocks) # Create stats dictionary stats = { - "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "correction_factor": correction_factor, "sparsity": sparsity, "phase": phase, "total_blocks": total_blocks, @@ -249,27 +260,18 @@ def calc_correction_factor_and_p( return element_mask, stats - def apply_sparsity( + def calculate_sparsity( self, - query: torch.Tensor | None = None, - key: torch.Tensor | None = None, - value: torch.Tensor | None = None, - attention_scores: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: - """Apply Flash Attention-aware block-wise sparsity. + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate sparsity mask and statistics for Flash Attention. Args: - query: Query tensor (unused, for API compatibility) - key: Key tensor (unused, for API compatibility) - value: Value tensor (unused, for API compatibility) attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] Returns: - Tuple with potentially modified attention_scores + Tuple of (sparse_mask, stats) where sparse_mask is boolean mask """ - # Attention scores must be provided for sparse attention - assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" - # Attention scores are always 4D: [batch, heads, seq_q, seq_k] assert len(attention_scores.shape) == 4, ( f"Expected 4D attention scores, got shape {attention_scores.shape}" @@ -278,20 +280,78 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase - self._update_threshold(phase) + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) - # Apply block-wise sparsity + # Calculate block-wise sparsity mask and stats sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) - # Store stats for module to collect (doesn't persist across calls) + # Store stats for module to collect self._last_stats = stats - # Apply mask to create sparse scores + return sparse_mask, stats + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply sparsity mask to attention scores. + + Args: + attention_scores: Attention scores tensor [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed boolean mask. If None, calculates internally. + + Returns: + Masked attention scores with sparse elements set to dtype minimum + """ + if sparse_mask is None: + sparse_mask, _ = self.calculate_sparsity(attention_scores) + + # Apply mask: set masked positions to minimum value (becomes 0 after softmax) mask_value = torch.finfo(attention_scores.dtype).min - sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + return attention_scores.masked_fill(~sparse_mask, mask_value) + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for this method. - return query, key, value, sparse_scores + Returns: + Dictionary with threshold configuration and calibration info. + """ + calibration_params = self.calibration_params + target_sparse_ratio = self.target_sparse_ratio + + if calibration_params is not None and target_sparse_ratio is not None: + # Per-phase calibrated dynamic threshold using Exponential model + example_lengths = [1024, 4096, 16384, 65536, 131072] + phase_info = {} + for phase, params in calibration_params.items(): + a, b = params["a"], params["b"] + target_sparsity = target_sparse_ratio.get(phase, 0.5) + scale_factor = a * np.exp(b * target_sparsity) + phase_info[phase] = { + "a": a, + "b": b, + "target_sparsity": target_sparsity, + "scale_factor": scale_factor, + "example_thresholds": { + length: scale_factor / length for length in example_lengths + }, + } + return { + "type": "dynamic_calibrated", + "formula": "threshold = a * exp(b * target_sparsity) / seqlen", + "calibration_params": calibration_params, + "target_sparse_ratio": target_sparse_ratio, + "phases": phase_info, + } + else: + # Static threshold (single value or phase-specific dict) + return { + "type": "static", + "value": self.threshold_config, + } @property def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index df7b5853b6..6329e4446f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -18,6 +18,7 @@ import re import warnings from abc import ABC, abstractmethod +from typing import Any import torch @@ -25,25 +26,61 @@ class SparseAttentionMethod(ABC): """Base class for sparse attention methods.""" + def __init__(self): + """Initialize base sparse attention method.""" + # Flag to indicate calibration mode (set by calibrator) + # Instance attribute to prevent shared state across multiple models + self._calibration_mode: bool = False + + # Calibration parameters set by the calibrator after calibration. + # Exponential model params per phase: {"prefill": {"a": ..., "b": ...}, ...} + self.calibration_params: dict[str, dict[str, float]] | None = None + # Target sparsity ratio per phase: {"prefill": 0.5, "decode": 0.5} + self.target_sparse_ratio: dict[str, float] | None = None + + @abstractmethod + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate sparsity mask and statistics without applying. + + Args: + attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + + Returns: + Tuple of (sparse_mask, stats_dict) where: + - sparse_mask: Boolean tensor indicating which elements to keep + - stats_dict: Dictionary with sparsity statistics + """ + @abstractmethod def apply_sparsity( self, - query: torch.Tensor | None = None, - key: torch.Tensor | None = None, - value: torch.Tensor | None = None, - attention_scores: torch.Tensor | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - """Apply sparsity to attention computation. + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply sparsity mask to attention scores. Args: - query: Query tensor - key: Key tensor - value: Value tensor - attention_scores: Pre-computed attention scores + attention_scores: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed mask. If None, calculates internally. + + Returns: + Masked attention scores with sparse elements set to -inf + """ + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for display/debugging. Returns: - Tuple of (query, key, value, attention_scores) with sparsity applied + Dictionary with threshold information. Should include: + - 'type': 'static', 'dynamic', or 'none' + - 'value': threshold value (for static) + - 'scale_factor': scale factor (for dynamic) + - Other method-specific info """ + return {"type": "none", "value": None} @property @abstractmethod diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 88434e7462..b79e25bd80 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -22,10 +22,12 @@ from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop +from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ + "calibrate", "sparsify", ] @@ -58,12 +60,36 @@ def sparsify( .. code-block::python config = { - "method": "flash_skip_softmax", "sparse_cfg": { + # Phase-aware thresholds with backend selection "*attention*": { + "method": "flash_skip_softmax", "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + # Disable for specific layers + "*layer.0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + For automatic threshold calibration using RULER dataset: + + .. code-block::python + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", "backend": "pytorch", "enable": True, + "calibration": { # Enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, }, "default": {"enable": False}, }, @@ -126,4 +152,26 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, config, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration with calibration settings + forward_loop: Optional callable that forwards calibration data through the model. + If provided, uses this for calibration data. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + """ + calibrate_sparse_attention(model, config, forward_loop=forward_loop) return model diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index ba8c8b8211..434fc18214 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,8 +15,20 @@ """Plugins for sparse attention integration with various frameworks.""" -from .huggingface import register_sparse_attention_on_the_fly +# List of model plugins that are called during conversion +# Each plugin is a callable that takes (model) and performs validation/setup +CUSTOM_MODEL_PLUGINS: list = [] + + +def register_custom_model_plugins_on_the_fly(model): + """Applies all registered custom model plugins.""" + for callback in CUSTOM_MODEL_PLUGINS: + callback(model) + + +from . import huggingface # noqa: E402 __all__ = [ - "register_sparse_attention_on_the_fly", + "CUSTOM_MODEL_PLUGINS", + "register_custom_model_plugins_on_the_fly", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 0c4a8baf93..828d126e86 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,12 +15,15 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import warnings + import torch.nn as nn import transformers from modelopt.torch.opt.dynamic import DynamicModule from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from . import CUSTOM_MODEL_PLUGINS class _GenericSparseAttention(SparseAttentionModule): @@ -118,3 +121,33 @@ def _is_supported_model(model: nn.Module) -> bool: # Support any PyTorch model with attention modules return isinstance(model, nn.Module) + + +def validate_eager_attention(model: nn.Module) -> None: + """Validate and enforce eager attention for HuggingFace models. + + Sparse attention requires attn_implementation='eager' because it + patches torch.nn.functional.softmax, which is only called in eager mode. + + Args: + model: Model to validate + """ + if not isinstance(model, transformers.PreTrainedModel): + return + + attn_impl = getattr(model.config, "_attn_implementation", None) + if attn_impl and attn_impl != "eager": + warnings.warn( + f"Sparse attention requires attn_implementation='eager', but model uses '{attn_impl}'. " + "Forcing eager attention implementation." + ) + model.config._attn_implementation = "eager" + + +# Register plugins +CUSTOM_MODEL_PLUGINS.extend( + [ + validate_eager_attention, + register_sparse_attention_on_the_fly, + ] +) diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 16b08bf19b..281e11e7d9 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -15,6 +15,8 @@ """Extensible sparse attention module.""" +from typing import Any + import torch import torch.nn.functional as F @@ -23,6 +25,7 @@ from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager class SparseAttentionModule(DynamicModule): @@ -103,6 +106,17 @@ def set_from_attribute_config( # Initialize sparse method instance self._init_sparse_method() + # Create stats manager based on config + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + # Initialize stats storage for collecting stats from sparse_softmax + self._last_stats: dict | None = None + def _init_sparse_method(self): """Initialize the sparse method instance.""" method_class = get_sparse_method(self._method) @@ -129,11 +143,22 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict (statistics collection will be added in calibration PR). + Returns empty dict if stats manager is not enabled. """ - # TODO: Statistics collection will be added in calibration PR + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() return {} + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information from the sparse method instance. + + Returns: + Dictionary with threshold information from the sparse method. + """ + if hasattr(self, "_sparse_method_instance") and self._sparse_method_instance is not None: + return self._sparse_method_instance.get_threshold_info() + return {"type": "none", "value": None} + def _setup(self): """Setup called by DynamicModule.""" # Apply default configuration if not yet configured @@ -157,6 +182,11 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) + # Collect stats if manager is available + if self._stats_manager is not None and self._last_stats is not None: + self._stats_manager.collect(self._last_stats) + self._last_stats = None # Clear after collection + return result def _get_sparse_context(self): @@ -172,14 +202,17 @@ def _create_sparse_softmax(self): original_softmax = F.softmax def sparse_softmax(input, dim=-1, *args, **kwargs): - # Let the method handle the sparsification - _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( - None, None, None, input - ) + # Calculate sparsity mask and collect statistics + sparse_mask, stats = self._sparse_method_instance.calculate_sparsity(input) + + # Store stats for collection + self._last_stats = stats + + # Only apply sparsity mask after calibration (not during calibration) + # During calibration, we measure sparsity without modifying the output + if not self._sparse_method_instance._calibration_mode: + input = self._sparse_method_instance.apply_sparsity(input, sparse_mask) - # Use sparse input if modified, otherwise use original - if sparse_input is not None: - return original_softmax(sparse_input, dim, *args, **kwargs) return original_softmax(input, dim, *args, **kwargs) return sparse_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py new file mode 100644 index 0000000000..b84a3cade5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self, phase: str | None = None) -> list[dict]: + """Get per-sample calibration statistics, optionally filtered by phase. + + Note: Returns historical stats collected while calibration_mode was enabled. + Stats remain accessible even after calibration_mode is disabled. + New stats are only collected when calibration_mode is True. + + Args: + phase: Optional phase to filter by ('prefill' or 'decode'). + If None, returns all stats. + + Returns: + List of per-sample statistics dictionaries. + Empty list if no stats were collected or no stats match the phase. + """ + if phase is None: + return self.per_sample_stats + return [s for s in self.per_sample_stats if s.get("phase") == phase] diff --git a/modelopt/torch/sparsity/attention_sparsity/utils.py b/modelopt/torch/sparsity/attention_sparsity/utils.py new file mode 100644 index 0000000000..94357754e1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/utils.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for sparse attention module discovery.""" + +import torch.nn as nn + +from .sparse_attention import SparseAttentionModule + + +def get_sparse_attention_modules(model: nn.Module) -> list[SparseAttentionModule]: + """Get all sparse attention modules in a model. + + Args: + model: Model to search for sparse attention modules. + + Returns: + List of SparseAttentionModule instances found in the model. + """ + return [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + +def get_named_sparse_attention_modules( + model: nn.Module, +) -> list[tuple[str, SparseAttentionModule]]: + """Get all sparse attention modules in a model with their names. + + Args: + model: Model to search for sparse attention modules. + + Returns: + List of (name, module) tuples for all SparseAttentionModule instances. + """ + return [(name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule)] diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index ffaa195f2e..2b085d5e35 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -20,6 +20,7 @@ from modelopt.torch.opt.conversion import ModelLikeModule from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.speculative.config import eagle3_default_config, kimik2_eagle_default_config from ..config import EagleConfig @@ -38,6 +39,14 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) break + # merge custom config with default config + default_arch_config = { + "llama": eagle3_default_config, + "kimik2": kimik2_eagle_default_config, + }[config.eagle_decoder_type] + custom_config = config.eagle_architecture_config + config.eagle_architecture_config = {**default_arch_config, **custom_config} + eagle_model = EagleDMRegistry.convert(model) eagle_model.modify( eagle_offline=config.eagle_offline, diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index 281528788c..d77ed298ac 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -36,7 +36,6 @@ """Eagle model utils.""" import torch -from torch import nn # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -71,21 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class RMSNorm(nn.Module): - """Borrowed from LlamaRMSNorm class.""" - - def __init__(self, hidden_size, eps=1e-6): - """LlamaRMSNorm is equivalent to T5LayerNorm.""" - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """Forward function for RMSNorm.""" - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 5435a8efac..e37e8f931c 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -17,6 +17,7 @@ import copy import warnings +from contextlib import contextmanager import megatron.core import torch @@ -24,13 +25,15 @@ from megatron.core import InferenceParams, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.extensions.transformer_engine import TELinear, TENorm from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, get_data_parallel_rank, get_expert_tensor_parallel_world_size, get_pipeline_model_parallel_world_size, @@ -59,7 +62,6 @@ try: from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec - from megatron.core.post_training.modelopt.layers import Linear except ImportError: warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!") @@ -388,7 +390,11 @@ def sharded_state_dict( if module is not self.layers: sharded_state_dict.update( sharded_state_dict_default( - module, f"{prefix}{name}.", sharded_offsets, metadata + module, + f"{prefix}{name}.", + sharded_offsets, + metadata, + tp_group=self.tp_group, ) ) @@ -442,16 +448,19 @@ def __init__( self._num_aux_hidden_states if self._num_aux_hidden_states > 0 else 2 ) - # This linear was previously a ColumnParallelLinear. We changed it to a normal linear + # This linear was previously a ColumnParallelLinear. We changed it to a TELinear # since ColumnParallelLinear will have try to gather the input sequence when sequence # parallel is used and does not allow gathering the outputs. with torch.device(device): - self.fc = Linear( + self.fc = TELinear( config.hidden_size * fc_input_size_multiplier, config.hidden_size, + parallel_mode="duplicated", config=config, init_method=(lambda w: None), # not used bias=bias, + skip_bias_add=False, + skip_weight_param_allocation=False, ) self.rotary_pos_emb = rotary_pos_emb @@ -529,11 +538,13 @@ def _get_eagle_transformer_layer_spec(self, config): IMPORTANT: EagleModule must use arbitrary_attention_mask since we need to manipulate the mask to compute the correct loss. The default causal mask will result in leaking. + However, if context parallel is used, we need to switch to causal + mask and inject attention_mask as attention_bias instead. """ transformer_layer_spec = get_gpt_modelopt_spec( config, remap_te_layernorm=True, - use_arbitrary_attention_mask=True, + use_arbitrary_attention_mask=get_context_parallel_world_size() == 1, ) # If heterogenous layers (e.g. DeepSeek), transformer_layer_spec is a # TransformerBlockSubmodules instead. We use the last layer_specs. @@ -583,9 +594,13 @@ def forward( # NOTE: Even if sequence_parallel is used, the rotary_seq_len must be in the original # length. Since we get the seq_len from hidden_states.shape[0], we need to # multiply the the tp back. + # Similarly, if context parallel is used, the rotary_seq_len must also be + # multiplied by context parallel size. rotary_seq_len = hidden_states.shape[0] if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size + if get_context_parallel_world_size() > 1: + rotary_seq_len *= get_context_parallel_world_size() if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) @@ -838,16 +853,41 @@ def _get_eagle_module_inputs( ttt_step: int = 0, ): """Getting EAGLE module inputs.""" - # [b, 1] + # gather_from_sequence_parallel_region gathers from the first dimention + # so we need to transpose input_ids first + # [b,s] -> [s,b] + input_ids = input_ids.clone().transpose(0, 1).contiguous() + input_ids = gather_from_sequence_parallel_region( + input_ids, group=get_context_parallel_group() + ) + # [s,b] -> [b,s] + input_ids = input_ids.transpose(0, 1).contiguous() id_padding = torch.zeros( (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) + # RotaryEmbedding's output is already scattered to context parallel region + # No need to scatter again. rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) + # [b,s] -> [s,b] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + padded_input_ids = scatter_to_sequence_parallel_region( + padded_input_ids, group=get_context_parallel_group() + ) + # [s,b] -> [b,s] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = gather_from_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -860,9 +900,17 @@ def _get_eagle_module_inputs( input_ids=eagle_inputs["input_ids"], position_ids=eagle_inputs["position_ids"], ) + eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, ttt_step) + attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = scatter_to_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), @@ -1111,14 +1159,17 @@ def forward( ttt_step=ttt_step, ) - _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( - eagle_inputs, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + with te_dot_product_attention_with_cp( + eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + ): + _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( + eagle_inputs, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) if self.config.sequence_parallel: eagle_module_input_hidden_states = gather_from_sequence_parallel_region( @@ -1330,3 +1381,80 @@ def get_ground_truth(self, input_ids, osl): if input_id[0, 0] == self.end_token: break return input_ids + + +@contextmanager +def te_dot_product_attention_with_cp(attention_mask: torch.Tensor, num_attention_heads: int): + """Context manager for TEDotProductAttention with context parallelism. + + Context manager that temporarily replace `attention_bias` + with `attention_mask` for `TEDotProductAttention.forward` calls across the process + if context parallel is used. + + Any call to `TEDotProductAttention.forward` (including calls originating + from other modules) inside the context will receive `attention_bias=attention_mask` + if context parallelism is used. + + Example: + with te_dot_product_attention_with_cp(attention_mask_tensor, num_attention_heads): + outputs = model(...) + + Note: This monkey-patches the class method and restores it on exit. + """ + from megatron.core.extensions.transformer_engine import TEDotProductAttention + + orig_forward = TEDotProductAttention.forward + + def _wrapped_forward(self, *args, **kwargs): + # Build attention_bias from the boolean attention_mask and ensure + # it's a fresh, detached tensor on the query's device/dtype to + # avoid shared-storage in-place modifications that break autograd. + query = args[0] if len(args) > 0 else None + if isinstance(query, torch.Tensor): + q_device = query.device + q_dtype = query.dtype + else: + q_device = None + q_dtype = None + + mask_fill = -1e9 + if q_dtype in (torch.float16, torch.bfloat16): + mask_fill = -40.0 + mask_val = torch.tensor(mask_fill, device=attention_mask.device) + zero_val = torch.tensor(0.0, device=attention_mask.device) + attention_bias = torch.where(attention_mask, mask_val, zero_val) + + if q_device is not None and q_dtype is not None: + attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) + + attention_bias = attention_bias.clone().detach().contiguous() + kwargs["attention_bias"] = attention_bias + + # Defensive clone of query/key/value positional tensors to avoid + # passing views into the fused attention kernel that might be + # modified in-place during backward. + if len(args) >= 1: + original_args = args + new_args = list(original_args) + try: + for i in range(min(3, len(new_args))): + if isinstance(new_args[i], torch.Tensor): + if not new_args[i].is_contiguous(): + new_args[i] = new_args[i].contiguous() + new_args[i] = new_args[i].clone() + + if any(x is None for x in new_args): + args = original_args + else: + args = tuple(new_args) + except Exception: + args = original_args + + return orig_forward(self, *args, **kwargs) + + if get_context_parallel_world_size() > 1: + TEDotProductAttention.forward = _wrapped_forward + try: + yield + finally: + TEDotProductAttention.forward = orig_forward diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 39df7b9b7d..5e7ff9c8e7 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -31,9 +31,12 @@ import contextlib import copy +from dataclasses import dataclass from typing import Any import torch +import transformers +from packaging.version import Version from torch import nn from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask @@ -45,21 +48,36 @@ ) from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from transformers.utils.quantization_config import QuantizationMethod +from transformers.utils.quantization_config import CompressedTensorsConfig from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel -from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask +from ..eagle.utils import expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel from ..utils import ( AcceptanceRateValidation, ResBlock, _setup_kimi_k2_decoder, + enable_cp_ttt_patch, + get_ttt_msk_func, temporary_set_config_value, ) +__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] + IGNORE_TOKEN_ID = LabelSmoother.ignore_index +ENABLE_CP_TTT_PATCH = False +# module variable to cache attention mask for cp ttt +CACHED_SHARD_TTT_MASKS = {} + + +def _get_empty_cache(config): + """Return an empty cache. Handle different versions of transformers for unit tests.""" + if Version(transformers.__version__) >= Version("4.54"): + return DynamicCache(config=config) + else: + return DynamicCache() @MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) @@ -219,7 +237,7 @@ def __init__(self, config, decoder_layer_cls, bias=False): [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) if config.use_last_layernorm: - self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps) # Optionally, we use a smaller vocab table for eagle module if config.draft_vocab_size != config.vocab_size or config.has_lm_head: @@ -227,7 +245,6 @@ def __init__(self, config, decoder_layer_cls, bias=False): assert config.draft_vocab_size <= config.vocab_size, ( "EAGLE module's vocab size should be <= base model vocab size!" ) - # Initialize the buffers to zero. # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. if config.draft_vocab_size < config.vocab_size: @@ -370,7 +387,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -388,6 +405,25 @@ def forward( return post_norm_h, pre_norm_h, past_key_values +@dataclass +class EagleBaseModelOutput: + out_hiddens: torch.Tensor + aux_hiddens: torch.Tensor | None = None + logits: torch.Tensor | None = None + input_embeds: torch.Tensor | None = None + loss: torch.Tensor | None = None + + @classmethod + def from_offline_dict(cls, d: dict): + return cls( + out_hiddens=d.get("base_model_hidden_states"), + aux_hiddens=d.get("aux_hidden_states"), + logits=d.get("base_model_logits"), + input_embeds=d.get("base_model_input_embeds"), + loss=None, + ) + + @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFEagleModel(EagleModel): """Eagle Model Class for huggingface models.""" @@ -408,16 +444,26 @@ def _base_model_lm_head(self): @property def _base_llm_config(self): """Return the llm config for the base model, from LLM or VLM.""" - return self.config.llm_config if hasattr(self.config, "llm_config") else self.config + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": ["model", "backbone", "language_model.backbone"], + "base_model_path": [ + "model.language_model", + "model", + "backbone", + "language_model.backbone", + ], "base_model_embeddings_path": [ "model.embed_tokens", "backbone.embeddings", "language_model.backbone.embeddings", + "model.language_model.embed_tokens", ], "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], } @@ -463,6 +509,8 @@ def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None def pop_and_gather_aux_hiddens(self): """Pop auxiliary hidden states from base model and gather them on the draft model device.""" + if not self.eagle_config.use_aux_hidden_state: + return None # In PTQ, forward method will be called with try and except to find max batch size. # This leads to uncleared aux hidden states in the front of the list. # To fix it, we only return the last num_aux_h items in the list. @@ -471,9 +519,11 @@ def pop_and_gather_aux_hiddens(self): self._aux_hidden_states.clear() # Gather aux hidden states on the draft model device - aux_h_list = [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list] + aux_hiddens = torch.cat( + [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list], dim=-1 + ) - return aux_h_list + return aux_hiddens def _get_eagle_device(self): """Return the device where we should place eagle module.""" @@ -535,28 +585,17 @@ def modify( self.eagle_config._attn_implementation = "sdpa" # Patch for Kimi-K2-Thinking, avoid quantizing drafter - if ( - hasattr(self.config, "quantization_config") - and self.config.quantization_config.quant_method - == QuantizationMethod.COMPRESSED_TENSORS - ): - self.config.quantization_config.quantization_config.ignore.append("re:.*eagle_module.*") + quant_config = getattr(self.config, "quantization_config", None) + if isinstance(quant_config, CompressedTensorsConfig): + quant_config.ignore.append("re:.*eagle_module.*") - # Use default aux_hidden_state layers if use_aux_hidden_state is True - # but no layer id is given + # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 ): self._set_default_aux_hidden_state_layers() - if self._base_llm_config.hidden_size != self.eagle_config.hidden_size: - raise ValueError( - "EAGLE module hidden size " - f"{self.eagle_config.hidden_size} must match base model hidden size " - f"{self._base_llm_config.hidden_size}!" - ) - # Freeze all parameters if self.eagle_freeze_base_model: for name, param in self.named_parameters(): @@ -598,25 +637,26 @@ def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): return self._cached_attn_blk_masks[ttt_step] def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length + self, attention_mask, input_shape, past_key_values_length, device, dtype ): """Expand the 2-D attention mask to 4-D and apply causal mask.""" # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None + # construct causal mask if input_shape[-1] > 1: combined_attention_mask = make_causal_mask( input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, + dtype, + device=device, past_key_values_length=past_key_values_length, ) - + # merge causal mask with padding mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) + expanded_attn_mask = expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to( + device + ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None @@ -625,73 +665,76 @@ def _prepare_decoder_attention_mask( return combined_attention_mask - def _get_eagle_module_inputs( + def _prepare_eagle_inputs( self, input_ids, - eagle_input_hidden_states, attention_mask, position_ids, eagle_cache, + base_outputs, ): """Helper function to prepare eagle inputs for the 0th eagle forward pass.""" - b, seq_length, _ = eagle_input_hidden_states.shape - past_key_values_length = eagle_cache.get_seq_length() if eagle_cache is not None else 0 - seq_length_with_past = seq_length + past_key_values_length + b, seq_length = input_ids.shape + past_kv_len = eagle_cache.get_seq_length() if eagle_cache is not None else 0 + seq_len_with_past = seq_length + past_kv_len - # Prepare eagle_input_ids: Shift left 1 token - zeropadding = torch.zeros( - input_ids.shape[0], 1, dtype=input_ids.dtype, device=input_ids.device - ) - eagle_input_ids = torch.cat((input_ids[:, 1:], zeropadding), dim=1) + # Prepare eagle_input_embeds: Shift left 1 token + with torch.no_grad(): + if base_outputs.input_embeds is None: + eagle_input_embeds = self._base_model_embeddings(input_ids.roll(-1, 1)) + else: + eagle_input_embeds = base_outputs.input_embeds.roll(-1, 1) + + # Prepare eagle_input_hiddens + if self.eagle_config.use_aux_hidden_state: + # Eagle3: concat base model intermediate (pre-norm) hiddens + eagle_input_hiddens = self.eagle_module.fc(base_outputs.aux_hiddens) + else: + # Eagle1: use base model output (post-norm)hiddens + eagle_input_hiddens = base_outputs.out_hiddens # Prepare attention_mask - if attention_mask is not None: # Shift left 1 token for attention_mask - zeropadding = torch.zeros( - attention_mask.shape[0], 1, dtype=attention_mask.dtype, device=attention_mask.device + if attention_mask is None: + eagle_attention_mask = torch.ones( # default: all tokens are valid + (b, seq_len_with_past), dtype=torch.bool, device=eagle_input_hiddens.device ) - attention_mask = torch.cat((attention_mask[:, 1:], zeropadding), dim=1) else: - attention_mask = torch.ones( # Initialize default attention_mask - (b, seq_length_with_past), dtype=torch.bool, device=eagle_input_hidden_states.device - ) - + eagle_attention_mask = attention_mask.roll(-1, 1) # Shift left 1 token # Expand the 2-D attention mask to 4-D and apply causal mask. - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (b, seq_length), eagle_input_hidden_states, past_key_values_length + eagle_attention_mask = self._prepare_decoder_attention_mask( + eagle_attention_mask, + (b, seq_length), + past_kv_len, + eagle_input_hiddens.device, + eagle_input_hiddens.dtype, ) # Prepare position_ids if position_ids is None: - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=eagle_input_hidden_states.device, + eagle_position_ids = ( + torch.arange( + past_kv_len, + seq_len_with_past, + dtype=torch.long, + device=eagle_input_hiddens.device, + ) + .unsqueeze(0) + .view(-1, seq_length) ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: - position_ids = position_ids.view(-1, seq_length).long() + eagle_position_ids = position_ids.view(-1, seq_length).long() - return eagle_input_ids, attention_mask, position_ids + return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step ) -> BlockMask | torch.Tensor: """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" - - def msk_func(b, h, q_idx, kv_idx): - mask = kv_idx <= (q_idx - ttt_step) - for i in range(1, ttt_step + 1): - mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( - kv_idx >= seq_length * i - ) - mask = mask | mask_block_i - return mask - + msk_func = get_ttt_msk_func(seq_length, ttt_step) dtypemin = torch.finfo(self._base_llm_config.dtype).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) - if self.eagle_module.config._attn_implementation == "flex_attention": + if self.eagle_config._attn_implementation == "flex_attention": # Return block mask for flex attention block_mask = create_block_mask(msk_func, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len) return block_mask @@ -707,40 +750,10 @@ def msk_func(b, h, q_idx, kv_idx): tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) + # Note: (hg) repeat mask for kimi-k2 compatibility tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask - def _llm_or_vlm_embedding(self, input_ids, kwargs): - """Return input embeddings with possibly vision embeddings for VLM.""" - tok_embeds = self._base_model_embeddings(input_ids) - - # LLM only have token embeddings - if "pixel_values" not in kwargs: - return tok_embeds - - # Otherwise, insert vision embeddings in tok_embeds - if self.config.model_type == "NemotronH_Nano_VL_V2": - vit_embeds = self.extract_feature(kwargs["pixel_values"]) - vit_embeds = vit_embeds[kwargs["image_flags"] == 1] - bs, seq_len, hid_size = tok_embeds.shape - tok_embeds = tok_embeds.reshape(bs * seq_len, hid_size) - input_ids = input_ids.reshape(bs * seq_len) - selected = input_ids == self.img_context_token_id - try: - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds.reshape(-1, hid_size) - except Exception as e: - vit_embeds = vit_embeds.reshape(-1, hid_size) - print( - f"warning: {e}, tok_embeds[selected].shape={tok_embeds[selected].shape}, " - f"vit_embeds.shape={vit_embeds.shape}" - ) - n_token = selected.sum() - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds[:n_token] - del vit_embeds - return tok_embeds.reshape(bs, seq_len, hid_size) - else: - raise ValueError(f"VLM model type {self.config.model_type} not supported") - def _base_model_forward( self, input_ids, @@ -761,6 +774,7 @@ def _base_model_forward( **kwargs, ) past_key_values = getattr(outputs, "past_key_values", None) + base_input_embeds = outputs.hidden_states[0] base_model_hidden_states = outputs.hidden_states[-1] base_model_logits = outputs.logits @@ -772,9 +786,16 @@ def _base_model_forward( labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, labels) - return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values + return EagleBaseModelOutput( + input_embeds=base_input_embeds, + aux_hiddens=self.pop_and_gather_aux_hiddens(), + out_hiddens=base_model_hidden_states, + logits=base_model_logits, + loss=base_model_loss, + ), past_key_values def _map_logits_to_draft_vocab(self, full_logits): + assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" reverse_mapping = ( torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) + self.eagle_module.d2t @@ -831,125 +852,95 @@ def forward( """Forward pass of the EagleModel. Returns: - hidden_states: The hidden state from the base model. - logits: logits from the base model. - eagle_hidden_states: The hidden state from eagle_module. - eagle_logits: logits from the eagle_module. + loss: Loss of base model or eagle model. + logits: Base model logits. + past_key_values: Base model past key values with eagle cache attached. + hidden_states: Base model hidden states. + train_acc: Drafter training accuracies. """ - if past_key_values is not None and hasattr(past_key_values, "eagle_cache"): - eagle_cache = past_key_values.eagle_cache - else: - eagle_cache = None + eagle_cache = getattr(past_key_values, "eagle_cache", None) if self.training: - assert eagle_cache is None, "eagle_cache should be None in training" assert past_key_values is None, "past_key_values should be None in training" if loss_mask is None: - loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + # By default, mask out padding tokens in loss computation + loss_mask = ( + attention_mask.clone().detach() + if attention_mask is not None + else torch.ones_like(input_ids, dtype=torch.bool) + ) - # ====First, we run base model forward==== - if "base_model_outputs" in kwargs: + # ====First, run base model forward==== + if self.eagle_offline: # Parse base model outputs forwarded from teacher - base_outputs = kwargs["base_model_outputs"] - base_model_hidden_states = base_outputs["base_model_hidden_states"] - if "base_model_logits" in base_outputs: - base_model_logits = base_outputs["base_model_logits"] - else: - base_model_logits = self.lm_head(base_model_hidden_states) - base_model_loss = None - past_key_values = DynamicCache() # Dummy cache - + assert "base_model_outputs" in kwargs + base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) + if base_outputs.logits is None: + base_outputs.logits = self.lm_head(base_outputs.out_hiddens) + past_key_values = None else: - base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = ( - self._base_model_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - self.eagle_freeze_base_model, - labels, - **kwargs, - ) + base_outputs, past_key_values = self._base_model_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + self.eagle_freeze_base_model, + labels, + **kwargs, ) if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values = _get_empty_cache(self._base_llm_config) if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + eagle_cache = _get_empty_cache(self.eagle_module.config) + past_key_values.eagle_cache = eagle_cache - # ====Run eagle forward==== + # ====Prepare inputs for the first eagle forward pass==== eagle_loss = None train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)] - # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers - b, seq_length, h = base_model_hidden_states.shape - if self.eagle_config.use_aux_hidden_state: - if "base_model_outputs" in kwargs: - aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"] - else: - aux_hidden_states = torch.cat(self.pop_and_gather_aux_hiddens(), dim=-1) - eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states) - else: - eagle_input_hidden_states = base_model_hidden_states - - # Get eagle inputs for the first eagle forward pass - eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs( + b, seq_length, _ = base_outputs.out_hiddens.shape + ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attn_mask_0, + eagle_position_ids, + ) = self._prepare_eagle_inputs( input_ids, - eagle_input_hidden_states, attention_mask, position_ids, eagle_cache, + base_outputs, ) - with torch.no_grad(): - inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs) - - past_key_values.eagle_cache = eagle_cache - # ====Perform training-time-testing with 3 extra eagle forward passes==== + # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.num_ttt_steps): - attention_mask = ( - attention_mask_0 + # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. + eagle_attention_mask = ( + eagle_attn_mask_0 if ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states, - inputs_embeds, - attention_mask, - position_ids, - eagle_cache, - ) - eagle_input_hidden_states = torch.cat( - ( - torch.zeros( - (b, 1, h), - dtype=eagle_input_hidden_states.dtype, - device=eagle_input_hidden_states.device, - ), - eagle_input_hidden_states[:, :-1, :], - ), - dim=1, - ) + with enable_cp_ttt_patch() if self.training else contextlib.nullcontext(): + _, eagle_input_hiddens, eagle_logits, eagle_cache = self._eagle_forward( + eagle_input_hiddens, + eagle_input_embeds, + eagle_attention_mask, + eagle_position_ids, + eagle_cache, + ) + eagle_input_hiddens = eagle_input_hiddens.roll(1, 1) for i in range(self.eagle_config.parallel_draft_step): eagle_logit = eagle_logits[i] classification_loss, acc = self._eagle_loss( # base model predict +1 tok, while eagle predict +2 # so we shift base model outputs compared to eagle outputs - base_model_logits[:, 1 + i :], - eagle_logit[:, : -(1 + i)], # additionally, we mask the first n tok of eagle outputs at nth TTT step - torch.cat( - ( - torch.zeros( - b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device - ), - loss_mask[:, 1 + ttt_step :] - if i == 0 - else loss_mask[:, 1 + ttt_step : -i], - ), - dim=1, - ), + base_outputs.logits[:, 1 + i + ttt_step :], + eagle_logit[:, ttt_step : -(1 + i)], + loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], ) + # Apply loss decay factor to focus on early steps classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) eagle_loss = ( classification_loss if eagle_loss is None else eagle_loss + classification_loss @@ -957,24 +948,19 @@ def forward( train_accs[i].append(acc) if not self.training: break - # Finally, we merge base model loss and eagle loss, raise error if both are None - if base_model_loss is not None and eagle_loss is not None: - loss = base_model_loss + eagle_loss - elif base_model_loss is not None: - loss = base_model_loss - elif eagle_loss is not None: - loss = eagle_loss - else: + + # Merge base model loss and eagle loss + if base_outputs.loss is None and eagle_loss is None: loss = None - assert not self.training, ValueError( - "Both base_model_loss and eagle_loss are skipped. At least one loss must be computed." - ) + assert not self.training, "At least one loss must be computed for training." + else: + loss = (base_outputs.loss or 0) + (eagle_loss or 0) return ModelOutput( loss=loss, - logits=base_model_logits, + logits=base_outputs.logits, past_key_values=past_key_values, - hidden_states=base_model_hidden_states, + hidden_states=base_outputs.out_hiddens, train_acc=train_accs, ) @@ -986,9 +972,8 @@ def _eagle_loss( ): """Function for EAGLE loss computing.""" if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) - loss_mask = loss_mask[:, :, None] + loss_mask = loss_mask[:, : eagle_logits.shape[1], None] classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits ) @@ -1038,34 +1023,28 @@ def pseudo_speculative_generate( # EAGLE-3 # Only the first iteration input_hidden_states are from aux_hidden_state layers # Gather _aux_hidden_states from all devices before concatenation - gathered_aux_hidden_states = self.pop_and_gather_aux_hiddens() - eagle_input_hidden_states = self.eagle_module.fc( - torch.cat(gathered_aux_hidden_states, dim=-1) - ) - + eagle_input_hidden_states = self.eagle_module.fc(self.pop_and_gather_aux_hiddens()) else: eagle_input_hidden_states = base_model_hidden_states draft_tokens = [] for step in range(steps): - # Get eagle inputs for the first eagle forward pass - _, eagle_attention_mask, eagle_position_ids = self._get_eagle_module_inputs( - input_ids, - eagle_input_hidden_states, - None, - None, + b, seq_length = eagle_ids.shape + eagle_attention_mask = self._prepare_decoder_attention_mask( None, + (b, seq_length), + 0, + eagle_input_hidden_states.device, + eagle_input_hidden_states.dtype, ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value( - self.eagle_module.config, "_attn_implementation", "sdpa" - ): + with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids), eagle_attention_mask, - eagle_position_ids, + None, ) # parallel logits are only used after the last step diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 1f919de065..e345386653 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -25,8 +25,10 @@ import torch import torch.distributed +import transformers from huggingface_hub import snapshot_download from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers.cache_utils import DynamicCache KIMI_K2_REPO_ID = "moonshotai/Kimi-K2-Thinking" @@ -41,6 +43,9 @@ def calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None): """Given a calibration text, find the most common vocabs and return the mapping.""" conversations = tokenizer.apply_chat_template(text) + # Transformers5.x returns a BatchEncoding from apply_chat_template + if hasattr(conversations, "input_ids"): + conversations = conversations.input_ids counter = Counter(conversations) vocab = counter.most_common(target_vocab_size) mapping = torch.zeros(target_vocab_size, dtype=torch.int64) @@ -439,3 +444,87 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") + + +def get_ttt_msk_func(seq_length, ttt_step): + """Return mask function for Eagle3 Training Time Test.""" + + def ttt_msk_func(b, h, q_idx, kv_idx): + mask = kv_idx <= (q_idx - ttt_step) + for i in range(1, ttt_step + 1): + mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( + kv_idx >= seq_length * i + ) + mask = mask | mask_block_i + return mask + + return ttt_msk_func + + +@contextlib.contextmanager +def enable_cp_ttt_patch(): + """Context manager to enable CP TTT patch.""" + import modelopt.torch.speculative.plugins.transformers + + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + try: + yield + finally: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False + + +def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): + """Load a VLM or LLM with kwargs. Returns the model and model config.""" + model_config = transformers.AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + if "vl" in model_config.model_type.lower(): + model_cls = transformers.AutoModelForVision2Seq + else: + model_cls = transformers.AutoModelForCausalLM + + return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) + + +@contextlib.contextmanager +def patch_transformers5_params_loading(): + """Patch transformers 5.x parameter loading to preserve original `requires_grad` settings. + + In transformers v5.x, loading a checkpoint forcibly sets parameters' requires_grad, + which may unintentionally unfreeze frozen parameters. This monkey-patch restores the original + `requires_grad` after loading parameters. + + Reference: + https://github.com/huggingface/transformers/blob/v5.0.0.rc1-release/src/transformers/core_model_loading.py#L640 + """ + # Skip patching for non-applicable transformers version + if importlib.util.find_spec("transformers.core_model_loading") is None: + return + from transformers import core_model_loading + + if not hasattr(core_model_loading, "set_param_for_module"): + return + + orig_set_param_for_module = core_model_loading.set_param_for_module + + def patched_set_param_for_module(*args, **kwargs): + """Monkey-patch set_param_for_module to restore original requires_grad.""" + model, target_name = args[:2] + module_path, _, param_name = target_name.rpartition(".") + module_obj = model.get_submodule(module_path) if module_path else model + + # Get original requires_grad value + orig_requires_grad = getattr(module_obj, param_name).requires_grad + + # Call set_param_for_module + orig_set_param_for_module(*args, **kwargs) + + # Restore original requires_grad value + getattr(module_obj, param_name).requires_grad = orig_requires_grad + + try: + core_model_loading.set_param_for_module = patched_set_param_for_module + yield + finally: + core_model_loading.set_param_for_module = orig_set_param_for_module diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 7908ec5146..16bff49c2e 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -57,6 +57,7 @@ "split": ["stem", "chat", "math", "code"], }, "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "chat_key": "messages", }, "nemotron-post-training-dataset-v1": { "config": { @@ -64,6 +65,7 @@ "split": ["stem", "chat", "math", "code", "tool_calling"], }, "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["messages"]), + "chat_key": "messages", }, "magpie": { "config": { @@ -71,9 +73,10 @@ "split": ["train"], }, "preprocess": lambda sample: "\n".join(turn["value"] for turn in sample["conversations"]), + "chat_key": "conversations", }, "cnn_dailymail": { - "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, + "config": {"path": "abisee/cnn_dailymail", "name": "3.0.0", "split": ["train"]}, "preprocess": lambda sample: sample["article"], }, "pile": { @@ -92,22 +95,36 @@ "config": {"path": "c4", "name": "en", "split": ["train"]}, "preprocess": lambda sample: sample["text"], }, + "wikitext": { + "config": {"path": "wikitext", "name": "wikitext-103-v1", "split": ["train"]}, + "preprocess": lambda sample: sample["text"], + }, } __all__ = [ "create_forward_loop", "get_dataset_dataloader", + "get_dataset_samples", "get_max_batch_size", "get_supported_datasets", ] -def _get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: +def get_dataset_samples( + dataset_name: str, + num_samples: int, + *, + apply_chat_template: bool = False, + tokenizer: "PreTrainedTokenizerBase | None" = None, +) -> list[str]: """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). + tokenizer: Tokenizer to use for applying the chat template to the samples. + No tokenization is done and plain text is still returned. Returns: Samples: The list of samples. @@ -122,6 +139,15 @@ def _get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: from datasets import load_dataset dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] + if apply_chat_template: + if "chat_key" not in dataset_config: + warn( + f"Dataset {dataset_name} does not support chat template. Chat template will not be applied." + ) + elif tokenizer is None: + raise ValueError("Tokenizer is required when applying chat template.") + print(f"Applying chat template to dataset {dataset_name}") + # It's unfortunate that the load_dataset function does not support split a list while streaming. # So we need to load the dataset for each split. config = dataset_config["config"].copy() @@ -147,7 +173,14 @@ def _get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]: break # Apply preprocess function to the sample - samples.append(dataset_config["preprocess"](sample)) + if apply_chat_template and "chat_key" in dataset_config: + sample = tokenizer.apply_chat_template( # type: ignore[union-attr] + sample[dataset_config["chat_key"]], tokenize=False + ) + else: + sample = dataset_config["preprocess"](sample) + if sample != "": # wikitext has some empty samples + samples.append(sample) return samples @@ -211,7 +244,7 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): - samples = _get_dataset_samples(ds_name, num_sample) + samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) batch_encoded = tokenizer.batch_encode_plus( diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 9b32d1ac46..7922b68805 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -23,6 +23,7 @@ from contextlib import suppress from datetime import timedelta from typing import Any +from warnings import warn import torch import torch.distributed @@ -76,7 +77,8 @@ def local_rank() -> int: """Returns the local rank of the current process.""" if "LOCAL_RANK" in os.environ: return int(os.environ["LOCAL_RANK"]) - raise RuntimeError("LOCAL_RANK environment variable not found.") + warn("LOCAL_RANK environment variable not found. Using global rank instead.") + return rank() def is_master(group=None) -> bool: @@ -101,6 +103,7 @@ def _deserialize(tensor: torch.Tensor, size: int | None = None) -> Any: buffer = tensor.numpy().tobytes() if size is not None: buffer = buffer[:size] + # Security NOTE: weights_only=False is used here on internally-generated buffer, not on untrusted user input obj = torch.load(io.BytesIO(buffer), weights_only=False) return obj diff --git a/modelopt/torch/utils/import_utils.py b/modelopt/torch/utils/import_utils.py index 4193bb2b34..8229da51e5 100644 --- a/modelopt/torch/utils/import_utils.py +++ b/modelopt/torch/utils/import_utils.py @@ -33,6 +33,6 @@ def import_plugin(plugin_name, msg_if_missing=None, verbose=True, success_msg=No except Exception as e: if verbose: warn_rank_0( - f"Failed to import {plugin_name} plugin due to: {e!r}. " + f"Failed to import modelopt {plugin_name} plugin due to: {e!r}. " "You may ignore this warning if you do not need this plugin." ) diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index b8e7aecce1..ada1b53612 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -46,7 +46,7 @@ def num2hrb(num: float, suffix="") -> str: """Convert big floating number to human readable string.""" step = 1000 # step between units is 1000 - units = ["", "K", "M", "G", "T", "P", "E"] + units = ["", "K", "M", "B", "T", "P", "E"] while abs(num) >= step and len(units) > 1: num /= step units.pop(0) diff --git a/modelopt/torch/utils/nemotron_vlm_dataset_utils.py b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py new file mode 100644 index 0000000000..fb2b085935 --- /dev/null +++ b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VLM dataset utilities. + +This module contains the Nemotron-VLM-Dataset-v2 specific logic: +- Subsets can store images in `media/shard_*.tar` (images only) +- Prompts/messages live in `/.jsonl` and reference the image filename (e.g. `292180.png`) + +We join the tar images with the JSONL messages by the shared filename and yield samples compatible with our +VLM calibration pipeline. +""" + +from __future__ import annotations + +import functools +import json +import os +import random +import tarfile +from io import BytesIO +from typing import Any + +import torch + +_IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + + +@functools.lru_cache(maxsize=8) +def list_repo_files_cached(repo_id: str, repo_type: str = "dataset") -> list[str]: + """List files in a HuggingFace repo (cached). + + Args: + repo_id: HF repo id (e.g., a dataset repo). + repo_type: HF repo type, usually "dataset" here. + """ + from huggingface_hub import list_repo_files + + return list_repo_files(repo_id=repo_id, repo_type=repo_type) + + +def extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image reference from Nemotron-style `messages`.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + for key in ("image", "images", "path", "image_url", "url", "value", "data"): + if key in part: + return part[key] + return None + + +class NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): + """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages).""" + + def __init__( + self, + repo_id: str, + subsets: list[str], + shard_paths: list[str], + num_samples: int, + seed: int, + shuffle_buffer_size: int, + max_shards: int | None, + ): + """Create an iterable dataset for Nemotron-VLM-Dataset-v2. + + Args: + repo_id: Dataset repo id, e.g. "nvidia/Nemotron-VLM-Dataset-v2". + subsets: Subset names to draw from (e.g., "sparsetables"). + shard_paths: Tar shard paths under `/media/`. + num_samples: Total number of samples to yield. + seed: RNG seed for sampling. + shuffle_buffer_size: Unused for now (kept for API compatibility). + max_shards: Max number of shards to use per subset (limits downloads). + """ + super().__init__() + self.repo_id = repo_id + self.subsets = subsets + self.shard_paths = shard_paths + self.num_samples = num_samples + self.seed = seed + self.shuffle_buffer_size = shuffle_buffer_size + self.max_shards = max_shards + + def __iter__(self): + from huggingface_hub import hf_hub_download + from PIL import Image + + rng = random.Random(self.seed) + + # Partition shards by subset. + shards_by_subset: dict[str, list[str]] = {s: [] for s in self.subsets} + for p in self.shard_paths: + subset = p.split("/", 1)[0] + if subset in shards_by_subset: + shards_by_subset[subset].append(p) + + for subset in list(shards_by_subset.keys()): + shard_list = sorted(shards_by_subset[subset]) + if self.max_shards is not None: + shard_list = shard_list[: max(0, self.max_shards)] + shards_by_subset[subset] = shard_list + + # Roughly split sample budget across subsets. + per_subset_target = max(1, self.num_samples // max(1, len(self.subsets))) + yielded_total = 0 + + for subset in self.subsets: + if yielded_total >= self.num_samples: + break + + shard_list = list(shards_by_subset.get(subset, [])) + if not shard_list: + continue + rng.shuffle(shard_list) + local_tar_paths = { + shard: hf_hub_download(repo_id=self.repo_id, filename=shard, repo_type="dataset") + for shard in shard_list + } + + # 1) Collect candidate image filenames from tar headers (no payload reads). + candidate_names: list[str] = [] + header_limit = per_subset_target * 50 + for shard in shard_list: + local_tar = local_tar_paths[shard] + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if not member.isfile(): + continue + name = member.name + _, ext = os.path.splitext(name) + if ext.lower() not in _IMG_EXTS: + continue + candidate_names.append(name) + if len(candidate_names) >= header_limit: + break + if len(candidate_names) >= header_limit: + break + + if not candidate_names: + continue + + rng.shuffle(candidate_names) + lookup_limit = per_subset_target * 10 + candidate_set = set(candidate_names[:lookup_limit]) + + # 2) Scan JSONL to map image filename -> messages. + jsonl_path = hf_hub_download( + repo_id=self.repo_id, filename=f"{subset}/{subset}.jsonl", repo_type="dataset" + ) + meta_by_image: dict[str, dict[str, Any]] = {} + with open(jsonl_path, encoding="utf-8") as f: + for line in f: + try: + obj = json.loads(line) + except Exception: + continue + msgs = obj.get("messages") + img_name = extract_first_image_from_messages(msgs) if msgs is not None else None + if isinstance(img_name, str) and img_name in candidate_set: + meta_by_image[img_name] = {"id": obj.get("id"), "messages": msgs} + if len(meta_by_image) >= per_subset_target: + break + + if not meta_by_image: + continue + + # 3) Extract matched images and yield samples. + needed = set(meta_by_image.keys()) + for shard in shard_list: + if yielded_total >= self.num_samples or not needed: + break + local_tar = local_tar_paths[shard] + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if yielded_total >= self.num_samples or not needed: + break + if not member.isfile(): + continue + name = member.name + if name not in needed: + continue + f = tf.extractfile(member) + if f is None: + continue + try: + raw = f.read() + if isinstance(raw, str): + raw = raw.encode() + raw_bytes: bytes = raw + img = Image.open(BytesIO(raw_bytes)).convert("RGB") + except Exception: + continue + meta = meta_by_image.get(name) + if not meta: + continue + yield { + "id": meta.get("id", name), + "messages": meta.get("messages"), + "image": img, + } + needed.discard(name) + yielded_total += 1 diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 1940295c3c..b54332375b 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -26,6 +26,10 @@ import torch import torch.distributed.fsdp import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm +from tqdm import tqdm + +from .tensor import torch_to try: from torch.distributed.fsdp._state_dict_utils import _convert_to_wrapped_module_name @@ -38,11 +42,6 @@ def _convert_to_wrapped_module_name(name: str) -> str: return name -from torch.nn.modules.batchnorm import _BatchNorm -from tqdm import tqdm - -from .tensor import torch_to - __all__ = [ "ModelLike", "compare_dict", @@ -56,7 +55,6 @@ def _convert_to_wrapped_module_name(name: str) -> str: "is_parallel", "make_divisible", "model_to", - "param_num", "param_num_from_forward", "remove_bn", "run_forward_loop", @@ -101,28 +99,6 @@ def get_module_device(module: nn.Module) -> torch.device: return torch.device("cpu") -def param_num(network: nn.Module, trainable_only: bool = False, unit=1e6) -> float: - """Get the number of parameters of a PyTorch model. - - Args: - network: The PyTorch model. - trainable_only: Whether to only count trainable parameters. Default is False. - unit: The unit to return the number of parameters in. Default is 1e6 (million). - - Returns: - The number of parameters in the model in the given unit. - """ - return ( - sum( - p.numel() if not trainable_only or p.requires_grad else 0 - for mod in network.modules() - for p in mod.parameters(recurse=False) - if not isinstance(mod, _BatchNorm) - ) - / unit - ) - - # TODO: we could also use the same approach as in inference_flops to get the number of params, # which might be more accurate. Another approach could be to run a backwards pass and use a hook # on the tensor directly. @@ -142,7 +118,7 @@ def param_num_from_forward( Returns: The number of parameters from the model's forward pass in the given unit. - This can helpful for dynamic modules, where the state dict might contain extra parameters that + This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass. diff --git a/modelopt/torch/utils/perf.py b/modelopt/torch/utils/perf.py index cd2652f947..dd4bb7ff52 100644 --- a/modelopt/torch/utils/perf.py +++ b/modelopt/torch/utils/perf.py @@ -28,6 +28,7 @@ "Timer", "clear_cuda_cache", "get_cuda_memory_stats", + "get_used_gpu_mem_fraction", "report_memory", ] @@ -48,6 +49,23 @@ def get_cuda_memory_stats(device=None): } +def get_used_gpu_mem_fraction(device="cuda:0"): + """Get used GPU memory as a fraction of total memory. + + Args: + device: Device identifier (default: "cuda:0") + + Returns: + float: Fraction of GPU memory currently used (0.0 to 1.0). + Returns 0.0 if CUDA is not available. + """ + if not torch.cuda.is_available(): + return 0.0 + + free_memory, total_memory = torch.cuda.mem_get_info(device) + return (total_memory - free_memory) / total_memory + + def report_memory(name="", rank=0, device=None): """Simple GPU memory report.""" memory_stats = get_cuda_memory_stats(device) diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index 517c599143..fd00e423f0 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -25,3 +25,8 @@ with import_plugin("megatron_preprocess_data"): from .megatron_preprocess_data import * + +# NOTE: Dont pre-import megatron bridge plugin here to avoid circular dependency issues. +# We dont register anything so this isnt a problem. +# with import_plugin("megatron bridge"): +# from .mbridge import * diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py new file mode 100644 index 0000000000..94cdf87cf5 --- /dev/null +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megatron-Bridge plugins for using with Model-Optimizer.""" + +from collections.abc import Callable +from typing import Any + +import torch.nn as nn +from datasets import DatasetDict +from megatron.bridge import AutoBridge +from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig +from megatron.bridge.data.loaders import setup_data_iterators +from megatron.bridge.data.utils import get_dataset_provider +from megatron.bridge.models.gpt_provider import GPTModelProvider, modelopt_transformer_layer_spec +from megatron.bridge.models.hf_pretrained.utils import is_safe_repo +from megatron.bridge.models.mamba.mamba_provider import ( + MambaModelProvider, + modelopt_mamba_stack_spec, +) +from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + OptimizerConfig, + SchedulerConfig, + TrainingConfig, + runtime_config_update, +) +from megatron.bridge.training.eval import evaluate_and_print_results +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.core.models.gpt import GPTModel +from megatron.core.models.mamba import MambaModel +from megatron.core.parallel_state import get_data_parallel_group +from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import unwrap_model +from transformers import AutoTokenizer + +from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0 + +__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"] + + +def load_mbridge_model_from_hf( + *, + hf_model_name_or_path: str, + trust_remote_code: bool = False, + provider_overrides: dict[str, Any] | None = None, + init_model_parallel: bool = True, +) -> tuple[ + AutoBridge, + GPTModelProvider | MambaModelProvider, + list[MegatronModule], + GPTModel | MambaModel, + AutoTokenizer, +]: + """Load a Megatron-Bridge model from HF. + + Args: + hf_model_name_or_path: The name or path of the HF model. + trust_remote_code: Whether to trust remote code. + provider_overrides: Overrides for the provider. + init_model_parallel: Whether to initialize model parallel. + + Returns: + A tuple of (bridge, provider, model, unwrapped_model, tokenizer). + """ + print_rank_0(f"Loading Megatron-Bridge model from HF: {hf_model_name_or_path}") + trust_remote_code = is_safe_repo( + trust_remote_code=trust_remote_code, + hf_path=hf_model_name_or_path, + ) + bridge = AutoBridge.from_hf_pretrained( + hf_model_name_or_path, trust_remote_code=trust_remote_code + ) + + provider = bridge.to_megatron_provider() + if provider_overrides: + for key, value in provider_overrides.items(): + assert hasattr(provider, key), f"{type(provider)} does not have attribute {key}" + setattr(provider, key, value) + + print_rank_0("Setting ModelOpt spec for model provider") + if isinstance(provider, MambaModelProvider): + provider.mamba_stack_spec = modelopt_mamba_stack_spec + else: + provider.transformer_layer_spec = modelopt_transformer_layer_spec + + provider.finalize() + if init_model_parallel: + provider.initialize_model_parallel(seed=0) + + model = provider.provide_distributed_model(wrap_with_ddp=False) + assert len(model) == 1 + unwrapped_model = unwrap_model(model[0]) + assert isinstance(unwrapped_model, (GPTModel, MambaModel)) + + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name_or_path, trust_remote_code=trust_remote_code + ) + + return bridge, provider, model, unwrapped_model, tokenizer + + +def _get_dataset_cfg( + dataset_name: str, + num_samples: int, + seq_length: int, + apply_chat_template: bool = True, + tokenizer: AutoTokenizer | None = None, +) -> HFDatasetConfig: + """Get a dataset config for the dataset.""" + dataset = get_dataset_samples( + dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer + ) + dataset_cfg = HFDatasetConfig( + dataset_name=f"{dataset_name}_{num_samples}", + dataset_dict=DatasetDict({"train": dataset}), + process_example_fn=lambda example, tokenizer: {"input": example, "output": ""}, + seq_length=seq_length, + dataloader_type="batch", + num_workers=1, + do_validation=False, + do_test=False, + val_proportion=None, + split_val_from_train=False, + rewrite=True, + ) + + return dataset_cfg + + +def get_hf_mbridge_calibration_loop( + *, + model: list[MegatronModule], + provider: GPTModelProvider | MambaModelProvider, + tokenizer: AutoTokenizer, + hf_model_name_or_path: str, + trust_remote_code: bool = False, + dataset_name: str = "nemotron-post-training-dataset-v2", + num_samples: int = 512, + micro_batch_size: int = 1, + global_batch_size: int = 1, +) -> Callable[[nn.Module], None]: + """Get a modelopt calibration loop for a Megatron-Bridge model. + + Args: + model: The model to calibrate. + provider: The provider to use for the model. + tokenizer: The tokenizer to use for the model. + hf_model_name_or_path: The name or path of the HF model. + trust_remote_code: Whether to trust remote code. + dataset_name: The name of the dataset to use for evaluation. + num_samples: The number of samples to use for evaluation. + micro_batch_size: The micro batch size to use for evaluation. + global_batch_size: The global batch size to use for evaluation. + + Returns: + A function that can be used to calibrate the model with a modelopt.torch API. + """ + if global_batch_size < micro_batch_size: + warn_rank_0( + f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}." + ) + global_batch_size = micro_batch_size + num_iters = num_samples // global_batch_size + + # NOTE: Issue with NemotronH tokenizer's len() hence using use_fast=True as a WAR + use_fast_tokenizer = isinstance(provider, NemotronHModelProvider) + + cfg = ConfigContainer( + model=provider, + train=TrainingConfig( + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + train_iters=num_iters, + eval_iters=num_iters, + skip_train=True, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True), + dataset=_get_dataset_cfg( + dataset_name, + num_samples, + provider.seq_length, + apply_chat_template=True, + tokenizer=tokenizer, + ), + tokenizer=TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name_or_path, + hf_tokenizer_kwargs={ + "trust_remote_code": trust_remote_code, + "use_fast": use_fast_tokenizer, + }, + ), + # Unused + optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False), + scheduler=SchedulerConfig(lr_decay_style="constant"), + logger=LoggerConfig(), + checkpoint=CheckpointConfig(), + ) + runtime_config_update(cfg) + + state = GlobalState() + state.cfg = cfg + + dataset_provider = get_dataset_provider(cfg.dataset) + + def _train_valid_test_datasets_provider( + train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig + ): + return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer) + + train_data_iterator, _, _ = setup_data_iterators( + cfg=cfg, + train_state=state.train_state, + model_length=len(model), + train_valid_test_datasets_provider=_train_valid_test_datasets_provider, + dp_group=get_data_parallel_group(), + ) + + def forward_loop(m): + evaluate_and_print_results( + state, + prefix="iteration 1", + forward_step_func=forward_step, + data_iterator=train_data_iterator, + model=model, + config=cfg, + verbose=True, + write_to_tensorboard=False, + ) + + return forward_loop diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index d542d935ac..63b66f360a 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -24,6 +24,8 @@ from megatron.core.transformer import MegatronModule from tqdm import tqdm +__all__ = ["megatron_generate", "megatron_prefill"] + def get_current_memory_info(): """Get current memory usage.""" @@ -208,13 +210,48 @@ def _forward_step_func(data, model): # NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported. position_ids = None - output_tensor = model( - data["tokens"], - position_ids, - attention_mask, - inference_context=inference_context, - runtime_gather_output=True, - ) + # Check if this is a VLM model (has vision inputs) + _has_pixel_values = data.get("pixel_values") is not None + _has_image_grid_thw = data.get("image_grid_thw") is not None + _has_image_sizes = data.get("image_sizes") is not None + has_vision_inputs = _has_pixel_values or _has_image_grid_thw or _has_image_sizes + + if has_vision_inputs: + # For VLM models: + # - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions) + # - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal) + vlm_position_ids = ( + torch.arange(seq_len, dtype=torch.long, device=device) + .unsqueeze(0) + .expand(batch_size, -1) + ) + vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device) + + forward_args = { + "input_ids": data["tokens"], + "position_ids": vlm_position_ids, + "attention_mask": vlm_attention_mask, + "inference_context": inference_context, + "runtime_gather_output": True, + } + # Add vision inputs + if _has_pixel_values: + forward_args["pixel_values"] = data["pixel_values"] + if _has_image_grid_thw: + forward_args["image_grid_thw"] = data["image_grid_thw"] + if _has_image_sizes: + forward_args["image_sizes"] = data["image_sizes"] + + output_tensor = model(**forward_args) + else: + # For text-only LLM models + output_tensor = model( + data["tokens"], + position_ids, + attention_mask, + inference_context=inference_context, + runtime_gather_output=True, + ) return output_tensor, _dummy_loss_func disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0 @@ -248,9 +285,18 @@ def _forward_step_func(data, model): else: tokens = input_ids + data_dict = {"tokens": tokens} + # Vision inputs should only be passed during prefill (step 0), not during decode steps + if pixel_values is not None: + data_dict["pixel_values"] = pixel_values + if image_grid_thw is not None: + data_dict["image_grid_thw"] = image_grid_thw + if image_sizes is not None: + data_dict["image_sizes"] = image_sizes + list_of_logits = get_forward_backward_func()( forward_step_func=_forward_step_func, - data_iterator=[{"tokens": tokens}], + data_iterator=[data_dict], model=model, num_microbatches=1, seq_length=tokens.shape[-1], diff --git a/modelopt/torch/utils/plugins/megatron_mmlu.py b/modelopt/torch/utils/plugins/megatron_mmlu.py index 3b997268bf..b03d338c0b 100644 --- a/modelopt/torch/utils/plugins/megatron_mmlu.py +++ b/modelopt/torch/utils/plugins/megatron_mmlu.py @@ -47,6 +47,8 @@ from .megatron_generate import megatron_generate +__all__ = ["megatron_mmlu"] + def _get_all_subjects(): """All subjects (anatomy, ...) can be acquired from querying all subsets and splits.""" diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index ac05e44f13..1c47a38dee 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -17,31 +17,59 @@ """Processing large data to tokenize for pretraining. -Usage: +Usage to tokenize one or more JSONL files: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --jsonl_paths path/to/input/data1.jsonl path/to/input/data2.jsonl ... \ + --json_keys text \ + --output_dir /path/to/tokenized/Qwen3/ \ + --tokenizer Qwen/Qwen3-0.6B +``` + +Usage to tokenize all JSONL files in a directory: + +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --input_dir /path/to/input/data/ \ + --json_keys text \ + --output_dir /path/to/tokenized/Qwen3/ \ + --tokenizer Qwen/Qwen3-0.6B +``` -```python -from modelopt.torch.utils.plugins import megatron_preprocess_data +Usage to download and tokenize a dataset from Hugging Face Hub: -megatron_preprocess_data( - input_path="path/to/input/data", - output_dir="path/to/output/dir", - tokenizer_name_or_path="hf_model_name", - json_keys=["name of json key(s) to tokenize"], -) +```bash +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --hf_dataset nvidia/Nemotron-Pretraining-Dataset-sample \ + --hf_name Nemotron-SFT-Code \ + --hf_split train \ + --json_keys text \ + --tokenizer Qwen/Qwen3-0.6B \ + --output_dir /path/to/tokenized/Qwen3/ ``` + +NOTE: If you skip --hf_name, it will download and tokenize all subsets for the dataset. +If you skip --hf_split, it will download and tokenize all splits for the subset. """ import argparse import json import multiprocessing -import sys +import os from pathlib import Path +from warnings import warn import requests from datasets import load_dataset +from huggingface_hub.utils import build_hf_headers from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer +from modelopt.torch.utils import num2hrb + +__all__ = ["megatron_preprocess_data"] + class _Encoder: tokenizer: AutoTokenizer = None @@ -104,11 +132,13 @@ def __init__(self, vocab_size: int, json_keys: list[str], log_interval: int, wor self.log_interval = log_interval self.workers = workers - def _print_processing_stats(self, count: int, total_doc_len: int, total_enc_len: int): - if count % self.log_interval == 0: + def _print_processing_stats( + self, count: int, total_doc_len: int, total_enc_len: int, *, force_print: bool = False + ): + if count % self.log_interval == 0 or force_print: print( - f"Processed {count} documents, {total_doc_len} chars, {total_enc_len} tokens", - file=sys.stderr, + f"\tProcessed {num2hrb(count)} docs = {num2hrb(total_doc_len)} chars = {num2hrb(total_enc_len)} tokens", + flush=True, ) def process_json_file( @@ -116,7 +146,7 @@ def process_json_file( ): output_prefix = Path(output_dir) / Path(input_file_name).stem - print("Opening", input_file_name) + print(f"\nOpening {input_file_name}") fin = open(input_file_name, encoding="utf-8") pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) @@ -138,7 +168,7 @@ def process_json_file( ) if not builders: - print(f"Output files corresponding to {input_file_name} already exist, skipping") + print(f"\t[SKIP] Output files corresponding to {input_file_name} already exist") return 0 total_doc_len, total_enc_len, final_enc_len = 0, 0, 0 @@ -149,6 +179,7 @@ def process_json_file( for key in doc: builders[key].add_document(doc[key], sentence_lens[key]) self._print_processing_stats(i, total_doc_len, total_enc_len) + self._print_processing_stats(i, total_doc_len, total_enc_len, force_print=True) fin.close() for key in builders: @@ -157,8 +188,92 @@ def process_json_file( return final_enc_len +def _download_hf_dataset( + dataset: str, + output_dir: str | Path, + json_keys: list[str], + name: str | None = None, + split: str | None = "train", + max_samples_per_split: int | None = None, +) -> list[str]: + """Download a Hugging Face dataset and save as JSONL files. + + Returns: + List of paths to downloaded JSONL files. + """ + print(f"Downloading dataset {dataset} from Hugging Face") + jsonl_paths: list[str] = [] + + try: + response = requests.get( + f"https://datasets-server.huggingface.co/splits?dataset={dataset}", + headers=build_hf_headers(), + timeout=10, + ) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError(f"Failed to fetch dataset splits for {dataset}: {e}") from e + + response_json = response.json() + print(f"\nFound {len(response_json['splits'])} total splits for {dataset}:") + for entry in response_json["splits"]: + print(f"\t{entry}") + + splits_to_process = [] + for entry in response_json["splits"]: + if name is not None and name != entry.get("config", None): + continue + if split is not None and split != entry["split"]: + continue + splits_to_process.append(entry) + + print(f"\nFound {len(splits_to_process)} splits to process:") + for entry in splits_to_process: + print(f"\t{entry}") + + for entry in splits_to_process: + skip_processing = False + path = entry["dataset"] + name = entry.get("config", None) + split = entry["split"] + if max_samples_per_split is not None: + split = f"{split}[:{max_samples_per_split}]" + jsonl_file_path = f"{output_dir}/raw/{path.replace('/', '--')}_{name}_{split}.jsonl" + + print(f"\nLoading HF dataset {path=}, {name=}, {split=}") + if os.path.exists(jsonl_file_path): + jsonl_paths.append(jsonl_file_path) + print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists") + continue + ds = load_dataset(path=path, name=name, split=split) + + for key in json_keys: + if key not in ds.features: + warn(f"[SKIP] {key=} not found in {ds.features=}") + skip_processing = True + break + + if skip_processing: + continue + + print(f"Saving raw dataset to {jsonl_file_path}") + ds.to_json(jsonl_file_path) + jsonl_paths.append(jsonl_file_path) + + print(f"\n\nTokenizing JSONL paths: {jsonl_paths}\n") + return jsonl_paths + + def megatron_preprocess_data( - input_path: str | Path | list[str] | list[Path], + *, + input_dir: str | Path | None = None, + jsonl_paths: str | Path | list[str] | list[Path] | None = None, + # Hugging Face Hub dataset arguments + hf_dataset: str | None = None, + hf_name: str | None = None, + hf_split: str | None = "train", + hf_max_samples_per_split: int | None = None, + # Other arguments output_dir: str | Path, tokenizer_name_or_path: str, json_keys: list[str] = ["text"], @@ -169,25 +284,48 @@ def megatron_preprocess_data( ): """Process large data for pretraining. + Exactly one of ``input_dir``, ``jsonl_paths``, or ``hf_dataset`` must be provided. + Args: - input_path (str | Path | list): Path to file or directory - containing input JSONL files, or list of paths to JSONL files - output_dir (str | Path): Path to directory to save binary output files - tokenizer_name_or_path (str): Name or path of the Hugging Face tokenizer to use - json_keys (list, optional): List of keys to extract from json. Defaults to ["text"] - append_eod (bool, optional): Append an token to the end of a document. Defaults to False - max_sequence_length (int, optional): Maximum tokenized sequence length. Defaults to None - workers (int, optional): Number of worker processes to launch. Defaults to 1 - log_interval (int, optional): Interval between progress updates. Defaults to 1000 + input_dir (str | Path, optional): Directory containing JSONL files to tokenize. + jsonl_paths (str | Path | list, optional): One or more paths to JSONL files. + hf_dataset (str, optional): Hugging Face Hub dataset name or path to download and tokenize. + hf_name (str, optional): Hugging Face Hub dataset subset name. Downloads all subsets if None. + hf_split (str, optional): Hugging Face Hub dataset split. Defaults to "train". + hf_max_samples_per_split (int, optional): Maximum number of samples to download per split from Hugging Face Hub. + Skip to download all samples. + output_dir (str | Path): Path to directory to save binary output files. + tokenizer_name_or_path (str): Name or path of the Hugging Face tokenizer to use. + json_keys (list, optional): List of keys to extract from json. Defaults to ["text"]. + append_eod (bool, optional): Append an token to the end of a document. Defaults to False. + max_sequence_length (int, optional): Maximum tokenized sequence length. Defaults to None. + workers (int, optional): Number of worker processes to launch. Defaults to 1. + log_interval (int, optional): Interval between progress updates. Defaults to 100000. """ - if isinstance(input_path, list): - file_names = input_path - elif Path(input_path).is_file(): - file_names = [input_path] - else: - file_names = sorted(Path(input_path).glob("*.jsonl")) + num_sources = sum(x is not None for x in (input_dir, jsonl_paths, hf_dataset)) + if num_sources != 1: + raise ValueError( + "Exactly one of `input_dir`, `jsonl_paths`, or `hf_dataset` must be provided." + ) + + if hf_dataset is not None: + jsonl_paths = _download_hf_dataset( + hf_dataset, + output_dir, + json_keys, + name=hf_name, + split=hf_split, + max_samples_per_split=hf_max_samples_per_split, + ) + + if input_dir is not None: + file_names = sorted(Path(input_dir).glob("*.jsonl")) if not file_names: - raise ValueError(f"No JSONL files found in input path: {input_path}") + raise ValueError(f"No JSONL files found in input directory: {input_dir}") + elif isinstance(jsonl_paths, (str, Path)): + file_names = [jsonl_paths] # type: ignore[list-item] + else: + file_names = list(jsonl_paths) # type: ignore[arg-type] Path(output_dir).mkdir(exist_ok=True) vocab_size = AutoTokenizer.from_pretrained(tokenizer_name_or_path).vocab_size @@ -200,32 +338,43 @@ def megatron_preprocess_data( num_tokens = partition.process_json_file(name, output_dir, encoder) final_enc_len += num_tokens - print(f">>> Total number of tokens: {final_enc_len}") + print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}") def main(): - """Sample main function to process large data for pretraining. - - Example usage: - - >>> python megatron_preprocess_data.py \ - --dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \ - --tokenizer "meta-llama/Llama-3.2-1B-Instruct" \ - --output_dir "./processed_data" - """ + """Sample main function to process large data for pretraining.""" parser = argparse.ArgumentParser(prog="megatron_preprocess_data") - parser.add_argument("--input_path", type=str, default=None, help="Input path.") + # Dataset arguments (pre-downloaded .jsonl files or download from Hugging Face Hub) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--input_dir", type=str, help="Directory containing JSONL files") + group.add_argument( + "--jsonl_paths", nargs="+", type=str, help="One or more paths to JSONL files" + ) + group.add_argument( + "--hf_dataset", + type=str, + help="Hugging Face Hub dataset path to download and tokenize", + ) parser.add_argument( - "--dataset", + "--hf_name", type=str, - default="nvidia/Nemotron-Pretraining-Dataset-sample", - help="Hugging Face Hub dataset name or path", + default=None, + help="Hugging Face Hub dataset subset name. Skip to download and tokenize all subsets for the dataset.", ) - parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset") - parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split") parser.add_argument( - "--output_dir", type=str, default="./processed_data", help="Output directory" + "--hf_split", + type=str, + default="train", + help="Hugging Face Hub dataset split. Skip to download and tokenize all splits for the subset.", ) + parser.add_argument( + "--hf_max_samples_per_split", + type=int, + default=None, + help="Maximum number of samples to download per split from Hugging Face Hub. Skip to download all samples.", + ) + # Other arguments + parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path") parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize") parser.add_argument("--append_eod", action="store_true", help="Append token") @@ -233,51 +382,21 @@ def main(): "--max_sequence_length", type=int, default=None, help="Maximum sequence length" ) parser.add_argument("--workers", type=int, default=8, help="Number of worker processes") - parser.add_argument("--log_interval", type=int, default=1000, help="Log interval") + parser.add_argument("--log_interval", type=int, default=100000, help="Log interval") args = parser.parse_args() - if args.input_path is None: - args.input_path = [] - - try: - response = requests.get( - f"https://datasets-server.huggingface.co/splits?dataset={args.dataset}", - timeout=10, - ) - response.raise_for_status() - except requests.RequestException as e: - print(f"Failed to fetch dataset splits for {args.dataset}: {e}") - return - - for entry in response.json()["splits"]: - skip_processing = False - name = entry["dataset"] - subset = entry.get("config", None) - split = entry["split"] - - if args.subset is not None and args.subset != subset: - skip_processing = True - if args.split is not None and args.split != split: - skip_processing = True - - print(f"Loading dataset {name} with subset {subset} and split {split}") - dataset = load_dataset(name, subset, split=split) - - for key in args.json_keys: - if key not in dataset.features: - print(f"Key {key} not found in dataset features. Skipping...") - skip_processing = True - break - - if skip_processing: - continue - - json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl" - dataset.to_json(json_file_path) - args.input_path += [json_file_path] + print("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print(f"{k:<35} {v}") + print("===================================================\n") megatron_preprocess_data( - input_path=args.input_path, + input_dir=args.input_dir, + jsonl_paths=args.jsonl_paths, + hf_dataset=args.hf_dataset, + hf_name=args.hf_name, + hf_split=args.hf_split, + hf_max_samples_per_split=args.hf_max_samples_per_split, output_dir=args.output_dir, tokenizer_name_or_path=args.tokenizer, json_keys=args.json_keys, diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py new file mode 100644 index 0000000000..e147ebf2c2 --- /dev/null +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processing large data to tokenize for pretraining.""" + +import copy +import itertools +import os + +import torch +import transformers +from datasets import load_dataset +from transformers.trainer_pt_utils import LabelSmoother + +from modelopt.torch.utils import print_rank_0 + +REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" +) + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +def _sharegpt_to_openai_messages(conversations: list[dict]): + """Optionally align sharedgpt format to openai format.""" + role_mapping = { + "user": "user", + "User": "user", + "human": "user", + "assistant": "assistant", + "Assistant": "assistant", + "gpt": "assistant", + "system": "system", + "System": "system", + } + messages = [] + for msg in conversations: + role = role_mapping[msg["role"]] + content = msg["content"] + messages.append({"role": role, "content": content}) + return messages + + +class ShardedDataset(torch.utils.data.Dataset): + """Subclass of torch.utils.data.Dataset to load data from HuggingFace dataset.""" + + def __init__( + self, + name: str, + subset: str | None = None, + data_files: str | None = None, + split: str = "train", + num_shards: int = 1, + shard_index: int = 0, + num_streaming_samples: int | None = None, + ): + """Initialize the ShardedDataset.""" + self.name = name + self.subset = subset + self.split = split + self.data_files = data_files + self.num_shards = num_shards + self.shard_index = shard_index + self.num_streaming_samples = num_streaming_samples + + self._load_dataset() + + def __len__(self): + if self.num_streaming_samples is not None: + return self.num_streaming_samples + else: + return len(self._raw_samples) + + def __getitem__(self, index): + index = index // self.num_shards + + if self.num_streaming_samples is not None: + while index >= len(self._raw_samples): + self._raw_samples.append(next(self._stream_iterator)) + + return self._raw_samples[index] + + def _load_dataset(self): + dataset = load_dataset( + self.name, + self.subset, + data_files=self.data_files, + split=self.split, + # num_proc=4, # TODO: Make this configurable + streaming=self.num_streaming_samples is not None, + ) + + shard = dataset.shard(num_shards=self.num_shards, index=self.shard_index) + + if self.num_streaming_samples is not None: + self._raw_samples = [] + self._stream_samples = shard + self._stream_iterator = itertools.cycle(self._stream_samples) + else: + self._raw_samples = shard + + +class LanguageDataCollator: + """Data collator for language modeling tasks. + + Accepts samples in OpenAI or ShareGPT formats and returns + tokenized outputs with padding and truncation, including + input_ids and attention_mask. + """ + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + train_len: int = 4096, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + json_key: str = "text", + return_labels: bool = False, + ): + """Initialize the LanguageDataset.""" + if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase): + raise ValueError( + "The tokenizer must be a transformers.PreTrainedTokenizerBase but got {}".format( + type(tokenizer) + ) + ) + self.tokenizer = tokenizer + self.train_len = train_len + self.add_generation_prompt = add_generation_prompt + self.answer_only_loss = answer_only_loss + self.json_key = json_key + self.return_labels = return_labels + + if chat_template is not None: + self.tokenizer.chat_template = chat_template + else: + self._post_process_chat_template() + + self._post_process_tokenizer() + if self.tokenizer.chat_template is None: + raise ValueError("No valid chat template!") + + def _post_process_tokenizer(self): + if self.tokenizer.pad_token_id is None: + print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + if hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is None: + if self.tokenizer.eos_token == "<|eot_id|>": # nosec + self.tokenizer.pad_token = "<|end_of_text|>" # nosec + else: + raise ValueError("The tokenizer has no pad_token!") + + def _post_process_chat_template(self): + # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the + # tokens are preserved for supervised learning. + self.tokenizer.chat_template = self.tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + ) + + def _process_chat_sample(self, examples: list): + tokenized_examples = self.tokenizer.apply_chat_template( + examples, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + if self.return_labels: + input_ids = tokenized_examples["input_ids"] + labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) + labels[..., :-1] = input_ids[..., 1:] + tokenized_examples["labels"] = labels + return tokenized_examples + + def _process_text_sample(self, examples: list): + tokenized_examples = self.tokenizer( + examples, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.train_len, + ) + return tokenized_examples + + def __call__(self, examples): + """Call the LanguageDataCollator.""" + batch = [] + + for example in examples: + if not isinstance(example, dict): + raise ValueError("The sample must be a Dict but got {}".format(type(example))) + text = example.get(self.json_key, None) + if isinstance(text, str): + batch.append(text) + else: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + batch.append(messages) + + return self._process_chat_sample(batch) + + +class VisionLanguageDataCollator(LanguageDataCollator): + """VisionLanguageDataCollator is a subclass of LanguageDataCollator that is used to collate vision-language data.""" + + def __init__( + self, + processor: str, + train_len: int = 8192, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + local_image_path: str = "", + return_labels: bool = False, + ): + """Initialize the VisionLanguageDataset.""" + self.processor = transformers.AutoProcessor.from_pretrained(processor) + self.chat_template = chat_template + self.local_image_path = local_image_path + + super().__init__( + tokenizer=self.processor.tokenizer, + train_len=train_len, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + answer_only_loss=answer_only_loss, + return_labels=return_labels, + ) + + def _process_multimodal_sample(self, examples): + tokenized_messages = self.processor.apply_chat_template( + examples, + tokenize=True, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + + return tokenized_messages + + def __call__(self, examples): + """Call the VisionLanguageDataCollator.""" + batch = [] + + for example in examples: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + + copy_messages = copy.deepcopy(messages) + + for msg in copy_messages: + if isinstance(msg["content"], str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + + for ctn in msg["content"]: + if ctn["type"] == "image" and "image" in ctn: + ctn["image"] = os.path.abspath( + os.path.join(self.local_image_path, ctn["image"]) + ) + # If any value in ctn is None, delete that key + # HF dataloader add Nones to align keys. Leads to error in processor. + keys_to_delete = [k for k, v in ctn.items() if v is None] + for k in keys_to_delete: + del ctn[k] + + batch.append(copy_messages) + + return self._process_multimodal_sample(batch) diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 1d9f594846..3f07c57715 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -13,29 +13,210 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for getting samples and forward loop function for different vlm datasets.""" +"""Utility functions for getting samples and dataloader for different VLM calibration datasets. +This module supports both: +- Small non-streaming VLM datasets (e.g., ScienceQA) +- Large streaming VLM datasets (e.g., Nemotron-VLM-Dataset-v2) where we want to avoid downloading everything. +""" + +import contextlib +import copy +import itertools +from io import BytesIO +from pathlib import Path from typing import Any +import torch from torch.utils.data import DataLoader from .image_processor import MllamaImageProcessor +from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached # Use dict to store the config for each dataset. # If we want to export more options to user like target languages, we need more standardized approach like dataclass. SUPPORTED_VLM_DATASET_CONFIG: dict[str, dict[str, Any]] = { "scienceqa": {"config": {"path": "derek-thomas/ScienceQA", "split": "train"}}, + # Large multi-subset dataset (use streaming to avoid downloading the entire dataset) + "nemotron_vlm_dataset_v2": { + "config": {"path": "nvidia/Nemotron-VLM-Dataset-v2", "split": "train", "streaming": True}, + # Provide a sane default that (a) includes in-repo media shards and (b) is document-centric. + # Subsets like docvqa_cot/chartqa_cot are JSONL-only in the dataset repo and require --vlm_image_root. + "default_subsets": ["sparsetables", "plotqa_cot", "wiki_en"], + }, } __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -def _get_vlm_dataset(dataset_name: str, num_samples: int): +class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): + """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" + + def __init__(self, hf_iterable, num_samples: int): + super().__init__() + self._hf_iterable = hf_iterable + self._num_samples = num_samples + + def __iter__(self): + return itertools.islice(iter(self._hf_iterable), self._num_samples) + + def __len__(self): + return self._num_samples + + +def _extract_text_from_messages(messages: Any) -> str | None: + """Best-effort extraction of a user text prompt from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + if msg.get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + # Common multimodal format: [{"type":"image"}, {"type":"text","text":"..."}] + texts = [ + part["text"] + for part in content + if isinstance(part, dict) + and part.get("type") == "text" + and isinstance(part.get("text"), str) + ] + if texts: + return "\n".join(texts) + return None + + +def _messages_up_to_last_user(messages: Any) -> list[dict[str, Any]] | None: + """Return messages truncated to the last user turn (inclusive).""" + if not isinstance(messages, list): + return None + last_user_idx = None + for i, msg in enumerate(messages): + if isinstance(msg, dict) and msg.get("role") == "user": + last_user_idx = i + if last_user_idx is None: + return None + trimmed = messages[: last_user_idx + 1] + return [m for m in trimmed if isinstance(m, dict)] + + +def _extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image object from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not (isinstance(part, dict) and part.get("type") == "image"): + continue + # Common keys used by HF datasets / chat templates + for key in ("image", "images", "value", "data", "path", "image_url", "url"): + if key in part: + val = part[key] + if isinstance(val, list) and val: + return val[0] + return val + # Fallback: return the dict itself (some processors may accept it) + return part + return None + + +def _extract_image_ref_from_example(example: dict[str, Any]) -> Any: + """Best-effort extraction of an image reference from a dataset example.""" + img = example.get("image") + if img is None: + img = example.get("images") + if img is None: + img = _extract_first_image_from_messages(example.get("messages")) + return img + + +def _maybe_load_image(image_obj: Any, repo_id: str | None, image_root: str | Path | None) -> Any: + """Convert common image references (path/bytes) into a PIL image if possible. + + For some streaming datasets, images are stored as file paths inside the dataset repo. + In that case, we lazily download just the referenced files via `hf_hub_download`. + """ + if image_obj is None: + return None + + # If it's a list, take the first (some formats store a list for multi-image samples). + if isinstance(image_obj, list) and image_obj: + image_obj = image_obj[0] + + # Path-like reference + if isinstance(image_obj, str): + # First, try resolving against a local image root (best option for datasets that only ship JSONL refs). + if image_root is not None: + try: + from PIL import Image + + local_path = Path(image_root) / image_obj + if local_path.exists(): + return Image.open(local_path).convert("RGB") + except Exception: + pass + + if repo_id is None: + return image_obj + try: + from huggingface_hub import hf_hub_download + from PIL import Image + + local_path = hf_hub_download(repo_id=repo_id, filename=image_obj, repo_type="dataset") + return Image.open(local_path).convert("RGB") + except Exception: + return None + + # Dict-like reference (common in chat content items) + if isinstance(image_obj, dict): + # bytes payload + if "bytes" in image_obj and isinstance(image_obj["bytes"], (bytes, bytearray)): + try: + from PIL import Image + + return Image.open(BytesIO(image_obj["bytes"])).convert("RGB") + except Exception: + return None + + # path/url-ish payloads + for key in ("path", "image", "image_path", "file", "url", "image_url"): + if key in image_obj and isinstance(image_obj[key], str): + return _maybe_load_image(image_obj[key], repo_id=repo_id, image_root=image_root) + + # If it's already a PIL/numpy/torch image-like object, just return it and let the processor validate. + return image_obj + + +def _get_vlm_dataset( + dataset_name: str, + num_samples: int, + require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, + use_media_shards: bool = True, + max_shards: int | None = None, +): """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + require_image: If True, keep only samples that have an image field. + subsets: Optional subset/config names for multi-subset datasets (e.g., Nemotron-VLM-Dataset-v2). + shuffle_buffer_size: Shuffle buffer size for streaming datasets (higher is "more random"). + seed: RNG seed for streaming dataset shuffle. + use_media_shards: If True, prefer reading in-repo `media/shard_*.tar` files when available. + max_shards: Optional cap on the number of tar shards to download/use. Returns: A hugging face Dataset. @@ -44,16 +225,91 @@ def _get_vlm_dataset(dataset_name: str, num_samples: int): if dataset_name in SUPPORTED_VLM_DATASET_CONFIG: from datasets import load_dataset - # Use streaming can reduce the downloading time for large datasets - dataset = load_dataset( - **SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"], - ) + cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() + streaming = bool(cfg.pop("streaming", False)) + + if dataset_name == "nemotron_vlm_dataset_v2": + # This dataset contains many subsets; load only the requested ones via `name=...`. + if not subsets: + subsets = SUPPORTED_VLM_DATASET_CONFIG[dataset_name].get("default_subsets", []) + if not subsets: + raise ValueError("No VLM subsets provided for nemotron_vlm_dataset_v2.") + + repo_id = cfg["path"] + + # Prefer in-repo media tar shards when present. HF `datasets` streaming alone does not join media. + if use_media_shards: + all_files = list_repo_files_cached(repo_id, repo_type="dataset") + shard_paths: list[str] = [] + for subset in subsets: + prefix = f"{subset}/media/" + shard_paths.extend( + [ + p + for p in all_files + if p.startswith(prefix) and p.lower().endswith(".tar") + ] + ) + + shard_paths = sorted(set(shard_paths)) + if shard_paths: + return NemotronTarPlusJsonlIterable( + repo_id=repo_id, + subsets=subsets, + shard_paths=shard_paths, + num_samples=num_samples, + seed=seed, + shuffle_buffer_size=shuffle_buffer_size, + max_shards=max_shards, + ) + + # Load each subset as a separate (streaming) dataset, then interleave. + streams = [ + load_dataset( + cfg["path"], + name=subset, + split=cfg.get("split", "train"), + streaming=streaming, + ) + for subset in subsets + ] + try: + from datasets import interleave_datasets + + ds = interleave_datasets(streams) + except Exception: + # Fallback: round-robin by chaining (less balanced than interleave). + ds = itertools.chain.from_iterable(streams) + else: + dataset = load_dataset(**cfg, streaming=streaming) + split = cfg.get("split", "train") + ds = dataset[split] if hasattr(dataset, "__getitem__") and split in dataset else dataset else: raise NotImplementedError( f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_vlm_datasets()}." ) - return dataset.select(range(num_samples)) + + # Streaming datasets: shuffle with bounded buffer and wrap into a torch IterableDataset. + if dataset_name == "nemotron_vlm_dataset_v2": + with contextlib.suppress(Exception): + ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer_size) + + if require_image: + # Keep only samples with a non-null image field (ScienceQA has both). + with contextlib.suppress(Exception): + ds = ds.filter( + lambda ex: ex.get("image", None) is not None + or ex.get("images", None) is not None + or _extract_image_ref_from_example(ex) is not None + ) + + # Select the first `num_samples` entries (or fewer if dataset is smaller). + try: + return ds.select(range(min(num_samples, len(ds)))) + except Exception: + # For streaming/iterable datasets without __len__/select, wrap for DataLoader iteration. + return _HFDatasetsIterableWrapper(ds, num_samples=num_samples) def get_supported_vlm_datasets() -> list[str]: @@ -75,9 +331,18 @@ def get_supported_vlm_datasets() -> list[str]: def get_vlm_dataset_dataloader( dataset_name: str = "scienceqa", - processor: MllamaImageProcessor = None, + processor: Any = None, batch_size: int = 1, num_samples: int = 512, + device: str | torch.device | None = None, + max_length: int | None = None, + require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, + image_root: str | Path | None = None, + use_media_shards: bool = True, + max_shards: int | None = None, ) -> DataLoader: """Get a dataloader with the dataset name and processor of the target model. @@ -86,22 +351,127 @@ def get_vlm_dataset_dataloader( processor: Processor used for encoding images and text data. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. + device: Device to move returned tensors to. If None, keep on CPU. + max_length: Optional max length for text tokenization (if supported by the processor). + require_image: If True, keep only samples that have an image field. Returns: An instance of dataloader. """ assert processor is not None, "Please provide a valid processor." - dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples) - # Apply the preprocessing function to the dataset - processed_dataset = dataset.map( - processor.preprocess_function, batched=False, remove_columns=dataset.column_names - ) + # Optional: allow callers to set a local image root for datasets that only ship JSON references. + # We store it on the processor instance to avoid threading it through a bunch of nested closures. + if image_root is not None: + setattr(processor, "_modelopt_vlm_image_root", image_root) + + if device is not None: + device = torch.device(device) - # Create DataLoader with the custom collate function - return DataLoader( - processed_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=processor.collate_function, + dataset = _get_vlm_dataset( + dataset_name, + num_samples=num_samples, + require_image=require_image, + subsets=subsets, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + use_media_shards=use_media_shards, + max_shards=max_shards, ) + + # Legacy path: our internal image processor wrapper (e.g., Mllama). + if isinstance(processor, MllamaImageProcessor): + processed_dataset = dataset.map( + processor.preprocess_function, batched=False, remove_columns=dataset.column_names + ) + return DataLoader( + processed_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + # Generic HF ProcessorMixin / AutoProcessor path: tokenize & process images at collate-time. + # For Nemotron VLM datasets, we prefer to follow the model-card flow: + # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # inputs = processor(text=[prompt], images=[pil_image], ...) + + def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: + repo_id = None + if dataset_name == "nemotron_vlm_dataset_v2": + repo_id = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"] + image_root = getattr(processor, "_modelopt_vlm_image_root", None) + + pairs: list[tuple[str, Any]] = [] + for ex in examples: + messages = ex.get("messages") + + # Image extraction + img_ref = _extract_image_ref_from_example(ex) + img = _maybe_load_image(img_ref, repo_id=repo_id, image_root=image_root) + if require_image and img is None: + continue + + # Prompt extraction + prompt = None + tok = getattr(processor, "tokenizer", None) + if tok is not None and messages is not None: + trimmed = _messages_up_to_last_user(messages) or [] + # For some Nemotron-style templates, the image content expects an empty string. + # Keep the actual image path separate for loading; blank it in the prompt message. + prompt_msgs = copy.deepcopy(trimmed) + for msg in prompt_msgs: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + part["image"] = "" + with contextlib.suppress(Exception): + prompt = tok.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + + if prompt is None: + # Fallback: best-effort question-only prompt. + q = ex.get("question") + if q is None and messages is not None: + q = _extract_text_from_messages(messages) + prompt = q or "Describe the image." + + pairs.append((prompt, img)) + + if not pairs: + raise ValueError( + "No usable images found in the current batch. " + "If you're using JSONL-only subsets (e.g., docvqa_cot/chartqa_cot), provide " + "`--vlm_image_root ` so referenced paths can be resolved. " + "If you're using asset-included subsets, keep media shard loading enabled " + "(default) and consider increasing `--vlm_max_shards`." + ) + + prompts, images = zip(*pairs) + + kwargs: dict[str, Any] = { + "text": list(prompts), + "images": list(images), + "return_tensors": "pt", + "padding": True, + } + if max_length is not None: + kwargs.update({"truncation": True, "max_length": max_length}) + + enc = processor(**kwargs) + + # Some processors return BatchEncoding; normalize to plain dict of tensors. + if hasattr(enc, "data"): + enc = enc.data + out: dict[str, Any] = dict(enc) + + # Move tensors to device if requested. + if device is not None: + for k, v in list(out.items()): + if torch.is_tensor(v): + out[k] = v.to(device) + return out + + return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate_fn) diff --git a/pyproject.toml b/pyproject.toml index f94abda13b..4b51354c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,8 +152,8 @@ disable_error_code = ["attr-defined"] [tool.pytest.ini_options] # Default additional options # Show a short test summary info for all except passed tests with -ra flag -# print execution time for 20 slowest tests and generate coverage reports -addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers" +# print execution time for 50 slowest tests and generate coverage reports +addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=50 --strict-markers" pythonpath = ["tests/"] markers = [ "manual: Only run when --run-manual is given", diff --git a/setup.py b/setup.py index 6096e31cab..9d98875abf 100644 --- a/setup.py +++ b/setup.py @@ -49,8 +49,7 @@ "onnx~=1.19.0", "onnxconverter-common~=1.16.0", "onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'", - "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 - "onnxruntime-gpu==1.23.2; platform_system == 'Windows'", + "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'", "onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test "onnxslim>=0.1.76", "polygraphy>=0.49.22", @@ -63,6 +62,8 @@ "huggingface_hub>=0.24.0", "peft>=0.17.0", "transformers>=4.53,<5.0", # Should match modelopt/torch/__init__.py and tox.ini + "nltk", + "wonderwords", ], # linter tools "dev-lint": [ @@ -78,6 +79,7 @@ "pytest-cov", "pytest-instafail", "pytest-timeout", + "sentencepiece", # For test_unified_export_megatron.py, test_vllm_fakequant_megatron_export.py "timm", "torchprofile>=0.0.4", # For computing flops of CV models "torchvision", diff --git a/tests/_test_utils/deploy_utils.py b/tests/_test_utils/deploy_utils.py index 805624b8f5..53ffd4a430 100644 --- a/tests/_test_utils/deploy_utils.py +++ b/tests/_test_utils/deploy_utils.py @@ -19,6 +19,48 @@ import pytest import torch +# Cache for available backends detection (computed once at import time) +_AVAILABLE_BACKENDS = None + + +def get_available_backends(): + """Detect which backends are available in the current environment. + + Returns: + set: A set of available backend names ('trtllm', 'vllm', 'sglang') + """ + global _AVAILABLE_BACKENDS + if _AVAILABLE_BACKENDS is not None: + return _AVAILABLE_BACKENDS + + available = set() + + try: + import tensorrt_llm # noqa: F401 + + available.add("trtllm") + except ImportError: + pass + + try: + import vllm # noqa: F401 + + available.add("vllm") + except ImportError: + pass + + try: + import sglang # noqa: F401 + + available.add("sglang") + except ImportError: + pass + + _AVAILABLE_BACKENDS = available + print(f"[deploy_utils] Detected available backends: {available}") + return _AVAILABLE_BACKENDS + + # Common test prompts for all backends COMMON_PROMPTS = [ "Hello, my name is", @@ -90,18 +132,18 @@ def run(self): def _deploy_trtllm(self): """Deploy a model using TensorRT-LLM.""" - try: - from tensorrt_llm import LLM, SamplingParams - from tensorrt_llm.llmapi import CudaGraphConfig, EagleDecodingConfig, KvCacheConfig - except ImportError: - pytest.skip("tensorrt_llm package not available") + from tensorrt_llm import LLM, SamplingParams + from tensorrt_llm.llmapi import CudaGraphConfig, EagleDecodingConfig, KvCacheConfig sampling_params = SamplingParams(max_tokens=32) spec_config = None llm = None kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.8) - if self.model_id == "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": + if self.model_id in ( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4", + ): llm = LLM( model=self.model_id, tensor_parallel_size=self.tensor_parallel_size, @@ -173,10 +215,7 @@ def _deploy_trtllm(self): def _deploy_vllm(self): """Deploy a model using vLLM.""" - try: - from vllm import LLM, SamplingParams - except ImportError: - pytest.skip("vllm package not available") + from vllm import LLM, SamplingParams quantization_method = "modelopt" if "fp4" in self.model_id.lower(): @@ -210,10 +249,8 @@ def _deploy_vllm(self): def _deploy_sglang(self): """Deploy a model using SGLang.""" - try: - import sglang as sgl - except ImportError: - pytest.skip("sglang package not available") + import sglang as sgl + quantization_method = "modelopt" if "fp4" in self.model_id.lower(): quantization_method = "modelopt_fp4" @@ -230,7 +267,10 @@ def _deploy_sglang(self): mem_fraction_static=0.7, context_length=1024, ) - elif self.model_id == "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": + elif self.model_id in ( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8", + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4", + ): llm = sgl.Engine( model_path=self.model_id, quantization=quantization_method, @@ -259,10 +299,20 @@ def __init__(self, **params): else: self.params[key] = [value] + # Filter backends to only include available ones + if "backend" in self.params: + available = get_available_backends() + original_backends = self.params["backend"] + self.params["backend"] = [b for b in original_backends if b in available] + # Pre-generate all deployers for pytest compatibility self._deployers = list(self._generate_deployers()) def _generate_deployers(self): + # If no backends available after filtering, yield nothing + if "backend" in self.params and not self.params["backend"]: + return + for values in itertools.product(*self.params.values()): deployer = ModelDeployer(**dict(zip(self.params.keys(), values))) # Set test case ID in format "model_id_backend" diff --git a/tests/_test_utils/import_helper.py b/tests/_test_utils/import_helper.py index 43f974935b..a1148480a6 100644 --- a/tests/_test_utils/import_helper.py +++ b/tests/_test_utils/import_helper.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import ctypes import importlib.metadata +import os import shutil import pytest @@ -28,6 +29,23 @@ def skip_if_no_tensorrt(): except (AssertionError, ImportError) as e: pytest.skip(f"{e}", allow_module_level=True) + # Also verify that ORT's TensorRT EP can actually load its native library. + # The tensorrt Python package may be installed, but ORT's provider shared library + # (libonnxruntime_providers_tensorrt.so) could fail to load due to CUDA version + # mismatches (e.g., ORT built for CUDA 12 running on a CUDA 13 system). + try: + import onnxruntime + + ort_capi_dir = os.path.join(os.path.dirname(onnxruntime.__file__), "capi") + trt_provider_lib = os.path.join(ort_capi_dir, "libonnxruntime_providers_tensorrt.so") + if os.path.isfile(trt_provider_lib): + ctypes.CDLL(trt_provider_lib) + except OSError as e: + pytest.skip( + f"ORT TensorRT EP native library cannot be loaded: {e}", + allow_module_level=True, + ) + def skip_if_no_trtexec(): if not shutil.which("trtexec"): @@ -43,19 +61,12 @@ def skip_if_no_libcudnn(): pytest.skip(f"{e}!", allow_module_level=True) -def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool = False): +def skip_if_no_megatron(*, te_required: bool = True, mamba_required: bool = False): try: import megatron # noqa: F401 except ImportError: pytest.skip("megatron not available", allow_module_level=True) - try: - import apex # noqa: F401 - - has_apex = True - except ImportError: - has_apex = False - try: import transformer_engine # noqa: F401 @@ -70,8 +81,8 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool except ImportError: has_mamba = False - if apex_or_te_required and not has_apex and not has_te: - pytest.skip("Apex or TE required for Megatron test", allow_module_level=True) + if te_required and not has_te: + pytest.skip("TE required for Megatron test", allow_module_level=True) if mamba_required and not has_mamba: pytest.skip("Mamba required for Megatron test", allow_module_level=True) @@ -88,5 +99,5 @@ def skip_if_onnx_version_above_1_18(): if version.parse(installed_version) > version.parse(required_version): pytest.skip( - f"{package_name} version {installed_version} is less than required {required_version}" + f"{package_name} version {installed_version} is greater than required {required_version}" ) diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index 675fe03b35..ff97b51421 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -924,3 +924,88 @@ def build_conv_isinf_model(opset_version=13): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_conv_resize_model(): + # Define your model inputs and outputs + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(1, 288, 32, 32)] + output_shapes = [(1, 16, 64, 64)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + # Create the ONNX graph with the nodes + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "weights_1"], + outputs=["conv1_conv/Conv2D:0"], + name="conv1_conv/Conv2D", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + ), + # Note: resize_roi_scales is intentionally used for both roi and scales inputs + # to test the shared constant duplication fix (PR #757) + helper.make_node( + op_type="Resize", + inputs=[ + "conv1_conv/Conv2D:0", + "resize_roi_scales", + "resize_roi_scales", + "resize_sizes", + ], + outputs=["output_0"], + name="resize1_resize/Resize", + coordinate_transformation_mode="asymmetric", + cubic_coeff_a=-0.75, + mode="nearest", + nearest_mode="floor", + ), + ] + + # Create the ONNX initializers + initializers = [ + helper.make_tensor( + name="weights_1", + data_type=onnx.TensorProto.FLOAT, + dims=(16, 288, 1, 1), + vals=np.random.uniform(low=0.5, high=1.0, size=16 * 288 * 1 * 1), + ), + helper.make_tensor( + name="resize_roi_scales", + data_type=onnx.TensorProto.FLOAT, + dims=(0,), + vals=[], + ), + helper.make_tensor( + name="resize_sizes", + data_type=onnx.TensorProto.INT64, + dims=(4,), + vals=[1, 16, 64, 64], + ), + ] + + # Create the ONNX graph with the nodes and initializers + graph = helper.make_graph(nodes, "conv_resize", inputs, outputs, initializer=initializers) + + # Create the ONNX model + model = helper.make_model(graph) + model.opset_import[0].version = 13 + model.ir_version = 10 + + # Check the ONNX model + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 31419c4c9e..7d91b8909b 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -21,6 +21,12 @@ pytest.importorskip("diffusers") from diffusers import UNet2DConditionModel +try: + from diffusers.models.transformers import DiTTransformer2DModel, FluxTransformer2DModel +except Exception: # pragma: no cover - optional diffusers models + DiTTransformer2DModel = None + FluxTransformer2DModel = None + import modelopt.torch.opt as mto @@ -45,6 +51,48 @@ def get_tiny_unet(**config_kwargs) -> UNet2DConditionModel: return tiny_unet +def get_tiny_dit(**config_kwargs): + """Create a tiny DiTTransformer2DModel for testing.""" + if DiTTransformer2DModel is None: + pytest.skip("DiTTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 2, + "out_channels": 2, + "num_layers": 1, + "norm_num_groups": 1, + "sample_size": 8, + "patch_size": 2, + "num_embeds_ada_norm": 10, + } + kwargs.update(**config_kwargs) + return DiTTransformer2DModel(**kwargs) + + +def get_tiny_flux(**config_kwargs): + """Create a tiny FluxTransformer2DModel for testing.""" + if FluxTransformer2DModel is None: + pytest.skip("FluxTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": 1, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 8, + "pooled_projection_dim": 8, + "guidance_embeds": False, + "axes_dims_rope": (2, 2, 4), + } + kwargs.update(**config_kwargs) + return FluxTransformer2DModel(**kwargs) + + def create_tiny_unet_dir(tmp_path: Path, **config_kwargs) -> Path: """Create and save a tiny UNet model to a directory.""" tiny_unet = get_tiny_unet(**config_kwargs) diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 76ddc5a94a..a22eaaf9e0 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -314,6 +314,7 @@ def get_mcore_mamba_hybrid_model( sequence_parallel: bool = False, # Mamba-specific parameters mamba_state_dim: int = 32, + mamba_num_heads: int | None = None, mamba_head_dim: int = 16, mamba_num_groups: int = 2, # MoE-specific parameters @@ -347,6 +348,7 @@ def get_mcore_mamba_hybrid_model( num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, mamba_state_dim=mamba_state_dim, + mamba_num_heads=mamba_num_heads, mamba_head_dim=mamba_head_dim, mamba_num_groups=mamba_num_groups, num_moe_experts=num_moe_experts, @@ -358,7 +360,7 @@ def get_mcore_mamba_hybrid_model( **config_kwargs, ) - if not (skip_moe or "E" in Symbols.VALID): + if not (skip_moe or "E" in Symbols.VALID): # Mcore 0.16+ has MoE support warn("MoE blocks are not supported in current MambaModel. Skipping MoE blocks.") skip_moe = True diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index bb91f83cd7..63904ba448 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -129,6 +129,47 @@ def run_mcore_inference_with_dummy_input( return run_mcore_inference(model, prompt_tokens, hidden_size) +def get_batch(model, batch_size=2): + seq_length = model.max_sequence_length + vocab_size = model.vocab_size + + input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() + position_ids = ( + torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) + ).cuda() + loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() + + return input_ids, labels, position_ids, attention_mask, loss_mask + + +def get_forward(model, batch_size=2): + """Return a forward function with cached batch inputs.""" + input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) + + def forward(model): + # MambaModel doesn't accept loss_mask argument + if isinstance(model, MambaModel): + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + ) + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return forward + + def initialize_for_megatron( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index 856edd38c9..87da6414c7 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -16,11 +16,22 @@ import modelopt.torch.prune as mtp -def prune_minitron(model, export_config, config, channel_divisor=64): +def prune_minitron(model, constraints, config, channel_divisor=64): return mtp.prune( model, - mode=[("mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config(channel_divisor))], - constraints={"export_config": export_config}, + mode=[ + ( + "mcore_minitron", + mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + mamba_head_dim_divisor=4, + num_moe_experts_divisor=1, + num_layers_divisor=1, + ), + ) + ], + constraints=constraints, dummy_input=None, # Not used config=config, ) diff --git a/tests/_test_utils/torch/nas_prune/utils.py b/tests/_test_utils/torch/nas_prune/utils.py new file mode 100644 index 0000000000..ffce06240f --- /dev/null +++ b/tests/_test_utils/torch/nas_prune/utils.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from torch.nn.modules.batchnorm import _BatchNorm + +from modelopt.torch.opt.dynamic import DynamicModule + + +def param_num(network: nn.Module, trainable_only: bool = False, unit=1e6) -> float: + """Get the number of parameters of a PyTorch model. + + Args: + network: The PyTorch model. + trainable_only: Whether to only count trainable parameters. Default is False. + unit: The unit to return the number of parameters in. Default is 1e6 (million). + + Returns: + The number of parameters in the model in the given unit. + """ + + if isinstance(network, DynamicModule): + # NOTE: model.parameters() doesnt consider active_slice so we dont get sorted or trimmed parameters! + raise NotImplementedError( + "param_num doesn't support DynamicModule. Please use param_num_from_forward instead." + ) + return ( + sum( + p.numel() if not trainable_only or p.requires_grad else 0 + for mod in network.modules() + for p in mod.parameters(recurse=False) + if not isinstance(mod, _BatchNorm) + ) + / unit + ) diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index f62d2d9911..2b2e43dcff 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -75,7 +75,14 @@ def forward_loop(model, run_backward=False): forward_loop(model, run_backward=True) -def save_restore_test(model_cls, device, quant_config, compress=False, version=None): +def save_restore_test( + model_cls, + device, + quant_config, + compress=False, + version=None, + test_cpu_restore: bool = False, +): # test restoring to an unquantized model model_quant = model_cls().to(device) model_ref = model_cls().to(device) @@ -89,11 +96,20 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N model_ref.load_state_dict(model_quant.state_dict()) assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0])) + # Verify that TensorQuantizer subclass types are preserved after restore + for name_q, mod_q in model_quant.named_modules(): + if name_q.endswith("quantizer"): + mod_r = dict(model_ref.named_modules())[name_q] + assert type(mod_q) is type(mod_r), ( + f"Quantizer class mismatch for '{name_q}': " + f"expected {type(mod_q).__name__}, got {type(mod_r).__name__}" + ) + if version is not None and Version(version) < Version("0.29"): # Rest of the tests are not needed for version < 0.29 return - if not compress: + if test_cpu_restore: # gpu: test restoring to a model on cpu. If the quantizer states are not initialized correctly, # the buffers will be created on cuda and this test will fail model_ref = model_cls().to("cpu") diff --git a/tests/_test_utils/torch/sparsity/sparse_attention_common.py b/tests/_test_utils/torch/sparsity/sparse_attention_common.py index 7724908b08..6e9ae50142 100644 --- a/tests/_test_utils/torch/sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch/sparsity/sparse_attention_common.py @@ -95,7 +95,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": 1e-4, + "threshold": {"prefill": 1e-4, "decode": 1e-4}, "br": 128, "bc": 128, "enable": True, @@ -153,13 +153,15 @@ def forward_loop(model): with torch.no_grad(): for batch in calib_data: output = model(batch) - assert not torch.isnan(output).any(), "NaN in output" - assert output is not None, "Output is None" + assert not torch.isnan(output).any(), ( + f"NaN detected in output for batch shape {batch.shape}" + ) + assert output is not None, f"Output is None for batch shape {batch.shape}" return model -def save_restore_test(model_cls, device, sparse_config): +def save_restore_test(model_cls, device, sparse_config, atol=1e-6): """Test save and restore of sparse attention state. Args: @@ -190,6 +192,6 @@ def save_restore_test(model_cls, device, sparse_config): output_sparse = model_sparse(test_input) output_restored = model_restored(test_input) - assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + assert torch.allclose(output_sparse, output_restored, atol=atol), ( "Restored model output doesn't match original" ) diff --git a/tests/conftest.py b/tests/conftest.py index 7f50af9258..a35aa448b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ # limitations under the License. import platform +from pathlib import Path import pytest @@ -57,3 +58,9 @@ def pytest_collection_modifyitems(config, items): def skip_on_windows(): if platform.system() == "Windows": pytest.skip("Skipping on Windows") + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + """Fixture providing the project root path for tests.""" + return Path(request.config.rootpath) diff --git a/tests/examples/gpt_oss/test_gpt_oss_qat.py b/tests/examples/gpt_oss/test_gpt_oss_qat.py index e5f9b8ab95..43464110b2 100644 --- a/tests/examples/gpt_oss/test_gpt_oss_qat.py +++ b/tests/examples/gpt_oss/test_gpt_oss_qat.py @@ -294,30 +294,27 @@ def deploy_gpt_oss_trtllm(self, tmp_path, model_path_override=None): ) def test_gpt_oss_complete_pipeline(model_path, tmp_path): """Test the complete GPT-OSS optimization pipeline by executing all 3 steps in sequence.""" - import pathlib - # Use current directory instead of tmp_path for checkpoints - current_dir = pathlib.Path.cwd() # Create GPTOSS instance with model path gpt_oss = GPTOSS(model_path) if model_path == "openai/gpt-oss-20b": # Step 1: SFT Training - sft_checkpoint = gpt_oss.gpt_oss_sft_training(current_dir) + sft_checkpoint = gpt_oss.gpt_oss_sft_training(tmp_path) if not sft_checkpoint or not sft_checkpoint.exists(): print("Step 1 failed: SFT checkpoint not found, stopping pipeline.") return print(f"Step 1 completed: SFT checkpoint at {sft_checkpoint}") # Step 2: QAT Training (depends on Step 1) - qat_checkpoint = gpt_oss.gpt_oss_qat_training(current_dir, sft_dir=sft_checkpoint) + qat_checkpoint = gpt_oss.gpt_oss_qat_training(tmp_path, sft_dir=sft_checkpoint) if not qat_checkpoint or not qat_checkpoint.exists(): print("Step 2 failed: QAT checkpoint not found, stopping pipeline.") return print(f"Step 2 completed: QAT checkpoint at {qat_checkpoint}") # Step 3: MXFP4 Conversion (depends on Step 2) - mxfp4_checkpoint = gpt_oss.gpt_oss_mxfp4_conversion(current_dir, qat_dir=qat_checkpoint) + mxfp4_checkpoint = gpt_oss.gpt_oss_mxfp4_conversion(tmp_path, qat_dir=qat_checkpoint) if not mxfp4_checkpoint or not mxfp4_checkpoint.exists(): print("Step 3 failed: MXFP4 checkpoint not found, stopping pipeline.") return @@ -325,12 +322,12 @@ def test_gpt_oss_complete_pipeline(model_path, tmp_path): # Step 4: Deploy with TensorRT-LLM (depends on Step 3) print("Step 4: Running deployment with MXFP4 checkpoint...") - gpt_oss.deploy_gpt_oss_trtllm(current_dir, model_path_override=mxfp4_checkpoint) + gpt_oss.deploy_gpt_oss_trtllm(tmp_path, model_path_override=mxfp4_checkpoint) print("Step 4 completed: Deployment successful") elif model_path == "openai/gpt-oss-120b": # Step 1: QAT Training with LoRA - qat_lora_checkpoint = gpt_oss.gpt_oss_qat_training_lora(current_dir) + qat_lora_checkpoint = gpt_oss.gpt_oss_qat_training_lora(tmp_path) if not qat_lora_checkpoint or not qat_lora_checkpoint.exists(): print("Step 1 failed: QAT-LoRA checkpoint not found, stopping pipeline.") return @@ -338,7 +335,7 @@ def test_gpt_oss_complete_pipeline(model_path, tmp_path): # Step 2: MXFP4 Conversion for LoRA model (depends on Step 1) mxfp4_checkpoint = gpt_oss.gpt_oss_mxfp4_conversion_lora( - current_dir, qat_lora_dir=qat_lora_checkpoint + tmp_path, qat_lora_dir=qat_lora_checkpoint ) if not mxfp4_checkpoint or not mxfp4_checkpoint.exists(): print("Step 2 failed: MXFP4 checkpoint not found, stopping pipeline.") @@ -347,5 +344,5 @@ def test_gpt_oss_complete_pipeline(model_path, tmp_path): # Step 3: Deploy with TensorRT-LLM (depends on Step 2) print("Step 3: Running deployment with MXFP4 checkpoint...") - gpt_oss.deploy_gpt_oss_trtllm(current_dir, model_path_override=mxfp4_checkpoint) + gpt_oss.deploy_gpt_oss_trtllm(tmp_path, model_path_override=mxfp4_checkpoint) print("Step 3 completed: Deployment successful") diff --git a/tests/examples/llm_ptq/test_deploy.py b/tests/examples/llm_ptq/test_deploy.py index 868304f48a..4dd98ad9d0 100644 --- a/tests/examples/llm_ptq/test_deploy.py +++ b/tests/examples/llm_ptq/test_deploy.py @@ -60,31 +60,43 @@ def cleanup_after_test(): "command", [ *ModelDeployerList( - model_id="nvidia/DeepSeek-R1-FP4", + model_id="nvidia/DeepSeek-R1-NVFP4", backend=("vllm", "trtllm", "sglang"), tensor_parallel_size=8, mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/DeepSeek-R1-FP4-v2", + model_id="nvidia/DeepSeek-R1-NVFP4-v2", backend=("vllm", "trtllm", "sglang"), tensor_parallel_size=8, mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/DeepSeek-R1-0528-FP4", + model_id="nvidia/DeepSeek-R1-0528-NVFP4", backend=("vllm", "trtllm", "sglang"), tensor_parallel_size=8, mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/DeepSeek-R1-0528-FP4-v2", + model_id="nvidia/DeepSeek-R1-0528-NVFP4-v2", backend=("vllm", "trtllm", "sglang"), tensor_parallel_size=8, mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/DeepSeek-V3-0324-FP4", + model_id="nvidia/DeepSeek-V3-0324-NVFP4", + backend=("vllm", "trtllm", "sglang"), + tensor_parallel_size=8, + mini_sm=100, + ), + *ModelDeployerList( + model_id="nvidia/DeepSeek-V3.1-NVFP4", + backend=("vllm", "trtllm", "sglang"), + tensor_parallel_size=8, + mini_sm=100, + ), + *ModelDeployerList( + model_id="nvidia/DeepSeek-V3.2-NVFP4", backend=("vllm", "trtllm", "sglang"), tensor_parallel_size=8, mini_sm=100, @@ -107,7 +119,7 @@ def test_deepseek(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/Llama-3.1-8B-Instruct-FP4", + model_id="nvidia/Llama-3.1-8B-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -119,7 +131,7 @@ def test_deepseek(command): tensor_parallel_size=4, ), *ModelDeployerList( - model_id="nvidia/Llama-3.3-70B-Instruct-FP4", + model_id="nvidia/Llama-3.3-70B-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=4, mini_sm=100, @@ -136,7 +148,7 @@ def test_deepseek(command): tensor_parallel_size=8, ), *ModelDeployerList( - model_id="nvidia/Llama-3.1-405B-Instruct-FP4", + model_id="nvidia/Llama-3.1-405B-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=8, mini_sm=100, @@ -148,7 +160,7 @@ def test_deepseek(command): tensor_parallel_size=8, ), *ModelDeployerList( - model_id="nvidia/Llama-4-Maverick-17B-128E-Instruct-FP4", + model_id="nvidia/Llama-4-Maverick-17B-128E-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=8, mini_sm=100, @@ -160,7 +172,7 @@ def test_deepseek(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + model_id="nvidia/Llama-4-Scout-17B-16E-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=8, mini_sm=100, @@ -176,7 +188,7 @@ def test_llama(command): "command", [ *ModelDeployerList( - model_id="nvidia/Qwen3-8B-FP4", + model_id="nvidia/Qwen3-8B-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -188,7 +200,7 @@ def test_llama(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/Qwen3-14B-FP4", + model_id="nvidia/Qwen3-14B-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -200,7 +212,7 @@ def test_llama(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/Qwen3-235B-A22B-FP4", + model_id="nvidia/Qwen3-235B-A22B-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=2, mini_sm=100, @@ -212,16 +224,16 @@ def test_llama(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/QwQ-32B-FP4", backend=("trtllm", "vllm", "sglang"), mini_sm=100 + model_id="nvidia/QwQ-32B-NVFP4", backend=("trtllm", "vllm", "sglang"), mini_sm=100 ), *ModelDeployerList( - model_id="nvidia/Qwen3-32B-FP4", + model_id="nvidia/Qwen3-32B-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=4, mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/Qwen2.5-VL-7B-Instruct-FP4", + model_id="nvidia/Qwen2.5-VL-7B-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=4, mini_sm=100, @@ -233,11 +245,23 @@ def test_llama(command): mini_sm=100, ), *ModelDeployerList( - model_id="nvidia/Qwen3-30B-A3B-FP4", + model_id="nvidia/Qwen3-30B-A3B-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=4, mini_sm=100, ), + *ModelDeployerList( + model_id="nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4", + backend=("trtllm", "vllm", "sglang"), + tensor_parallel_size=8, + mini_sm=100, + ), + *ModelDeployerList( + model_id="nvidia/Qwen3-Next-80B-A3B-Thinking-NVFP4", + backend=("trtllm", "vllm", "sglang"), + tensor_parallel_size=8, + mini_sm=100, + ), ], ids=idfn, ) @@ -252,11 +276,10 @@ def test_qwen(command): model_id="nvidia/Mixtral-8x7B-Instruct-v0.1-FP8", backend=("trtllm", "vllm", "sglang") ), *ModelDeployerList( - model_id="nvidia/Mixtral-8x7B-Instruct-v0.1-FP4", + model_id="nvidia/Mixtral-8x7B-Instruct-v0.1-NVFP4", backend=("trtllm", "vllm", "sglang"), mini_sm=100, ), - # ModelDeployer(model_id="nvidia/Mixtral-8x7B-Instruct-v0.1-FP8", backend="sglang"), unsupported ], ids=idfn, ) @@ -266,9 +289,9 @@ def test_mixtral(command): @pytest.mark.parametrize( "command", - [ # TRTLLM bug: https://nvbugs/5451286 + [ *ModelDeployerList( - model_id="nvidia/gemma-3-12b-it-FP4", + model_id="nvidia/gemma-3-12b-it-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -282,7 +305,7 @@ def test_mixtral(command): attn_backend="FLASHINFER", ), *ModelDeployerList( - model_id="nvidia/gemma-3-27b-it-FP4", + model_id="nvidia/gemma-3-27b-it-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -307,7 +330,7 @@ def test_gemma(command): "command", [ *ModelDeployerList( - model_id="nvidia/Phi-4-multimodal-instruct-FP4", + model_id="nvidia/Phi-4-multimodal-instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -319,7 +342,7 @@ def test_gemma(command): mini_sm=89, ), *ModelDeployerList( - model_id="nvidia/Phi-4-reasoning-plus-FP4", + model_id="nvidia/Phi-4-reasoning-plus-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=1, mini_sm=100, @@ -341,7 +364,7 @@ def test_phi(command): "command", [ *ModelDeployerList( - model_id="nvidia/Kimi-K2-Instruct-FP4", + model_id="nvidia/Kimi-K2-Instruct-NVFP4", backend=("trtllm", "vllm", "sglang"), tensor_parallel_size=8, mini_sm=100, @@ -374,12 +397,6 @@ def test_kimi(command): tensor_parallel_size=1, mini_sm=89, ), - *ModelDeployerList( - model_id="nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8", - backend=("trtllm", "vllm", "sglang"), - tensor_parallel_size=4, - mini_sm=89, - ), *ModelDeployerList( model_id="nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8", backend=("vllm",), @@ -393,6 +410,13 @@ def test_kimi(command): mini_sm=89, attn_backend="FLASHINFER", ), + *ModelDeployerList( + model_id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4", + backend=("trtllm", "vllm", "sglang"), + tensor_parallel_size=1, + mini_sm=89, + attn_backend="FLASHINFER", + ), ], ids=idfn, ) @@ -454,6 +478,14 @@ def test_medusa(command): mini_sm=89, eagle3_one_model=False, ), + *ModelDeployerList( + base_model="Qwen/Qwen3-235B-A22B-Thinking-2507", + model_id="nvidia/Qwen3-235B-A22B-Thinking-2507-FP4-Eagle3", + backend=("trtllm", "sglang"), + tensor_parallel_size=8, + mini_sm=89, + eagle3_one_model=False, + ), *ModelDeployerList( base_model="Qwen/Qwen3-30B-A3B", model_id="nvidia/Qwen3-30B-A3B-Eagle3", diff --git a/tests/examples/llm_ptq/test_llm_ptq.py b/tests/examples/llm_ptq/test_llm_ptq.py index 6ba23cc04e..4fc39f5ecb 100644 --- a/tests/examples/llm_ptq/test_llm_ptq.py +++ b/tests/examples/llm_ptq/test_llm_ptq.py @@ -114,6 +114,7 @@ def test_ptq_whisper(self, command): # sm89 PTQCommand(quant="fp8", min_sm=89), PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100 + PTQCommand(quant="mxfp8", min_sm=100), PTQCommand(quant="nvfp4", min_sm=100), # # multi_gpu diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py index b70dfab35d..9f1cb81253 100644 --- a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -34,7 +34,6 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", } ) kwargs.setdefault("seq_len", 128) - kwargs.setdefault("num_samples", 1) kwargs.setdefault("max_new_tokens", 16) cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) @@ -43,8 +42,10 @@ def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", @pytest.mark.parametrize("method", ["skip_softmax"]) def test_attention_sparsity(tiny_llama_path, tmp_path, method): - """Test sparse attention with TinyLlama.""" + """Test sparse attention with TinyLlama (with and without calibration).""" run_attention_sparsity_command( model=tiny_llama_path, method=method, + seq_len=128, + max_new_tokens=10, ) diff --git a/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py new file mode 100644 index 0000000000..73b02707c6 --- /dev/null +++ b/tests/examples/puzzletron/mbridge_distillation/test_distill_hf.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for distill_hf.py script.""" + +from pathlib import Path + +from _test_utils.examples.run_command import run_example_command +from _test_utils.torch.distributed.utils import get_free_port +from _test_utils.torch.puzzletron.utils import create_and_save_small_hf_model, create_tokenizer +from transformers import AutoModelForCausalLM + +from modelopt.torch.puzzletron.anymodel import convert_model + + +def test_distill_hf(project_root_path: Path, tmp_path: Path): + """Integration test for distill_hf.py. + + Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + and runs mbridge distillation. The models are created with reduced size for faster testing. + Models are converted to include block_configs. + """ + # Prepare student and teacher models + student_hf_path, teacher_hf_path = _prepare_student_and_teacher_models( + project_root_path, tmp_path + ) + + # Prepare output directory + output_dir = tmp_path / "distill_output" + output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare HF export directory + hf_export_dir = tmp_path / "hf_export" + hf_export_dir.mkdir(parents=True, exist_ok=True) + + # Build command-line arguments for distill_hf.py + # Use torchrun for distributed execution (single GPU for testing) + nproc_per_node = 1 + tp_size = 1 + train_iters = 5 + + cmd_parts = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--master-addr", + "127.0.0.1", # Explicitly set master address + "--master-port", + str(get_free_port()), # Pass port directly to torchrun to avoid conflicts + "distill_hf.py", + "--student_hf_path", + student_hf_path, + "--teacher_hf_path", + teacher_hf_path, + "--output_dir", + str(output_dir), + "--hf-export-path", + str(hf_export_dir), # Export to HuggingFace format + "--hf-model", + "meta-llama/Llama-3.1-8B-Instruct", # Note: uses hyphen, not underscore + "--tp_size", + str(tp_size), + "--pp_size", + "1", + "--seq_length", + "128", # Reduced for faster forward/backward passes + "--use_mock_data", # Use mock data to avoid disk I/O overhead + "--split", + "99,1,0", + "--mbs", + "1", + "--gbs", + "4", # Global batch size + "--train_iters", + str(train_iters), # Minimal iterations for smoke test + "--lr", + "0.0001", + "--min_lr", + "1e-5", + "--lr_warmup_iters", + "2", # Reduced warmup iterations + "--eval_interval", + "100", # Disable evaluation (set to > train_iters) + "--eval_iters", + "0", # No evaluation iterations + "--log_interval", + "5", # Reduced logging frequency + ] + + run_example_command( + cmd_parts, + example_path="puzzletron/mbridge_distillation", + ) + + # Check that distillation checkpoint contains run_config.yaml + run_config_path = output_dir / "checkpoints" / f"iter_{train_iters:07d}" / "run_config.yaml" + assert run_config_path.exists(), f"Expected run_config.yaml to exist at: {run_config_path}" + + # Verify that the distilled model can be loaded in HuggingFace format + model = AutoModelForCausalLM.from_pretrained( + str(hf_export_dir), + local_files_only=True, + trust_remote_code=True, + ) + assert model is not None, "Failed to load distilled model with AutoModelForCausalLM" + + print( + f"PYTEST SUMMARY: test_distill_hf test has finished successfully. " + f"Output directory: {output_dir}, HF export: {hf_export_dir}" + ) + + +def _prepare_student_and_teacher_models(project_root_path: Path, tmp_path: Path) -> tuple[str, str]: + """Prepare student and teacher models for distillation. + + Creates Llama models programmatically, converts them to heterogeneous format (AnyModel), + and returns the paths to the converted checkpoints. + + Args: + project_root_path: Path to the project root directory + tmp_path: Temporary directory for test artifacts + + Returns: + Tuple of (student_hf_path, teacher_hf_path) as strings + """ + + # Create temporary directories for models + student_hf_dir = tmp_path / "student_hf" + teacher_hf_dir = tmp_path / "teacher_hf" + + # Create tokenizer (uses local tokenizer from test resources) + tokenizer = create_tokenizer(project_root_path) + + # Create student model using utility function + # This uses local config files and preserves model-specific settings + # TODO: Make the student model using different ffn sizes across layers. + create_and_save_small_hf_model( + output_path=str(student_hf_dir), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_config_name="llama_3_1_8b_instruct", + hybrid_override_pattern=None, + ) + + # Create teacher model (same as student for testing) + create_and_save_small_hf_model( + output_path=str(teacher_hf_dir), + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + hf_config_name="llama_3_1_8b_instruct", + hybrid_override_pattern=None, + ) + + # Convert models to AnyModel format BEFORE distillation + # This is needed as converted checkpoints will be used as input for distillation later + student_anymodel_dir = tmp_path / "student_anymodel" + teacher_anymodel_dir = tmp_path / "teacher_anymodel" + + convert_model( + input_dir=str(student_hf_dir), + output_dir=str(student_anymodel_dir), + converter="llama", + ) + + convert_model( + input_dir=str(teacher_hf_dir), + output_dir=str(teacher_anymodel_dir), + converter="llama", + ) + print("Models converted to AnyModel format:") + print(f" Student AnyModel: {student_anymodel_dir}") + print(f" Teacher AnyModel: {teacher_anymodel_dir}") + + return student_anymodel_dir, teacher_anymodel_dir diff --git a/tests/examples/speculative_decoding/conftest.py b/tests/examples/speculative_decoding/conftest.py index bc75b87836..80417f4048 100644 --- a/tests/examples/speculative_decoding/conftest.py +++ b/tests/examples/speculative_decoding/conftest.py @@ -21,18 +21,20 @@ @pytest.fixture(scope="session", autouse=True) def tiny_daring_anteater_path(tmp_path_factory): - dataset_path = MODELOPT_ROOT / "examples/speculative_decoding/Daring-Anteater" + dataset_path = ( + MODELOPT_ROOT / "examples/speculative_decoding/input_conversations/daring-anteater.jsonl" + ) if not os.path.exists(dataset_path): try: run_example_command( - ["git", "clone", "https://huggingface.co/datasets/nvidia/Daring-Anteater"], + ["python", "prepare_input_conversations/add_daring_anteater.py"], "speculative_decoding", ) except Exception as e: # Ignore rate-limiting errors - pytest.skip(f"Failed to clone Daring-Anteater dataset: {e}") + pytest.skip(f"Failed to prepare dataset: {e}") output_path = tmp_path_factory.mktemp("daring_anteater") / "train.jsonl" - with open(dataset_path / "train.jsonl") as src, open(output_path, "w") as dst: + with open(dataset_path) as src, open(output_path, "w") as dst: for i, line in enumerate(src): if i >= 128: break diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 6bf1c79d2b..4f80692ca8 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -14,10 +14,13 @@ # limitations under the License. import json +import os import pytest import safetensors.torch +import torch from _test_utils.examples.run_command import run_example_command +from packaging.version import Version from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER @@ -28,9 +31,44 @@ def eagle_output_dir(tmp_path_factory): return tmp_path_factory.mktemp("eagle_output_dir") +@pytest.fixture(scope="module") +def draft_vocab_cache_dir(tmp_path_factory): + """Eagle output directory shared in this module.""" + return tmp_path_factory.mktemp("eagle_output_dir") + + +def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft_vocab_cache_dir): + """Test calibration of draft vocabulary.""" + run_example_command( + [ + "python", + "./scripts/calibrate_draft_vocab.py", + "--model", + tiny_llama_path, + "--data", + tiny_daring_anteater_path, + "--draft_vocab_size", + "100", + "--save_dir", + draft_vocab_cache_dir, + ], + "speculative_decoding", + ) + + model_name = os.path.basename(os.path.normpath(tiny_llama_path)) + d2t = torch.load(os.path.join(draft_vocab_cache_dir, model_name, "d2t.pt")) + assert d2t.shape[0] == 100, f"Expected draft vocab size 100, got {d2t.shape[0]}" + + # fmt: off -def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path, eagle_output_dir): - """Test Eagle3 training with a tiny llama model.""" +@pytest.mark.parametrize("cp_size", [1, 2]) +def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, cp_size): + """Test Eagle3 training with a tiny llama model, using different cp_size values.""" + available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + if cp_size == 2 and available_gpus < 2: + pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus)) + if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"): + pytest.skip("cp_size=2 requires torch 2.10.0") # Create an ultra-tiny EAGLE config for testing to reduce memory usage tiny_eagle_config = { "max_position_embeddings": 128, @@ -42,7 +80,7 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ } # Write the tiny config to a temporary file - config_file = tmp_path / "tiny_eagle_config.json" + config_file = tmp_path / f"tiny_eagle_config_cp{cp_size}.json" with open(config_file, "w") as f: json.dump(tiny_eagle_config, f) @@ -51,12 +89,29 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, - "--num_epochs", "1", + "--num_epochs", "0.25", "--lr", "1e-5", - "--num_gpu", str(num_gpus), "--mode", "eagle3", "--eagle_config", str(config_file), - "--output_dir", eagle_output_dir / "eagle-tinyllama", + "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}", + "--training_seq_len", "128", # Match max_position_embeddings + "--cp_size", str(cp_size), + ], + "speculative_decoding", + ) + + +def test_resume_training(tiny_daring_anteater_path, eagle_output_dir): + """Test resume training of Eagle3.""" + run_example_command( + [ + "./launch_train.sh", + "--model", eagle_output_dir / "eagle-tinyllama-cp1", + "--data", tiny_daring_anteater_path, + "--num_epochs", "0.5", + "--lr", "1e-5", + "--mode", "eagle3", + "--output_dir", eagle_output_dir / "eagle-tinyllama-cp1", "--training_seq_len", "128", # Match max_position_embeddings ], "speculative_decoding", @@ -68,9 +123,9 @@ def test_ar_validate(eagle_output_dir): run_example_command( [ "python", "./scripts/ar_validate.py", - "--model_path", eagle_output_dir / "eagle-tinyllama", - "--osl", "20", - "--num_samples", "10", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", + "--osl", "10", + "--num_samples", "5", "--steps", "3" ], "speculative_decoding", @@ -82,7 +137,7 @@ def test_export_hf_checkpoint(eagle_output_dir): run_example_command( [ "python", "./scripts/export_hf_checkpoint.py", - "--model_path", eagle_output_dir / "eagle-tinyllama", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", "--export_path", eagle_output_dir / "eagle-tinyllama-export", ], "speculative_decoding", @@ -104,17 +159,3 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) - -@pytest.mark.skip(reason="Needs dataset conversion to role-content format; consolidate data loading first.") -def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path,tmp_path): - """Test calibration of draft vocabulary.""" - run_example_command( - [ - "python", "./scripts/calibrate_draft_vocab.py", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--draft_vocab_size", "100", - "--save_dir", tmp_path / "draft_vocab_cache", - ], - "speculative_decoding", - ) diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py index 488e24855b..545b79d7ea 100644 --- a/tests/examples/speculative_decoding/test_medusa.py +++ b/tests/examples/speculative_decoding/test_medusa.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from _test_utils.examples.run_command import run_example_command @@ -32,7 +32,7 @@ def _run_hf_ptq(model_path, output_dir, qformat): ) -def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path): +def test_llama_medusa_fp8_qat(tiny_llama_path, tiny_daring_anteater_path, tmp_path): medusa_path = tmp_path / "medusa-tinyllama" # Test Medusa @@ -43,7 +43,6 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--num_gpu", str(num_gpus), "--mode", "medusa", "--output_dir", medusa_path, "--medusa_num_heads", "2", @@ -52,6 +51,8 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa "speculative_decoding", ) + pytest.skip("speculative decoding uses transformers 5.x, quantization example uses transformers 4.x") + # Test PTQ on Medusa _run_hf_ptq(medusa_path, tmp_path / "medusa-tinyllama-hf", "fp8") diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index cd4e34ca1d..a38322d141 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - import pytest import torch import torch.distributed as dist @@ -59,9 +57,3 @@ def set_torch_dtype(request): @pytest.fixture(scope="session", autouse=True) def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() - - -@pytest.fixture -def project_root_path(request: pytest.FixtureRequest) -> Path: - """Fixture providing the project root path for tests.""" - return Path(request.config.rootpath) diff --git a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py index a6f3608727..23a1439e94 100644 --- a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py +++ b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py @@ -29,15 +29,17 @@ "fuse_input_scale", "fuse_weight_scale", "fuse_weight_scale_2", - "fuse_prequant_scale", + "fuse_pre_quant_scale", + "fuse_svdquant_lora_a", ), [ - ("fp8", "tiny_llama-fp8", True, False, True, True), - ("nvfp4", "tiny_llama-nvfp4", True, False, True, True), - ("nvfp4_awq", "tiny_llama-nvfp4-awq", True, False, True, True), - ("int4_awq", "tiny_llama-int4-awq", True, False, True, True), - ("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True), - ("int8_wo", "tiny_llama-int8-wo", False, False, False, False), + ("fp8", "tiny_llama-fp8", True, False, True, True, False), + ("nvfp4", "tiny_llama-nvfp4", True, False, True, True, False), + ("nvfp4_awq", "tiny_llama-nvfp4-awq", True, False, True, True, False), + ("int4_awq", "tiny_llama-int4-awq", True, False, True, True, False), + ("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False), + ("int8_wo", "tiny_llama-int8-wo", False, False, False, False, False), + ("nvfp4_svdquant", "tiny_llama-nvfp4-svdquant", True, False, True, True, True), ], ) def test_unified_hf_export_and_check_safetensors( @@ -47,7 +49,8 @@ def test_unified_hf_export_and_check_safetensors( fuse_input_scale, fuse_weight_scale, fuse_weight_scale_2, - fuse_prequant_scale, + fuse_pre_quant_scale, + fuse_svdquant_lora_a, ): """ 1) Generates a .safetensors file by running hf_ptq.py with each --qformat. @@ -92,6 +95,18 @@ def test_unified_hf_export_and_check_safetensors( f"Expected .safetensors file not found for qformat={qformat}: {generated_file}" ) + # Map scale types to their conditions + scale_types = [ + ("input_scale", fuse_input_scale), + ("weight_scale", fuse_weight_scale), + ("weight_scale_2", fuse_weight_scale_2), + ("pre_quant_scale", fuse_pre_quant_scale), + ("weight_quantizer._svdquant_lora_a", fuse_svdquant_lora_a), + ] + + # Projection pairs to check for equality + proj_pairs = [("gate_proj", "up_proj"), ("q_proj", "k_proj"), ("q_proj", "v_proj")] + def _same_scale(name, key1, key2, f): if key1 in name: tensor1 = f.get_tensor(name) @@ -108,23 +123,11 @@ def _same_scale(name, key1, key2, f): assert tensor.shape is not None, f"Tensor '{name}' shape is None!" assert tensor.dtype is not None, f"Tensor '{name}' dtype is None!" - if "scale" in name: - # Map scale types to their conditions - scale_types = [ - ("input_scale", fuse_input_scale), - ("weight_scale", fuse_weight_scale), - ("weight_scale_2", fuse_weight_scale_2), - ("prequant_scale", fuse_prequant_scale), - ] - - # Projection pairs to check for equality - proj_pairs = [("gate_proj", "up_proj"), ("q_proj", "k_proj"), ("q_proj", "v_proj")] - - # Check each scale type if its condition is met - for scale_suffix, condition in scale_types: - if name.endswith(scale_suffix) and condition: - # Check each projection pair - for proj1, proj2 in proj_pairs: - _same_scale(name, proj1, proj2, f) + # Check each scale type if its condition is met + for scale_suffix, condition in scale_types: + if name.endswith(scale_suffix) and condition: + # Check each projection pair + for proj1, proj2 in proj_pairs: + _same_scale(name, proj1, proj2, f) # TODO: Load a pre-dumped log to compare textually or use pre-defined dict for sanity checks diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py deleted file mode 100644 index d6fa9400ba..0000000000 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from functools import partial - -import torch -from _test_utils.import_helper import skip_if_no_megatron - -skip_if_no_megatron(apex_or_te_required=True, mamba_required=True) - -from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model -from _test_utils.torch.megatron.utils import ( - run_mcore_inference, - run_mcore_inference_with_dummy_input, -) -from _test_utils.torch.misc import compare_outputs, set_seed -from _test_utils.torch.nas_prune.minitron_common import prune_minitron -from megatron.core.ssm.mamba_layer import MambaLayer -from megatron.core.transformer.identity_op import IdentityOp - -import modelopt.torch.nas as mtn -from modelopt.torch.prune.plugins.mcore_minitron import ( - ImportanceEstimatorRegistry, - _convert_model_to_dynamic_space, - get_mcore_minitron_config, -) - -SEED = 1234 - - -def _test_mcore_mamba_parameter_sorting(rank, size): - # Use relatively bigger model here for more accurate test for sorting - channel_divisor = 64 - - num_layers = size - hybrid_override_pattern = "M" * size - hidden_size = channel_divisor * 4 - mamba_state_dim = channel_divisor - mamba_head_dim = 16 - mamba_num_groups = 2 - max_sequence_length = 32 - vocab_size = 64 - batch_size = 2 - - model = get_mcore_mamba_hybrid_model( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=size, - initialize_megatron=True, - num_layers=num_layers, - hybrid_override_pattern=hybrid_override_pattern, - hidden_size=hidden_size, - mamba_state_dim=mamba_state_dim, - mamba_head_dim=mamba_head_dim, - mamba_num_groups=mamba_num_groups, - max_sequence_length=max_sequence_length, - vocab_size=vocab_size, - bf16=False, - ).cuda() - - # Randomize norm weights instead of all zeros or ones - for n, m in model.named_modules(): - if "norm" in n and not isinstance(m, IdentityOp): - m.weight.data = torch.randn_like(m.weight) - - model.eval() - dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) - ) - registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks - - # Compute activations for sorting - for _ in range(5): - run_mcore_inference_with_dummy_input(model, batch_size) - - # Get the output of the original model - prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() - y1 = run_mcore_inference(model, prompt_tokens) - - mtn.utils.sort_parameters(model) - registry.cleanup() - - # check if all mamba_num_heads, mamba_head_dim, hidden_size have been sorted - sortable_per_pp = [ - n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None - ] - # 2 mamba hps per layer + 1 for hidden_size (num_layers is not sorted!) - assert len(sortable_per_pp) == 2 * num_layers // size + 1 - - # sanity check if the model functionality is preserved after sorting - y2 = run_mcore_inference(model, prompt_tokens) - - # check if the inference results after sorting is the same - compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) - - -def test_mcore_mamba_parameter_sorting(): - set_seed(SEED) - spawn_multiprocess_job( - size=torch.cuda.device_count(), - job=_test_mcore_mamba_parameter_sorting, - backend="nccl", - ) - - -def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size): - channel_divisor = 4 - - num_layers = min(size * 2, 8) - hidden_size = channel_divisor * 8 - ffn_hidden_size = channel_divisor * 2 - num_attention_heads = 8 - num_query_groups = 4 - mamba_state_dim = channel_divisor * 2 - mamba_head_dim = channel_divisor * 2 - mamba_num_groups = 2 - num_moe_experts = 8 - vocab_size = 32 - batch_size = 2 - - def _get_model(initialize_megatron=True): - model = get_mcore_mamba_hybrid_model( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=size, - initialize_megatron=initialize_megatron, - num_layers=num_layers, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - ffn_hidden_size=ffn_hidden_size, - mamba_state_dim=mamba_state_dim, - mamba_head_dim=mamba_head_dim, - mamba_num_groups=mamba_num_groups, - moe_ffn_hidden_size=ffn_hidden_size, - moe_shared_expert_intermediate_size=ffn_hidden_size, - num_moe_experts=num_moe_experts, - vocab_size=vocab_size, - ).cuda() - return model - - model = _get_model() - - mamba_layer = None - for layer in model.decoder.layers: - if isinstance(layer, MambaLayer): - mamba_layer = layer - break - assert mamba_layer is not None, f"No MambaLayer found in the model PP rank {rank}!" - mamba_num_heads = mamba_layer.mixer.nheads - - def forward_loop(m): - for _ in range(5): - run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) - - # Traditional GPT pruning parameters - pruned_ffn_hidden_size = ffn_hidden_size // 2 - pruned_num_attention_heads = num_attention_heads // 2 - pruned_hidden_size = hidden_size // 2 - pruned_num_moe_experts = num_moe_experts // 2 - - # Mamba-specific pruning parameters - pruned_mamba_num_heads = mamba_num_heads // 2 - pruned_mamba_head_dim = mamba_head_dim // 2 - - # Base export config with GPT/Attention parameters - export_config = { - "ffn_hidden_size": pruned_ffn_hidden_size, - "num_attention_heads": pruned_num_attention_heads, - "hidden_size": pruned_hidden_size, - "mamba_num_heads": pruned_mamba_num_heads, - "mamba_head_dim": pruned_mamba_head_dim, - "moe_ffn_hidden_size": pruned_ffn_hidden_size, - "moe_shared_expert_intermediate_size": pruned_ffn_hidden_size, - "num_moe_experts": pruned_num_moe_experts, - } - prune_minitron( - model, - export_config, - {"forward_loop": forward_loop, "scores_path": ckpt_path}, - channel_divisor, - ) - - # Assert weights are pruned correctly - mixer = mamba_layer.mixer - bc = 2 * mixer.ngroups * mixer.d_state - assert mixer.nheads == pruned_mamba_num_heads - assert mixer.headdim == pruned_mamba_head_dim - assert mixer.in_proj.input_size == pruned_hidden_size - assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim - assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads - assert mixer.out_proj.input_size == mixer.d_inner - assert mixer.out_proj.output_size == pruned_hidden_size - assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bc - - # Assert model.config is updated for correct save/restoring - assert model.config.ffn_hidden_size == pruned_ffn_hidden_size - assert model.config.num_attention_heads == pruned_num_attention_heads - assert model.config.hidden_size == pruned_hidden_size - assert model.config.mamba_num_heads == pruned_mamba_num_heads - assert model.config.mamba_head_dim == pruned_mamba_head_dim - assert model.config.moe_ffn_hidden_size == pruned_ffn_hidden_size - assert model.config.moe_shared_expert_intermediate_size == pruned_ffn_hidden_size - assert model.config.num_moe_experts == pruned_num_moe_experts - - # Assert forward pass works on the pruned model - run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size) - - # Assert re-pruning from scores_path works without running the forward loop again - model = _get_model(initialize_megatron=False) - prune_minitron(model, export_config, {"scores_path": ckpt_path}, channel_divisor) - - -def test_mcore_mamba_hybrid_pruning(tmp_path): - spawn_multiprocess_job( - size=torch.cuda.device_count(), - job=partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth"), - backend="nccl", - ) diff --git a/tests/gpu/torch/puzzletron/export/__init__.py b/tests/gpu/torch/puzzletron/export/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/tests/gpu/torch/puzzletron/export/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/gpu/torch/puzzletron/export/mbridge/__init__.py b/tests/gpu/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/tests/gpu/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py new file mode 100644 index 0000000000..0c60bcd007 --- /dev/null +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader + +RAND_SEED = 42 +torch.manual_seed(RAND_SEED) + + +def test_update_hessian(): + """Test for update_hessian function with both random and known inputs.""" + # Test 1: Random input - general functionality test + torch.manual_seed(42) + batch_size = 2 + seq_len = 3 + features = 4 + input_tensor = torch.randn(batch_size, seq_len, features, dtype=torch.float32) + + hessian = torch.zeros(features, features, dtype=torch.float32) + n_samples = 0 + + updated_hessian, new_n_samples = update_hessian(input_tensor, hessian, n_samples) + + # Verify output shape + assert updated_hessian.shape == (features, features), ( + f"Expected hessian shape ({features}, {features}), got {updated_hessian.shape}" + ) + + # Verify sample count is updated correctly (incremented by batch_size) + assert new_n_samples == batch_size, f"Expected n_samples={batch_size}, got {new_n_samples}" + + # Verify hessian is not all zeros after update + assert not torch.allclose(updated_hessian, torch.zeros_like(updated_hessian)), ( + "Hessian should not be all zeros after update" + ) + + # Verify hessian is symmetric (should be for outer product X @ X.T) + assert torch.allclose(updated_hessian, updated_hessian.t()), "Hessian should be symmetric" + + # Test 2: Known input - verify correct hessian calculation + batch_size = 6 + seq_len = 2 + features = 2 + input_tensor = torch.ones(batch_size, seq_len, features, dtype=torch.float32) + + hessian = torch.zeros(features, features, dtype=torch.float32) + n_samples = 0 + + updated_hessian, new_n_samples = update_hessian(input_tensor, hessian, n_samples) + + # Manual calculation: + # input_flat shape: (features, batch*seq) = (2, 12), all ones + # scaled_input = sqrt(2/6) * input_flat = sqrt(1/3) * ones(2, 12) + # outer_product = scaled_input @ scaled_input.t() = (2/6) * ones(2,12) @ ones(12,2) = [[4,4], [4,4]] + # Note: The scaling factor is (2/n_samples), so with n_samples=6 and 12 tokens: (2/6)*12 = 4 + expected_hessian = torch.ones(features, features, dtype=torch.float32) * 4.0 + + assert torch.allclose(updated_hessian, expected_hessian, atol=1e-5), ( + f"Expected hessian {expected_hessian}, got {updated_hessian}" + ) + assert new_n_samples == batch_size + + # Test 3: Accumulated hessians - verify equivalence + # Processing [6,2,2] in one step should equal processing [2,2,2] three times + seq_len = 2 + features = 2 + + # Process in 3 steps of batch_size=2 + hessian_accumulated = torch.zeros(features, features, dtype=torch.float32) + n_samples_accumulated = 0 + + for i in range(3): + input_batch = torch.ones(2, seq_len, features, dtype=torch.float32) + hessian_accumulated, n_samples_accumulated = update_hessian( + input_batch, hessian_accumulated, n_samples_accumulated + ) + + # Verify that accumulated result matches single-step result from Test 2 + assert torch.allclose(hessian_accumulated, updated_hessian, atol=1e-5), ( + f"Accumulated hessian should match single-step: expected {updated_hessian}, got {hessian_accumulated}" + ) + assert torch.allclose(hessian_accumulated, expected_hessian, atol=1e-5), ( + f"Accumulated hessian should match expected: expected {expected_hessian}, got {hessian_accumulated}" + ) + assert n_samples_accumulated == 6, f"Expected n_samples=6, got {n_samples_accumulated}" + + +@pytest.mark.parametrize( + ("block_size", "dim", "model_weight", "expect_weight_change"), + [ + (4, 16, torch.randn(16, 16).to("cuda"), True), # random weight + ( + 4, + 16, + torch.ones(16, 16).to("cuda"), + False, + ), # all same weight -> no quantization error -> no GPTQ update + ], +) +def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): + model = torch.nn.Linear(dim, dim).to("cuda") + model.weight.data = model_weight + model.name = "linear" + original_weight = model_weight.clone() + input = torch.randn(2, 16, dim).to("cuda") + hessian = torch.zeros(dim, dim).to("cpu") + n_samples = 0 + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + + # Get qdq weight + q_dq_weight = model.weight_quantizer(model.weight.data) + + # Restore original weight + model.weight.data = original_weight.clone() + + hessian, n_samples = update_hessian(input, hessian, n_samples) + + # Verify n_samples is update using hessian matrix + assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" + + # Perform another forward pass to update hessian matrix + input_2 = torch.randn(3, 16, dim).to("cuda") + hessian, n_samples = update_hessian(input_2, hessian, n_samples) + assert n_samples == input.shape[0] + input_2.shape[0], ( + "n_samples should be equal to input.shape[0] + input_2.shape[0]" + ) + + hessian = hessian.to(input.device) + blockwise_weight_update(model, hessian, block_size, 0.1) + if expect_weight_change: + # Weight must change as GPTQ updates weights to adjust for quantization error + assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" + else: + assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" + + +@pytest.mark.parametrize( + "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] +) +def test_gptq_e2e_flow(quant_cfg): + model = AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True + ) + + # can't set attribute 'pad_token' for "" + # We skip this step for Nemo models + if tokenizer.pad_token != "" or tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Left padding usually provides better calibration result. + tokenizer.padding_side = "left" + + assert tokenizer.pad_token is not None, "Pad token cannot be set!" + model.eval() + + quant_cfg = copy.deepcopy(quant_cfg) + quant_cfg["algorithm"] = "gptq_lite" + # Define quantizer/dataloader + calib_dataloader = get_dataset_dataloader( + dataset_name="cnn_dailymail", + tokenizer=tokenizer, + batch_size=32, + num_samples=512, + device="cuda", + include_labels=False, + ) + # Only run single sample for preview + prompt = "Where is New York city?" + input_ids = tokenizer(prompt, return_tensors="pt") + print(f"Input ids: {input_ids}") + generated_ids_before_ptq = model.generate( + input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 + ) + + print( + f"Generated ids before quantization: {tokenizer.decode(generated_ids_before_ptq[0], skip_special_tokens=True)}" + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + generated_ids_after_ptq = model.generate( + input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 + ) + print( + f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" + ) diff --git a/tests/gpu/torch/quantization/test_hadamard.py b/tests/gpu/torch/quantization/test_hadamard.py index c768bc87e3..64dd39e2cf 100644 --- a/tests/gpu/torch/quantization/test_hadamard.py +++ b/tests/gpu/torch/quantization/test_hadamard.py @@ -41,9 +41,16 @@ def test_hadamard_transform(dim): xxt_h = x_h @ x_h.T # The numerical error can be large, especially for 16-bit floats. assert torch.allclose(xxt_h, xxt, atol=0.05) + x_h_fp32 = normalized_hadamard_transform(x, rotate_fp32=True) + xxt_h_fp32 = x_h_fp32 @ x_h_fp32.T + assert torch.allclose(xxt_h_fp32, xxt, atol=0.05) -def test_kv_rotate(): +@pytest.mark.parametrize( + "rotate_fp32", + [True, False], +) +def test_kv_rotate(rotate_fp32): mtq.plugins.register_attention_for_kv_quant(SDPAAttention) model = nn.Sequential(SDPAAttention()) mtq.replace_quant_module(model) @@ -51,11 +58,15 @@ def test_kv_rotate(): set_quantizer_by_cfg(model, {"*": {"enable": False}}) dummy_input = SDPAAttention.get_input(device="cuda") output_ref = model(dummy_input) + if rotate_fp32: + rotate = {"enable": True, "rotate_fp32": True} + else: + rotate = True with set_quantizer_by_cfg_context( model, { "*[qk]_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ): @@ -67,7 +78,7 @@ def test_kv_rotate(): model, { "*k_bmm_quantizer": { - "rotate": True, + "rotate": rotate, }, }, ): diff --git a/tests/gpu/torch/quantization/test_nvfp4_onnx_export.py b/tests/gpu/torch/quantization/test_nvfp4_onnx_export.py index c5a7922097..8ffdf38293 100644 --- a/tests/gpu/torch/quantization/test_nvfp4_onnx_export.py +++ b/tests/gpu/torch/quantization/test_nvfp4_onnx_export.py @@ -98,6 +98,7 @@ def forward_loop(model, run_backward=False): output_names=["output"], export_params=True, opset_version=17, + dynamo=False, ) onnx_model = NVFP4QuantExporter.process_model(onnx.load(onnx_path)) diff --git a/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py new file mode 100644 index 0000000000..b1b3691a79 --- /dev/null +++ b/tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NVFP4StaticQuantizer and NVFP4MSECalibrator.""" + +import pytest +import torch + +from modelopt.torch.quantization.calib import NVFP4MSECalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.tensor_quant import ( + scaled_e4m3_impl, + static_blockwise_fp4_fake_quant, +) + + +@pytest.mark.parametrize("device", ["cuda"]) +class TestNVFP4StaticQuantizer: + def test_from_tensor_quantizer(self, device): + """Test creating NVFP4StaticQuantizer from TensorQuantizer.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + tq = TensorQuantizer(quant_attribute_cfg=cfg).to(device) + tq.amax = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) + + nvfp4_quantizer = NVFP4StaticQuantizer.from_tensor_quantizer(tq) + + assert nvfp4_quantizer.global_amax is None + assert nvfp4_quantizer._num_bits == (2, 1) + assert torch.allclose(nvfp4_quantizer._amax, tq._amax) + + def test_global_amax_property(self, device): + """Test global_amax property getter/setter.""" + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: 16, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + assert quantizer.global_amax is None + + quantizer.global_amax = torch.tensor(5.0, device=device) + assert quantizer.global_amax is not None + assert torch.isclose(quantizer.global_amax, torch.tensor(5.0, device=device)) + + quantizer.global_amax = 10.0 + assert torch.isclose(quantizer.global_amax, torch.tensor(10.0, device=device)) + + quantizer.global_amax = None + assert quantizer.global_amax is None + + def test_fake_quantize_with_both_amaxs(self, device): + """Test _fake_quantize uses both _amax and _global_amax.""" + num_blocks = 4 + block_size = 16 + + cfg = QuantizerAttributeConfig( + num_bits=(2, 1), + block_sizes={-1: block_size, "type": "static", "scale_bits": (4, 3)}, + ) + quantizer = NVFP4StaticQuantizer(quant_attribute_cfg=cfg).to(device) + + x = torch.randn(num_blocks, block_size, device=device) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + quantizer.amax = per_block_amax + quantizer.global_amax = global_amax + + output = quantizer._fake_quantize(x) + + expected = static_blockwise_fp4_fake_quant( + x, + per_block_amax, + global_amax, + ) + + assert torch.allclose(output, expected) + + +@pytest.mark.parametrize("device", ["cuda"]) +class TestNVFP4MSECalibrator: + def test_basic_initialization(self, device): + """Test NVFP4MSECalibrator initialization.""" + num_blocks = 4 + amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + cal = NVFP4MSECalibrator(amax=amax, global_amax=global_amax) + + assert cal._losses_sum is None + assert cal._amax is None + + def test_fp8_candidates_generation(self, device): + """Test that 126 valid FP8 candidates are generated.""" + num_blocks = 4 + amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + cal = NVFP4MSECalibrator(amax=amax, global_amax=global_amax) + + candidates = cal._generate_candidates(device) + + assert candidates.shape[0] == 126 + assert torch.all(torch.isfinite(candidates)) + assert torch.all(candidates > 0) + + def test_collect_and_compute_amax(self, device): + """Test collect and compute_amax workflow.""" + num_blocks = 8 + block_size = 16 + per_block_amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(10.0, device=device) + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + x = torch.randn(num_blocks, block_size, device=device) + cal.collect(x) + + assert cal._losses_sum is not None + assert len(cal._losses_sum) == 126 + + amax = cal.compute_amax() + + assert amax is not None + assert amax.shape[0] == num_blocks + assert torch.all(torch.isfinite(amax)) + assert torch.all(amax > 0) + + def test_multiple_collections(self, device): + """Test that multiple collections accumulate correctly.""" + num_blocks = 4 + block_size = 16 + per_block_amax = torch.ones(num_blocks, device=device) + global_amax = torch.tensor(5.0, device=device) + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + x1 = torch.randn(num_blocks, block_size, device=device) + x2 = torch.randn(num_blocks, block_size, device=device) + + cal.collect(x1) + losses_after_first = [loss.clone() for loss in cal._losses_sum] + + cal.collect(x2) + losses_after_second = cal._losses_sum + + for loss1, loss2 in zip(losses_after_first, losses_after_second): + assert torch.all(loss2 >= loss1 - 1e-6) + + def test_per_block_independent_optimization(self, device): + """Test that each block is optimized independently. + + Uses constant values per block to ensure deterministic behavior. + """ + num_blocks = 4 + block_size = 16 + + # Create blocks with constant values (all elements in a block are the same) + # This ensures deterministic behavior for the test + x = torch.zeros(num_blocks, block_size, device=device) + x[0, :] = 0.5 + x[1, :] = 2.0 + x[2, :] = 5.0 + x[3, :] = 10.0 + + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + axis=0, # reduce_axis = -1 + global_amax=global_amax, + quant_func=quant_func, + ) + + cal.collect(x) + amax = cal.compute_amax() + + # With constant values per block, the optimal amax should scale with the block values + assert amax[1] > amax[0] + assert amax[2] > amax[1] + assert amax[3] > amax[2] + + def test_fp8_sweep_generates_quantized_scales(self, device): + """Test that the fp8 sweep produces scales that are already FP8-quantized.""" + num_blocks = 8 + block_size = 16 + + x = torch.randn(num_blocks, block_size, device=device) + per_block_amax = x.abs().amax(dim=-1) + global_amax = per_block_amax.max() + + def quant_func(x, amax): + return static_blockwise_fp4_fake_quant(x, amax, global_amax) + + cal = NVFP4MSECalibrator( + amax=per_block_amax, + global_amax=global_amax, + quant_func=quant_func, + ) + + cal.collect(x) + amax = cal.compute_amax() + + # The calibrator sweeps over FP8 candidates, so the resulting scales + # should already be representable in FP8 (i.e., quantize-dequantize is a no-op). + scale = amax.float() / 6.0 + scale_fp8_quant_amax = global_amax.float() / 6.0 + scale_qdq = scaled_e4m3_impl(scale, scale_fp8_quant_amax) + assert torch.allclose(scale_qdq, scale) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index 26df7a8c82..08fac486f7 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -15,6 +15,8 @@ """Unit tests for quantized tensors.""" +import math + import pytest import torch from _test_utils.torch.misc import set_seed @@ -22,7 +24,7 @@ from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.nn import TensorQuantizer -from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor set_seed() @@ -602,3 +604,388 @@ def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, b assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1) assert hasattr(quantizer, "_scale") assert quantizer._scale.numel() > 1 + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "input_shape", + [ + (128, 128), + (256, 64), + (512, 512), + # 3D shapes (MoE): (num_experts, out_dim, in_dim) + (4, 64, 128), + (1, 64, 128), # single expert edge case + (32, 256, 512), # large-scale MoE + # Shapes requiring padding (last dim not divisible by block size 32) + (8, 128, 65), # odd in_dim + (128, 65), + (256, 100), + (64, 33), + ], + ) + def test_mxfp8_quantize_dequantize(self, device, input_dtype, input_shape): + """Test MXFP8 quantization and dequantization produces correct E8M0 scales.""" + # Create test tensor + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize using MXFP8QTensor + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + # Verify scale shape: last dim is ceil(in_dim / 32), other dims preserved + expected_scale_shape = ( + *input_shape[:-1], + math.ceil(input_shape[-1] / MXFP8QTensor.BLOCK_SIZE), + ) + assert e8m0_scale.shape == expected_scale_shape, ( + f"Expected scale shape {expected_scale_shape}, got {e8m0_scale.shape}" + ) + + # Verify quantized data is FP8 E4M3 and preserves original shape + assert qtensor._quantized_data.dtype == torch.float8_e4m3fn, ( + f"Expected float8_e4m3fn, got {qtensor._quantized_data.dtype}" + ) + assert qtensor._quantized_data.shape == input_shape, ( + f"Expected quantized data shape {input_shape}, got {qtensor._quantized_data.shape}" + ) + + # Dequantize + dequant_tensor = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + ) + + # Verify dequantized tensor shape and values match original + assert dequant_tensor.shape == input_shape, ( + f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" + ) + assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: " + f"max diff = {(dequant_tensor - test_tensor).abs().max()}" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_e8m0_scale_values(self, device): + """Test that MXFP8 produces correct E8M0 scale values (power-of-2 only).""" + # Create a tensor with known amax values per block + # MXFP8 block size is 32, so create a 2x64 tensor (2 rows, 2 blocks per row) + test_tensor = torch.zeros((2, 64), dtype=torch.float32, device=device) + + # First block (row 0, elements 0-31): max abs = 1.0, should give exponent ~127-8 = 119 + # (since E4M3 max is 448, log2(1/448) ≈ -8.8, ceil = -8, biased = 127 + (-8) = 119) + test_tensor[0, :32] = 1.0 + + # Second block (row 0, elements 32-63): max abs = 448.0, should give exponent = 127 + # (since 448/448 = 1, log2(1) = 0, biased = 127) + test_tensor[0, 32:64] = 448.0 + + # Third block (row 1, elements 0-31): max abs = 2.0 + test_tensor[1, :32] = 2.0 + + # Fourth block (row 1, elements 32-63): max abs = 0.5 + test_tensor[1, 32:64] = 0.5 + + # Quantize + qtensor, e8m0_scale = MXFP8QTensor.quantize(test_tensor) + + # Verify all scales are valid uint8 values + assert e8m0_scale.dtype == torch.uint8 + assert e8m0_scale.shape == (2, 2) + + # Verify dequantization works + dequant = qtensor.dequantize( + dtype=torch.float32, + scale=e8m0_scale, + ) + + # Check that the dequantized max values per block are close to original + assert torch.allclose(dequant[0, :32].max(), torch.tensor(1.0, device=device), rtol=0.1) + assert torch.allclose(dequant[0, 32:64].max(), torch.tensor(448.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, :32].max(), torch.tensor(2.0, device=device), rtol=0.1) + assert torch.allclose(dequant[1, 32:64].max(), torch.tensor(0.5, device=device), rtol=0.1) + + # fmt: off + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "test_input", + [ + # FP8 E4M3 boundary test values (max is 448, various powers of 2) + torch.tensor([[1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0, 0.5, 0.25, + 0.125, 0.0625, 0.03125, 0.015625, -1.0, -2.0, -4.0, -8.0, -16.0, -32.0, + -64.0, -128.0, -256.0, -448.0, -0.5, -0.25, -0.125, -0.0625, -0.03125, -0.015625]]), + # Mix of positive and negative values near E4M3 boundaries + torch.tensor([[448.0, 416.0, 384.0, 352.0, 320.0, 288.0, 256.0, 224.0, 192.0, 160.0, + 128.0, 96.0, 64.0, 48.0, 32.0, 24.0, -448.0, -416.0, -384.0, -352.0, -320.0, + -288.0, -256.0, -224.0, -192.0, -160.0, -128.0, -96.0, -64.0, -48.0, -32.0, -24.0]]), + ], + ) + def test_mxfp8_quantize_boundary_values(self, test_input, device, input_dtype): + # fmt: on + """Test MXFP8 quantization with E4M3 boundary values.""" + x = test_input.to(input_dtype).to(device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(x) + + # Verify scale is uint8 (E8M0 format) + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + + dequant = qtensor.dequantize( + dtype=input_dtype, + scale=e8m0_scale, + ) + + # FP8 E4M3 has limited precision, allow reasonable tolerance + assert torch.allclose(dequant, x, rtol=5e-2, atol=5e-2), ( + f"Dequantized tensor differs from original: max diff = {(dequant - x).abs().max()}" + ) + + @pytest.mark.parametrize( + "input_shape", + [(1600, 1600)], + ) + def test_mxfp8_quantize_gpu_mem(self, input_shape): + """Test MXFP8 GPU memory usage during quantization.""" + + def _get_gpu_mem_used(): + device = torch.device("cuda:0") + free, total = torch.cuda.mem_get_info(device) + return total - free + + # Warmup + test_input = torch.rand((32, 32), dtype=torch.float32, device="cuda") + MXFP8QTensor.quantize(test_input) + + test_input = torch.rand(input_shape, dtype=torch.float32, device="cuda") + torch.cuda.empty_cache() + + input_size = test_input.element_size() * test_input.numel() + before_quantize = _get_gpu_mem_used() + MXFP8QTensor.quantize(test_input) + after_quantize = _get_gpu_mem_used() + + # Memory increase should be reasonable (less than 3x input size) + # MXFP8 stores FP8 data (1 byte) + uint8 scales, so should be efficient + assert (after_quantize - before_quantize) < input_size * 3, ( + f"Memory increase too large: {after_quantize - before_quantize} bytes " + f"for input size {input_size} bytes" + ) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize( + "input_shape", + [(128, 64), (256, 128), (512, 256)], + ) + def test_mxfp8_get_weights_scaling_factor(self, device, input_shape): + """Test MXFP8 get_weights_scaling_factor returns correct E8M0 scales.""" + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + # Get scaling factor + e8m0_scale = MXFP8QTensor.get_weights_scaling_factor(weight) + + # Verify dtype and shape + assert e8m0_scale.dtype == torch.uint8, f"Expected uint8 scale, got {e8m0_scale.dtype}" + expected_shape = (input_shape[0], input_shape[1] // MXFP8QTensor.BLOCK_SIZE) + assert e8m0_scale.shape == expected_shape, ( + f"Expected scale shape {expected_shape}, got {e8m0_scale.shape}" + ) + + # Verify E8M0 values are in valid range [0, 254] (biased exponent = unbiased + 127) + # The code clamps unbiased exponent to [-127, 127], giving biased range [0, 254] + # Note: 255 (0xFF) represents NaN in E8M0 and should never appear from valid weights + assert torch.all(e8m0_scale <= 254), "E8M0 scale contains NaN value (255)" + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "input_shape", + [ + (64, 64), + (128, 128), + (4, 64, 128), # 3D MoE shape + # Note: All shapes must have last dim divisible by 32 since + # get_weights_scaling_factor() requires this (unlike quantize() which pads) + ], + ) + def test_mxfp8_quantize_with_precomputed_scale(self, device, input_dtype, input_shape): + """Test MXFP8 quantize() with pre-computed weights_scaling_factor.""" + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Quantize without pre-computed scale (baseline) + qtensor_auto, scale_auto = MXFP8QTensor.quantize(test_tensor) + + # Pre-compute scale and pass to quantize + precomputed_scale = MXFP8QTensor.get_weights_scaling_factor(test_tensor) + qtensor_precomputed, scale_precomputed = MXFP8QTensor.quantize( + test_tensor, weights_scaling_factor=precomputed_scale + ) + + # Verify scales match + assert torch.equal(scale_auto, scale_precomputed), ( + "Pre-computed scale should match auto-computed scale" + ) + + # Verify quantized data matches + assert torch.equal(qtensor_auto._quantized_data, qtensor_precomputed._quantized_data), ( + "Quantized data should match when using pre-computed scale" + ) + + # Verify dequantized results match + dequant_auto = qtensor_auto.dequantize(dtype=input_dtype, scale=scale_auto) + dequant_precomputed = qtensor_precomputed.dequantize( + dtype=input_dtype, scale=scale_precomputed + ) + assert torch.equal(dequant_auto, dequant_precomputed), ( + "Dequantized results should match" + ) + + @pytest.mark.parametrize( + ("amax_value", "expected_exponent"), + [ + (0.0, -127.0), # Zero amax: minimum exponent + (448.0, 0.0), # E4M3_MAX: exponent 0 + (1.0, -8.0), # log2(1/448) ~ -8.8, ceil = -8 + (1e40, 127.0), # Very large amax: clamps to max + (1e-50, -127.0), # Very small amax: clamps to min + ], + ) + def test_mxfp8_compute_e8m0_exponent_edge_cases(self, amax_value, expected_exponent): + """Test _compute_e8m0_exponent handles edge cases correctly.""" + amax = torch.tensor([amax_value], device="cuda") + exponent = MXFP8QTensor._compute_e8m0_exponent(amax) + assert exponent.item() == expected_exponent, ( + f"amax={amax_value} should give exponent {expected_exponent}, got {exponent.item()}" + ) + + def test_mxfp8_get_weights_scaling_factor_asserts_1d_weight(self): + """Test get_weights_scaling_factor raises assertion for 1D tensor.""" + weight_1d = torch.randn(64, device="cuda") + with pytest.raises(AssertionError, match="Weight must be at least 2D"): + MXFP8QTensor.get_weights_scaling_factor(weight_1d) + + def test_mxfp8_get_weights_scaling_factor_asserts_non_divisible(self): + """Test get_weights_scaling_factor raises assertion when dim not divisible by 32.""" + # 33 is not divisible by 32 + weight = torch.randn(64, 33, device="cuda") + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.get_weights_scaling_factor(weight) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_quantize_with_scale_asserts(self, device): + """Test quantize_with_scale raises assertions for invalid inputs.""" + # Test wrong scale dtype assertion + weight = torch.randn(64, 64, dtype=torch.float32, device=device) + wrong_dtype_scale = torch.randn(64, 2, dtype=torch.float32, device=device) + with pytest.raises(AssertionError, match="weights_scaling_factor must be"): + MXFP8QTensor.quantize_with_scale(weight, wrong_dtype_scale) + + # Test non-divisible dimension assertion + weight_bad_dim = torch.randn(64, 33, dtype=torch.float32, device=device) + scale = torch.randint(0, 255, (64, 1), dtype=torch.uint8, device=device) + with pytest.raises(AssertionError, match="must be divisible by MXFP8 block size"): + MXFP8QTensor.quantize_with_scale(weight_bad_dim, scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_3d_moe(self, device): + """Test get_weights_scaling_factor_from_quantizer handles 3D MoE tensors.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + _scale = None + + quantizer = MockQuantizer() + + # Test when _scale is None (should compute from weight) + scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + expected_shape = ( + input_shape[0], + input_shape[1], + input_shape[2] // MXFP8QTensor.BLOCK_SIZE, + ) + assert scale.shape == expected_shape + + # Test when _scale is provided with correct 3D shape + quantizer._scale = torch.randint(0, 255, expected_shape, dtype=torch.uint8, device=device) + scale_from_quantizer = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( + weight, quantizer + ) + assert torch.equal(scale_from_quantizer, quantizer._scale) + + @pytest.mark.parametrize("device", ["cuda"]) + def test_mxfp8_get_weights_scaling_factor_from_quantizer_scale_shape_mismatch(self, device): + """Test get_weights_scaling_factor_from_quantizer raises assertion on shape mismatch.""" + input_shape = (4, 64, 128) # (num_experts, out_dim, in_dim) + weight = torch.randn(input_shape, dtype=torch.float32, device=device) + + class MockQuantizer: + block_sizes = {-1: MXFP8QTensor.BLOCK_SIZE} + # Wrong shape: 2D instead of 3D (missing num_experts dimension) + _scale = torch.randint( + 0, 255, (64, 4), dtype=torch.uint8, device=device + ) + + quantizer = MockQuantizer() + + with pytest.raises(AssertionError, match="Scale shape .* does not match expected shape"): + MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, quantizer) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_mxfp8_dequantize_default_dtype(self, device, input_dtype): + """Test dequantize uses original dtype when dtype=None.""" + input_tensor = torch.randn(64, 64, dtype=input_dtype, device=device) + qtensor, e8m0_scale = MXFP8QTensor.quantize(input_tensor) + + # Dequantize without specifying dtype + dequant = qtensor.dequantize(scale=e8m0_scale) + + assert dequant.dtype == input_dtype + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "input_shape", + [ + (64, 64), + (128, 128), + (4, 64, 128), # 3D MoE shape + ], + ) + def test_mxfp8_fake_quant(self, device, input_dtype, input_shape): + """Test MXFP8 fake quantization via TensorQuantizer matches real quant+dequant.""" + block_sizes = {-1: 32, "type": "dynamic", "scale_bits": (8, 0)} + + # Create fake quant quantizer + fake_quant_cfg = QuantizerAttributeConfig( + num_bits=(4, 3), block_sizes=block_sizes, fake_quant=True, axis=None + ) + fake_quantizer = TensorQuantizer(fake_quant_cfg).to(device) + + # Create real quant quantizer + real_quant_cfg = QuantizerAttributeConfig( + num_bits=(4, 3), block_sizes=block_sizes, fake_quant=False, axis=None + ) + real_quantizer = TensorQuantizer(real_quant_cfg).to(device) + + # Test tensor + test_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + # Fake quant output + fake_quant_output = fake_quantizer(test_tensor) + + # Real quant + dequant + q_tensor = real_quantizer(test_tensor) + real_dequant_output = real_quantizer(q_tensor) + + # Verify fake quant matches real quant+dequant + assert fake_quant_output.shape == test_tensor.shape + assert fake_quant_output.dtype == test_tensor.dtype + assert torch.allclose(fake_quant_output, real_dequant_output, rtol=5e-2, atol=5e-2), ( + f"Fake quant differs from real quant+dequant: " + f"max diff = {(fake_quant_output - real_dequant_output).abs().max()}" + ) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 811e0be813..3e9ff4256c 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -28,6 +28,47 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.extensions import get_cuda_ext_mx +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + }, + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, + }, +} + @pytest.mark.parametrize("model_cls", [SimpleLinear, SimpleConv, SimpleConvLinear]) @pytest.mark.parametrize( @@ -46,12 +87,15 @@ mtq.NVFP4_AWQ_LITE_CFG, mtq.NVFP4_AWQ_CLIP_CFG, mtq.NVFP4_AWQ_FULL_CFG, + mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, mtq.MXFP8_DEFAULT_CFG, mtq.MXFP6_DEFAULT_CFG, mtq.MXFP4_DEFAULT_CFG, mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, + NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, ], ) def test_quantize(model_cls, config): @@ -68,6 +112,9 @@ def test_quantize(model_cls, config): mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, + NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, + mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") @@ -88,7 +135,10 @@ def test_quantize(model_cls, config): (SimpleLinear, mtq.INT8_SMOOTHQUANT_CFG), (SimpleLinear, mtq.W4A8_AWQ_BETA_CFG), (SimpleConvLinear, mtq.INT8_DEFAULT_CFG), + (SimpleLinear, NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG), + (SimpleLinear, NVFP4_WEIGHT_ACT_MSE_CFG), ], ) def test_save_restore(model_cls, quant_config): - save_restore_test(model_cls, "cuda", quant_config) + test_cpu_restore = quant_config == mtq.INT8_SMOOTHQUANT_CFG + save_restore_test(model_cls, "cuda", quant_config, test_cpu_restore=test_cpu_restore) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 26b1e377fa..e84b1a49ad 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -172,7 +172,7 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): inputs.abs().amax(), ) assert torch.allclose(quantized_outputs, expected_outputs) - if triton_kernel.IS_AVAILABLE and not skip_triton: + if hasattr(triton_kernel, "fp4_fake_quant_block") and not skip_triton: quantized_outputs_triton = triton_kernel.fp4_fake_quant_block( inputs, inputs.abs().amax() ) @@ -204,3 +204,98 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): test_in *= sign test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign _test_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [8, 16, 32]) + @pytest.mark.parametrize("skip_scale_quant", [True, False]) + def test_static_blockwise_fp4(self, set_torch_dtype, block_size, skip_scale_quant): + # Test with e2m1 table values + sign = torch.randint(0, 2, (1, 8)).cuda() * 2 - 1 + + def _get_test_inputs_outputs(test_in, test_out, num_blocks=4): + return torch.concat((test_in,) * (block_size // 8), dim=-1).repeat( + num_blocks, 1 + ), torch.concat((test_out,) * (block_size // 8), dim=-1).repeat(num_blocks, 1) + + def _test_static_fp4_kernel(test_in, test_out, amax_value=6.0): + inputs, expected_outputs = _get_test_inputs_outputs(test_in, test_out) + num_blocks = inputs.shape[0] + amax = torch.full((num_blocks,), amax_value, device=inputs.device) + + quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant( + inputs, amax=amax, quantize_block_scales=not skip_scale_quant + ) + + # Only check exact values when skip_scale_quant=True + # When scale quantization is enabled, the scale changes slightly, affecting outputs + if skip_scale_quant: + assert torch.allclose(quantized_outputs_triton, expected_outputs, atol=1e-6) + else: + assert quantized_outputs_triton.shape == expected_outputs.shape + + test_in = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + test_out = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + if skip_scale_quant: + # Test slightly below the e2m1 boundary values. + # Numbers should be quantized down to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] -= 0.1 + test_in *= sign + test_out = torch.tensor([[0.0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + # Test slightly above the e2m1 boundary values. + # Numbers should be quantized up to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] += 0.1 + test_in *= sign + test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif( + not hasattr(triton_kernel, "fp4_fake_quant_block"), + reason="fp4_fake_quant_block requires compute >= 8.9", + ) + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 8, 16]) + def test_static_vs_dynamic_fp4_kernels(self, set_torch_dtype, block_size, num_blocks): + """Test that static kernel with computed scales matches dynamic kernel behavior. + + The dynamic kernel computes scales dynamically from block-wise max values with FP8 quantization. + This test verifies that the static kernel with pre-computed amax (matching dynamic kernel's logic) + produces the same results as the dynamic kernel. + """ + torch.manual_seed(42) + + x = torch.randn(num_blocks, block_size, dtype=torch.float32).cuda() * 10 + block_amax = x.abs().max(dim=1, keepdim=False)[0] + global_amax = block_amax.max() + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, + amax=block_amax, + global_amax=global_amax, + quantize_block_scales=True, + ) + output_dynamic = triton_kernel.fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(output_static, output_dynamic, rtol=1e-3, atol=1e-5), ( + f"Static and dynamic kernels produced different outputs " + f"(param=amax).\n" + f"Max abs diff: {(output_static - output_dynamic).abs().max()}\n" + f"Mean abs diff: {(output_static - output_dynamic).abs().mean()}\n" + f"Max relative diff: {((output_static - output_dynamic).abs() / (output_dynamic.abs() + 1e-8)).max()}" + ) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 0000000000..97296971d7 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,383 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch.sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied (Exponential model params) + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibration params (a, b) are set + if hasattr(method, "calibration_params") and method.calibration_params: + for params in method.calibration_params.values(): + assert "a" in params and params["a"] > 0 + assert "b" in params and params["b"] > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "enable": True, + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py index c90b99bba3..df4cfaa65d 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -66,7 +66,7 @@ def test_load_and_sparsify(self, tinyllama_model): sparse_cfg={ "*attn*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -94,7 +94,7 @@ def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, } @@ -124,7 +124,7 @@ def test_forward_decode(self, tinyllama_model): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-5, # More conservative for decode + "threshold": {"prefill": 1e-3, "decode": 1e-5}, # More conservative for decode "backend": "pytorch", "enable": True, } @@ -163,7 +163,7 @@ def test_gqa_attention(self, tinyllama_model): sparse_config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "backend": "pytorch", "enable": True, } diff --git a/tests/gpu_megatron/_extensions b/tests/gpu_megatron/_extensions new file mode 120000 index 0000000000..dc4ffce338 --- /dev/null +++ b/tests/gpu_megatron/_extensions @@ -0,0 +1 @@ +../gpu/_extensions/ \ No newline at end of file diff --git a/tests/gpu_megatron/torch/conftest.py b/tests/gpu_megatron/torch/conftest.py new file mode 120000 index 0000000000..40eda16c0f --- /dev/null +++ b/tests/gpu_megatron/torch/conftest.py @@ -0,0 +1 @@ +../../gpu/torch/conftest.py \ No newline at end of file diff --git a/tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py b/tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py new file mode 100644 index 0000000000..b3b35e7927 --- /dev/null +++ b/tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import run_mcore_inference_with_dummy_input +from _test_utils.torch.misc import set_seed + +import modelopt.torch.distill as mtd +from modelopt.torch.distill.plugins.megatron import ( + DistillationConfig, + adjust_distillation_model_for_mcore, + setup_distillation_config, +) + +SEED = 1234 + + +def _test_logits_kl_loss(rank, size): + """Test basic LogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 32 + batch_size = 2 + + # Create teacher model (slightly larger) + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model (same size for simplicity) + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config + distill_cfg = setup_distillation_config( + config_or_path=None, + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def _test_topk_logits_kl_loss(top_k, rank, size): + """Test TopKLogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 128 + batch_size = 2 + + # Create teacher model + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config with TopKLogitsKLLoss via logit_kl_topk argument + distill_cfg = setup_distillation_config( + config_or_path=DistillationConfig(logit_kl_topk=top_k), + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def test_logits_kl_loss(): + """Test LogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=_test_logits_kl_loss, + backend="nccl", + ) + + +def test_topk_logits_kl_loss(top_k: int = 5): + """Test TopKLogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_topk_logits_kl_loss, top_k), + backend="nccl", + ) diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py similarity index 59% rename from tests/gpu/torch/export/test_unified_export_megatron.py rename to tests/gpu_megatron/torch/export/test_unified_export_megatron.py index c07c2b5658..e931e6a95c 100644 --- a/tests/gpu/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -16,25 +16,62 @@ import json from copy import deepcopy from functools import partial +from pathlib import Path import pytest import torch import transformers -from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import get_forward from _test_utils.torch.transformers_models import create_tiny_llama_dir -skip_if_no_megatron(apex_or_te_required=True) - +import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel -def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): +def _verify_model_quant_config( + export_dir: Path, quant_config: str | None = None, kv_cache_quant_cfg: str | None = None +): + """Verify config.json and hf_quant_config.json""" + config_dict = json.load(open(export_dir / "config.json")) + hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) + # Make sure config.json and hf_quant_config.json are consistent + assert ( + config_dict["quantization_config"]["quant_algo"] + == hf_quant_config_dict["quantization"]["quant_algo"] + ) + assert ( + config_dict["quantization_config"]["ignore"] + == hf_quant_config_dict["quantization"]["exclude_modules"] + ) + + # Verify config.json + if kv_cache_quant_cfg: + assert config_dict["quantization_config"]["kv_cache_scheme"]["num_bits"] == 8 + + # Verify hf_quant_config.json + if quant_config: + quant_config_dict = hf_quant_config_dict["quantization"] + quant_type = quant_config_dict["quant_algo"] + assert ( + quant_type in quant_config + ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG + assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + if quant_type == "NVFP4": + assert quant_config_dict["group_size"] == 16 + + if kv_cache_quant_cfg: + assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 + + +def _test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg, rank, size +): num_layers = 2 hidden_size = 64 num_attention_heads = 8 @@ -63,14 +100,24 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): transformer_impl="modelopt", ).cuda() - if algo == "medusa": + if quant_config: + quant_config_dict = getattr(mtq, quant_config) + if kv_cache_quant_cfg: + kv_quant_cfg = getattr(mtq, kv_cache_quant_cfg)["quant_cfg"] + quant_config_dict = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_config_dict, kv_quant_cfg + ) + forward = get_forward(model) + model = mtq.quantize(model, quant_config_dict, forward) + + if extra_module == "medusa": config = { "medusa_num_heads": 1, "medusa_num_layers": 1, } model = mtsp.convert(model, [("medusa", config)]) assert isinstance(model, _DynamicMedusaGPTModel) - elif algo == "eagle": + elif extra_module == "eagle": config = {"eagle_architecture_config": deepcopy(default_eagle_config)} model = mtsp.convert(model, [("eagle", config)]) assert isinstance(model, _DynamicEagleGPTModel) @@ -91,25 +138,36 @@ def _test_unified_export_megatron(tmp_path, model_type, arch, algo, rank, size): with open(tmp_path / "config.json", "w") as f: json.dump(pretrained_config, f) + tmp_export_dir = tmp_path / "export" export_mcore_gpt_to_hf( model, tmp_path if arch is not None else None, dtype=torch.bfloat16, + export_dir=str(tmp_export_dir), ) + if quant_config: + _verify_model_quant_config(tmp_export_dir, quant_config, kv_cache_quant_cfg) + @pytest.mark.parametrize( - ("model_type", "arch", "algo"), + ("model_type", "arch", "extra_module", "quant_config", "kv_cache_quant_cfg"), [ - ("nemotron", "NemotronForCausalLM", None), - ("nemotron", "NemotronForCausalLM", "eagle"), - ("nemotron", "NemotronForCausalLM", "medusa"), - ("llama", "LlamaForCausalLM", None), - ("llama", "LlamaForCausalLM", "eagle"), - ("llama", "LlamaForCausalLM", "medusa"), + ("nemotron", "NemotronForCausalLM", None, None, None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", None), + ("nemotron", "NemotronForCausalLM", None, "NVFP4_DEFAULT_CFG", "FP8_KV_CFG"), + ("nemotron", "NemotronForCausalLM", "eagle", None, None), + ("nemotron", "NemotronForCausalLM", "medusa", None, None), + ("llama", "LlamaForCausalLM", None, None, None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", None), + ("llama", "LlamaForCausalLM", None, "FP8_DEFAULT_CFG", "FP8_KV_CFG"), + ("llama", "LlamaForCausalLM", "eagle", None, None), + ("llama", "LlamaForCausalLM", "medusa", None, None), ], ) -def test_unified_export_megatron(tmp_path, model_type, arch, algo): +def test_unified_export_megatron( + tmp_path, model_type, arch, extra_module, quant_config, kv_cache_quant_cfg +): # TODO: Fix TP>1 failures spawn_multiprocess_job( size=1, # torch.cuda.device_count(), @@ -118,7 +176,9 @@ def test_unified_export_megatron(tmp_path, model_type, arch, algo): tmp_path, model_type, arch, - algo, + extra_module, + quant_config, + kv_cache_quant_cfg, ), backend="nccl", ) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py similarity index 97% rename from tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py rename to tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py index ea351db6a9..8e4578d7bb 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py @@ -18,15 +18,12 @@ import pytest import torch -from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model import modelopt.torch.quantization as mtq from modelopt.torch.export import export_mcore_gpt_to_hf_vllm_fq -skip_if_no_megatron(apex_or_te_required=True) - def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size): """Test megatron-core model export for vLLM with fake quantization.""" diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py similarity index 69% rename from tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py rename to tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 2679d3090d..df1c6e240e 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -17,14 +17,11 @@ import pytest import torch -from _test_utils.import_helper import skip_if_no_megatron - -skip_if_no_megatron(apex_or_te_required=True) - from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import run_mcore_inference from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_layer import TransformerLayer @@ -32,6 +29,7 @@ import modelopt.torch.nas as mtn from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.plugins.megatron import ( + NumAttentionHeadsHp, _DynamicColumnParallelLinear, _DynamicEmbedding, _DynamicLanguageModelEmbedding, @@ -81,7 +79,19 @@ def _test_gpt_search_space( normalization=normalization, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_layers_divisor=1, + ), + ) + ], + ) assert isinstance(model, _DynamicMCoreLanguageModel) for m in model.modules(): @@ -153,6 +163,74 @@ def test_expand_head_indices(): assert expand_head_indices(heads, hidden_size_per_head).tolist() == [2, 3, 6, 7, 4, 5, 0, 1] +def test_gpt_self_attention_head_sorting(distributed_setup_size_1): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=1, + hidden_size=16, + num_attention_heads=8, + num_query_groups=2, + ffn_hidden_size=16, + activation_func="squared_relu", + ).cuda() + + model = mtn.convert(model, "mcore_minitron") + + self_attn = model.decoder.layers[0].self_attention + assert isinstance(self_attn, _DynamicSelfAttention) + assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) + assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) + + hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") + assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) + + # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] + assert hp_num_attention_heads.choices == [2, 4, 6, 8] + assert hp_num_attention_heads._num_query_groups == 2 + + # Set importance and slice order + # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] + # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] + # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] + hp_num_attention_heads._get_importance = lambda: torch.tensor( + [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + ) + # _estimate_head_ranking returns ranking as 1D tensor + expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) + hp_num_attention_heads.enforce_order(expected_ranking) + + assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] + + # check if we get correct selection of sorted + pruned heads after setting active values + hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) + + # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 + expected_q_heads = [0, 3, 4, 5] + # In QKV layout (4 heads/group → 6 QKV heads/group): + # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] + # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] + expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] + + assert ( + self_attn.linear_qkv._get_output_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_qkv_heads), model.config.kv_channels + ).tolist() + ) + assert ( + self_attn.linear_proj._get_input_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_q_heads), model.config.kv_channels + ).tolist() + ) + + # Clean up since this is not a spawned process + destroy_model_parallel() + + def _test_gpt_moe_search_space(rank, size): channel_divisor = 4 @@ -183,7 +261,20 @@ def _test_gpt_moe_search_space(rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_moe_experts_divisor=1, + num_layers_divisor=1, + ), + ) + ], + ) moe = model.decoder.layers[0].mlp assert isinstance(moe, _DynamicMoELayer) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py similarity index 92% rename from tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py rename to tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 430b5e261a..de743dc365 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -17,7 +17,7 @@ import torch from _test_utils.import_helper import skip_if_no_megatron -skip_if_no_megatron(apex_or_te_required=True, mamba_required=True) +skip_if_no_megatron(mamba_required=True) from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model @@ -51,7 +51,7 @@ def _test_mamba_search_space(rank, size): mamba_head_dim_divisor = 4 num_layers = size - hybrid_override_pattern = "M" * size + hybrid_override_pattern = "M" * size # all layers are Mamba layers hidden_size = channel_divisor * 4 mamba_state_dim = channel_divisor mamba_head_dim = mamba_head_dim_divisor * 2 @@ -75,7 +75,20 @@ def _test_mamba_search_space(rank, size): ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + mamba_head_dim_divisor=mamba_head_dim_divisor, + num_layers_divisor=1, + ), + ) + ], + ) assert isinstance(model, _DynamicMCoreLanguageModel) if is_pipeline_first_stage(): diff --git a/tests/gpu/torch/opt/plugins/test_megatron_chaining.py b/tests/gpu_megatron/torch/opt/plugins/test_megatron_chaining.py similarity index 100% rename from tests/gpu/torch/opt/plugins/test_megatron_chaining.py rename to tests/gpu_megatron/torch/opt/plugins/test_megatron_chaining.py diff --git a/tests/gpu/torch/peft/test_megatron_peft.py b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py similarity index 99% rename from tests/gpu/torch/peft/test_megatron_peft.py rename to tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py index 34d22d2fd6..1615321ae2 100644 --- a/tests/gpu/torch/peft/test_megatron_peft.py +++ b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py @@ -19,22 +19,17 @@ import pytest import torch import torch.nn.init as init -from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import initialize_for_megatron from megatron.core import dist_checkpointing +import modelopt.torch.peft as mtpeft +import modelopt.torch.quantization as mtq from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( restore_sharded_modelopt_state, save_sharded_modelopt_state, ) - -skip_if_no_megatron() - - -import modelopt.torch.peft as mtpeft -import modelopt.torch.quantization as mtq from modelopt.torch.peft.lora.layer import LoRAModule from modelopt.torch.utils.plugins import megatron_prefill diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py similarity index 83% rename from tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py rename to tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 46d48ea2b2..4f386e654e 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -17,10 +17,6 @@ import pytest import torch -from _test_utils.import_helper import skip_if_no_megatron - -skip_if_no_megatron(apex_or_te_required=True) - from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import ( @@ -29,20 +25,13 @@ ) from _test_utils.torch.misc import compare_outputs, set_seed from _test_utils.torch.nas_prune.minitron_common import prune_minitron -from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn from modelopt.torch.nas.conversion import export_searchspace -from modelopt.torch.nas.plugins.megatron import ( - NumAttentionHeadsHp, - _DynamicProjRowParallelLinear, - _DynamicQKVColumnParallelLinear, - _DynamicSelfAttention, - expand_head_indices, -) from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, + MCoreMinitronSearcher, _convert_model_to_dynamic_space, get_mcore_minitron_config, ) @@ -89,7 +78,10 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, ffn_hidden_size_divisor=channel_divisor + ), ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -128,74 +120,6 @@ def test_mcore_gpt_parameter_sorting(activation_func): ) -def test_mcore_gpt_self_attention_head_sorting(distributed_setup_size_1): - model = get_mcore_gpt_model( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - initialize_megatron=True, - num_layers=1, - hidden_size=16, - num_attention_heads=8, - num_query_groups=2, - ffn_hidden_size=16, - activation_func="squared_relu", - ).cuda() - - model = mtn.convert(model, "mcore_minitron") - - self_attn = model.decoder.layers[0].self_attention - assert isinstance(self_attn, _DynamicSelfAttention) - assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) - assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) - - hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") - assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) - - # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] - assert hp_num_attention_heads.choices == [2, 4, 6, 8] - assert hp_num_attention_heads._num_query_groups == 2 - - # Set importance and slice order - # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] - # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] - # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] - hp_num_attention_heads._get_importance = lambda: torch.tensor( - [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - ) - # _estimate_head_ranking returns ranking as 1D tensor - expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) - hp_num_attention_heads.enforce_order(expected_ranking) - - assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] - - # check if we get correct selection of sorted + pruned heads after setting active values - hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) - - # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 - expected_q_heads = [0, 3, 4, 5] - # In QKV layout (4 heads/group → 6 QKV heads/group): - # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] - # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] - expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] - - assert ( - self_attn.linear_qkv._get_output_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_qkv_heads), model.config.kv_channels - ).tolist() - ) - assert ( - self_attn.linear_proj._get_input_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_q_heads), model.config.kv_channels - ).tolist() - ) - - # Clean up since this is not a spawned process - destroy_model_parallel() - - def _test_mcore_gpt_pruning( num_attention_heads, num_query_groups, @@ -280,16 +204,17 @@ def forward_loop(m): export_config["hidden_size"] = pruned_hidden_size if pruned_num_layers_div != 1: export_config["num_layers"] = pruned_num_layers + constraints = {"export_config": export_config} config = { - "scores_path": ckpt_path, + "checkpoint": ckpt_path, "skip_sorting": skip_sorting, } if skip_sorting: assert ckpt_path is None else: config["forward_loop"] = forward_loop - model, pruning_scores = prune_minitron(model, export_config, config, channel_divisor) + model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor) if not skip_sorting: assert pruning_scores["layer_scores"] assert pruning_scores["activations_per_rank"] @@ -363,12 +288,12 @@ def forward_loop(m): prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size) - # Assert re-pruning from scores_path works without running the forward loop again + # Assert re-pruning from checkpoint works without running the forward loop again if ckpt_path: model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) model_rerun, pruning_scores = prune_minitron( - model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor + model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor ) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) @@ -478,7 +403,12 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + num_moe_experts_divisor=1, + ), ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -565,11 +495,12 @@ def forward_loop(m): "moe_shared_expert_intermediate_size": pruned_moe_shared_ffn, "num_moe_experts": pruned_num_moe_experts, } + constraints = {"export_config": export_config} prune_minitron( model, - export_config, - {"scores_path": ckpt_path, "forward_loop": forward_loop}, + constraints, + {"checkpoint": ckpt_path, "forward_loop": forward_loop}, channel_divisor, ) @@ -603,10 +534,10 @@ def forward_loop(m): prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() output = run_mcore_inference(model, prompt_tokens, pruned_hidden_size) - # Assert re-pruning from scores_path works without running the forward loop again + # Assert re-pruning from checkpoint works without running the forward loop again model_rerun = _get_model(initialize_megatron=False) model_rerun.load_state_dict(sd) - prune_minitron(model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor) + prune_minitron(model_rerun, constraints, {"checkpoint": ckpt_path}, channel_divisor) output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size) assert torch.allclose(output, output_rerun, atol=1e-5) @@ -618,3 +549,30 @@ def test_mcore_gpt_pruning_moe(tmp_path): job=partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"), backend="nccl", ) + + +def test_generate_search_space_combos(): + ss = { + "hidden_size": [32, 64, 96, 128, 160], + "ffn_hidden_size": [128, 256, 384, 512, 640], + "num_attention_heads": [8, 16, 24, 32], + "num_layers": [1, 2, 3, 4, 5, 6, 7, 8], + } + ss_combos = MCoreMinitronSearcher._generate_search_space_combos( + ss, max_width_pruning=0.5, max_depth_pruning=0.25, hparams_to_skip=["ffn_hidden_size"] + ) + assert len(ss_combos) == 3 * 2 * 2 + assert ss_combos == [ + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 8}, + ] diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py new file mode 100644 index 0000000000..69b286c6b0 --- /dev/null +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib +import io +from functools import partial + +import pytest +import torch +from _test_utils.import_helper import skip_if_no_megatron + +skip_if_no_megatron(mamba_required=True) + +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model +from _test_utils.torch.megatron.utils import ( + run_mcore_inference, + run_mcore_inference_with_dummy_input, +) +from _test_utils.torch.misc import compare_outputs, set_seed +from _test_utils.torch.nas_prune.minitron_common import prune_minitron +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols +from megatron.core.ssm.mamba_layer import MambaLayer +from megatron.core.transformer.identity_op import IdentityOp + +import modelopt.torch.nas as mtn +from modelopt.torch.prune.plugins.mcore_minitron import ( + ImportanceEstimatorRegistry, + _convert_model_to_dynamic_space, + get_mcore_minitron_config, + get_mcore_param_count, +) + +SEED = 1234 + + +def _test_mcore_mamba_parameter_sorting(rank, size): + # Use relatively bigger model here for more accurate test for sorting + channel_divisor = 64 + + num_layers = size + hybrid_override_pattern = "M" * size + hidden_size = channel_divisor * 4 + mamba_state_dim = channel_divisor + mamba_head_dim = 16 + mamba_num_groups = 2 + max_sequence_length = 32 + vocab_size = 64 + batch_size = 2 + + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hybrid_override_pattern=hybrid_override_pattern, + hidden_size=hidden_size, + mamba_state_dim=mamba_state_dim, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=mamba_num_groups, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + bf16=False, + ).cuda() + + # Randomize norm weights instead of all zeros or ones + for n, m in model.named_modules(): + if "norm" in n and not isinstance(m, IdentityOp): + m.weight.data = torch.randn_like(m.weight) + + model.eval() + dynamic_space = _convert_model_to_dynamic_space( + model, + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + mamba_head_dim_divisor=4, + ), + ) + registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks + + # Compute activations for sorting + for _ in range(5): + run_mcore_inference_with_dummy_input(model, batch_size) + + # Get the output of the original model + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + y1 = run_mcore_inference(model, prompt_tokens) + + mtn.utils.sort_parameters(model) + registry.cleanup() + + # check if all mamba_num_heads, mamba_head_dim, hidden_size have been sorted + sortable_per_pp = [ + n for n, hp in dynamic_space.named_hparams(configurable=True) if hp.importance is not None + ] + # 2 mamba hps per layer + 1 for hidden_size (num_layers is not sorted!) + assert len(sortable_per_pp) == 2 * num_layers // size + 1 + + # sanity check if the model functionality is preserved after sorting + y2 = run_mcore_inference(model, prompt_tokens) + + # check if the inference results after sorting is the same + compare_outputs(y1, y2, rtol=1e-5, atol=1e-3) + + +def test_mcore_mamba_parameter_sorting(): + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=_test_mcore_mamba_parameter_sorting, + backend="nccl", + ) + + +def _test_mcore_mamba_hybrid_pruning(ckpt_path, rank, size): + channel_divisor = 4 + + num_layers = min(size * 2, 8) + hidden_size = channel_divisor * 8 + ffn_hidden_size = channel_divisor * 2 + num_attention_heads = 8 + num_query_groups = 4 + mamba_state_dim = channel_divisor * 2 + mamba_head_dim = channel_divisor * 2 + mamba_num_groups = 2 + num_moe_experts = 8 + vocab_size = 32 + batch_size = 2 + + def _get_model(initialize_megatron=True): + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=initialize_megatron, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + mamba_state_dim=mamba_state_dim, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=mamba_num_groups, + moe_ffn_hidden_size=ffn_hidden_size, + moe_shared_expert_intermediate_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, + vocab_size=vocab_size, + ).cuda() + return model + + model = _get_model() + + mamba_layer = None + for layer in model.decoder.layers: + if isinstance(layer, MambaLayer): + mamba_layer = layer + break + assert mamba_layer is not None, f"No MambaLayer found in the model PP rank {rank}!" + mamba_num_heads = mamba_layer.mixer.nheads + + def forward_loop(m): + for _ in range(2): + run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) + + # Traditional GPT pruning parameters + pruned_ffn_hidden_size = ffn_hidden_size // 2 + pruned_num_attention_heads = num_attention_heads // 2 + pruned_hidden_size = hidden_size // 2 + pruned_num_moe_experts = num_moe_experts // 2 + + # Mamba-specific pruning parameters + pruned_mamba_num_heads = mamba_num_heads // 2 + pruned_mamba_head_dim = mamba_head_dim // 2 + + # Base export config with GPT/Attention parameters + export_config = { + "ffn_hidden_size": pruned_ffn_hidden_size, + "num_attention_heads": pruned_num_attention_heads, + "hidden_size": pruned_hidden_size, + "mamba_num_heads": pruned_mamba_num_heads, + "mamba_head_dim": pruned_mamba_head_dim, + "moe_ffn_hidden_size": pruned_ffn_hidden_size, + "moe_shared_expert_intermediate_size": pruned_ffn_hidden_size, + "num_moe_experts": pruned_num_moe_experts, + } + constraints = {"export_config": export_config} + prune_minitron( + model, + constraints, + {"forward_loop": forward_loop, "checkpoint": ckpt_path}, + channel_divisor, + ) + + # Assert weights are pruned correctly + mixer = mamba_layer.mixer + bc = 2 * mixer.ngroups * mixer.d_state + assert mixer.nheads == pruned_mamba_num_heads + assert mixer.headdim == pruned_mamba_head_dim + assert mixer.in_proj.input_size == pruned_hidden_size + assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim + assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads + assert mixer.out_proj.input_size == mixer.d_inner + assert mixer.out_proj.output_size == pruned_hidden_size + assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bc + + # Assert model.config is updated for correct save/restoring + assert model.config.ffn_hidden_size == pruned_ffn_hidden_size + assert model.config.num_attention_heads == pruned_num_attention_heads + assert model.config.hidden_size == pruned_hidden_size + assert model.config.mamba_num_heads == pruned_mamba_num_heads + assert model.config.mamba_head_dim == pruned_mamba_head_dim + assert model.config.moe_ffn_hidden_size == pruned_ffn_hidden_size + assert model.config.moe_shared_expert_intermediate_size == pruned_ffn_hidden_size + assert model.config.num_moe_experts == pruned_num_moe_experts + + # Assert forward pass works on the pruned model + run_mcore_inference_with_dummy_input(model, batch_size, pruned_hidden_size) + + # Assert re-pruning from checkpoint works without running the forward loop again + model = _get_model(initialize_megatron=False) + prune_minitron(model, constraints, {"checkpoint": ckpt_path}, channel_divisor) + + +def test_mcore_mamba_hybrid_pruning(tmp_path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth"), + backend="nccl", + ) + + +def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size): + channel_divisor = 4 + + # TODO: MoE in MambaModel requires Mcore 0.16+ + num_layers = 4 # Atleast one of "M, *, -, E" blocks + hybrid_pattern = "M*-M" # "ME*-" + hidden_size = 16 + ffn_hidden_size = 32 + num_attention_heads = 16 + num_query_groups = 4 + mamba_state_dim = 4 + mamba_num_heads = 16 + mamba_head_dim = 16 + mamba_num_groups = 2 + num_moe_experts = None + moe_ffn_hidden_size = None + moe_shared_expert_intermediate_size = None + # num_moe_experts = 8 + # moe_ffn_hidden_size = 16 + # moe_shared_expert_intermediate_size = 16 + vocab_size = 32 + batch_size = 2 + + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hybrid_override_pattern=hybrid_pattern, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + mamba_state_dim=mamba_state_dim, + mamba_num_heads=mamba_num_heads, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=mamba_num_groups, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + num_moe_experts=num_moe_experts, + vocab_size=vocab_size, + ).cuda() + + param_count = get_mcore_param_count(model) + assert param_count == 31776.0, param_count + + def forward_loop(m): + for _ in range(2): + run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) + + def score_func(m): + c = m.config + return ( + c.num_layers + + c.hidden_size + + c.ffn_hidden_size + + c.mamba_num_heads + + c.mamba_head_dim + + c.num_attention_heads + # + c.num_moe_experts + # + c.moe_ffn_hidden_size + # + c.moe_shared_expert_intermediate_size + ) + + constraints = {"params": int(param_count * 0.7)} + config = { + "forward_loop": forward_loop, + "checkpoint": ckpt_path, + "score_func": score_func, + "max_width_pruning": 0.5, + "max_depth_pruning": 0.5, + "hparams_to_skip": ["num_attention_heads"], + "top_k": 10, + } + + # Capture stdout to assert search space output + stdout_capture = io.StringIO() + with contextlib.redirect_stdout(stdout_capture): + model, searcher_state = prune_minitron(model, constraints, config, channel_divisor) + + # Assert expected search space output is present + captured_output = stdout_capture.getvalue() + print(captured_output) + if rank == 0: + assert "Search space for num_layers: [3, 4]" in captured_output + assert "Search space for hidden_size: [12, 16]" in captured_output + assert "Search space for mamba_num_heads: [10, 12, 14, 16]" in captured_output + assert "Search space for mamba_head_dim: [12, 16]" in captured_output + assert "Search space for ffn_hidden_size: [20, 24, 28, 32]" in captured_output + assert "Total search space in consideration: 128" in captured_output + + # NOTE: Slight variation in layer ordering for Attention and MLP depending on PP configuration + # This affects param counts when num_layers is pruned + sorted_layers = [ + layer + for layer, _ in sorted( + searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True + ) + ] + # fmt: off + if sorted_layers == [1, 4, 2, 3]: + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 10, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21180.0, 94.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21140.0, 81.0], # noqa: E501 + ] + elif sorted_layers == [1, 4, 3, 2]: + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 21524.0, 93.0], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21412.0, 93.0], # noqa: E501 + ] + else: + raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers}") + # fmt: on + + assert get_mcore_param_count(model) == 22196.0 + + top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]] + assert len(top_k) == 10 + for actual, (ss_config, params, score) in zip(top_k, expected_top_k): + assert actual.ss_config == ss_config, (actual.ss_config, ss_config) + assert actual.params == params, (actual.params, params) + assert actual.score == score, (actual.score, score) + + +def test_mcore_mamba_hybrid_pruning_nas(tmp_path): + set_seed(SEED) + if torch.cuda.device_count() > 4: + pytest.skip("Skipping test for more than 4 GPUs") + if "E" in Symbols.VALID: + pytest.skip("TODO: Update test for MoE in Mamba (Mcore 0.16+)") + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_mcore_mamba_hybrid_pruning_nas, tmp_path / "modelopt_minitron_scores.pth" + ), + backend="nccl", + ) diff --git a/tests/gpu/torch/quantization/plugins/test_apex.py b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py similarity index 100% rename from tests/gpu/torch/quantization/plugins/test_apex.py rename to tests/gpu_megatron/torch/quantization/plugins/test_apex.py diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py similarity index 91% rename from tests/gpu/torch/quantization/plugins/test_megatron.py rename to tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index d02b02c18a..5704e369e5 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -18,10 +18,8 @@ import pytest import torch -from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import ( - MambaModel, MegatronModel, get_mcore_gpt_model, get_mcore_mamba_hybrid_model, @@ -29,6 +27,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_forward, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, @@ -41,9 +40,6 @@ data_tensor_context_parallel_test_helper, verify_kv_cache_amax_sync, ) - -skip_if_no_megatron() - from megatron.core.parallel_state import ( destroy_model_parallel, get_data_parallel_group, @@ -69,47 +65,6 @@ SEED = 1234 -def get_batch(model, batch_size=2): - seq_length = model.max_sequence_length - vocab_size = model.vocab_size - - input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - labels = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() - position_ids = ( - torch.arange(seq_length, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1).cuda() - ) - attention_mask = torch.tril( - torch.ones((batch_size, 1, seq_length, seq_length), dtype=torch.bool) - ).cuda() - loss_mask = torch.ones((batch_size, seq_length), dtype=torch.float32).cuda() - - return input_ids, labels, position_ids, attention_mask, loss_mask - - -def get_forward(model, batch_size=2): - """Return a forward function with cached batch inputs.""" - input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) - - def forward(model): - # MambaModel doesn't accept loss_mask argument - if isinstance(model, MambaModel): - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - ) - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) - - return forward - - def test_convert_megatron_parallel_linear(distributed_setup_size_1): initialize_for_megatron(seed=SEED) set_seed(SEED) @@ -514,10 +469,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device, @pytest.mark.parametrize( "config", - [ - NVFP4_GEMM_KV_CFG, - FP8_GEMM_KV_CFG, - ], + [NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG], ) def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config): """Test sharded state dict for hybrid Mamba MOE models.""" @@ -776,6 +728,81 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus): ) +@pytest.mark.parametrize("ep_size", [1, 2]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + size = torch.cuda.device_count() + if size < ep_size: + pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test") + + spawn_multiprocess_job( + size=size, + job=partial( + _test_layer_sync_moe_local_experts_amax, + ep_size, + moe_grouped_gemm, + ), + backend="nccl", + ) + + +def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size): + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=1, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=1, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=8, + transformer_impl="modelopt", + ) + quant_cfg = mtq.FP8_DEFAULT_CFG + model = mtq.quantize(model, quant_cfg, get_forward(model)) + + for layer in model.decoder.layers: + layer.mlp.experts.layer_sync_moe_local_experts_amax() + + for layer in model.decoder.layers: + # Check input quantizer amax is synced across local experts + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.input_quantizer.amax is not None + assert expert.linear_fc2.input_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.input_quantizer.amax + else: + assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.input_quantizer.amax + else: + assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax) + + # Check weight quantizer amax is different across local experts + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.weight_quantizer.amax is not None + assert expert.linear_fc2.weight_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.weight_quantizer.amax + else: + assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.weight_quantizer.amax + else: + assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + + def _test_expert_model_parallel_amax_sync( tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size ): @@ -841,9 +868,6 @@ def _test_expert_model_parallel_amax_sync( ) # calibrate the model with distributed sync and test synchronization mtq.model_calib.max_calibrate(model, forward, distributed_sync=True) - for module in model.modules(): - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" @@ -859,9 +883,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): if size < ep_size * etp_size: pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") - if moe_grouped_gemm: - pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - spawn_multiprocess_job( size=size, job=partial( @@ -997,7 +1018,7 @@ def test_convert_mcore_te_gpt_model(distributed_setup_size_1): for n, m in model.named_modules(): if isinstance(m, TERowParallelLinear): - assert isinstance(m, _QuantTEMCoreRowParallelLinear) + assert isinstance(m, _QuantTEMCoreRowParallelLinear), f"{m=}, {type(m)}" assert m.input_quantizer.amax is not None assert m.weight_quantizer.amax is not None diff --git a/tests/gpu/torch/quantization/plugins/test_transformer_engine.py b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py similarity index 100% rename from tests/gpu/torch/quantization/plugins/test_transformer_engine.py rename to tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py diff --git a/tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py b/tests/gpu_megatron/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py similarity index 100% rename from tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py rename to tests/gpu_megatron/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py diff --git a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py similarity index 98% rename from tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py rename to tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py index 5a149b77fc..0bb9658ff8 100644 --- a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -16,10 +16,6 @@ import pytest import torch -from _test_utils.import_helper import skip_if_no_megatron - -skip_if_no_megatron(apex_or_te_required=True) - from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.megatron.models import get_mcore_gpt_model diff --git a/tests/unit/torch/utils/test_megatron_preprocess_data.py b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py similarity index 52% rename from tests/unit/torch/utils/test_megatron_preprocess_data.py rename to tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py index dbdd8e3088..de6bc71818 100644 --- a/tests/unit/torch/utils/test_megatron_preprocess_data.py +++ b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py @@ -15,18 +15,9 @@ import json import os -import platform from pathlib import Path -import pytest -from _test_utils.import_helper import skip_if_no_megatron - -if platform.system() == "Windows": - pytest.skip("Skipping on Windows", allow_module_level=True) - -skip_if_no_megatron() -datasets = pytest.importorskip("datasets") -_ = pytest.importorskip("transformers") +from datasets import load_dataset from modelopt.torch.utils.plugins.megatron_preprocess_data import megatron_preprocess_data @@ -40,49 +31,46 @@ def download_and_prepare_minipile_dataset(output_dir: Path) -> Path: Returns: Path to the created JSONL file """ - # Download the dataset - dataset = datasets.load_dataset("nanotron/minipile_100_samples", split="train") + dataset = load_dataset("nanotron/minipile_100_samples", split="train") - # Convert to JSONL format jsonl_file = output_dir / "minipile_100_samples.jsonl" with open(jsonl_file, "w", encoding="utf-8") as f: for item in dataset: - # Extract the text field and write as JSONL json_obj = {"text": item["text"]} f.write(json.dumps(json_obj) + "\n") return jsonl_file -def test_megatron_preprocess_data_with_minipile_dataset(tmp_path): - """Test megatron_preprocess_data function with nanotron/minipile_100_samples dataset. +def test_megatron_preprocess_data_with_minipile_jsonl(tmp_path): + """Test megatron_preprocess_data with nanotron/minipile_100_samples dataset. This test: 1. Downloads the HuggingFace dataset "nanotron/minipile_100_samples" 2. Converts it to JSONL format - 3. Passes it to megatron_preprocess_data + 3. Calls megatron_preprocess_data with jsonl_paths 4. Verifies that output files are created """ - # Download and prepare the dataset input_jsonl = download_and_prepare_minipile_dataset(tmp_path) - # Verify the input file was created and has content assert input_jsonl.exists(), "Input JSONL file should exist" assert input_jsonl.stat().st_size > 0, "Input JSONL file should not be empty" - # Test the megatron_preprocess_data function + with open(input_jsonl, encoding="utf-8") as f: + first_line = f.readline().strip() + first_item = json.loads(first_line) + assert "text" in first_item, "Each JSONL item should have a 'text' field" + assert isinstance(first_item["text"], str), "Text field should be a string" + megatron_preprocess_data( - input_path=input_jsonl, + jsonl_paths=input_jsonl, output_dir=tmp_path, - tokenizer_name_or_path="gpt2", # Use a small, common tokenizer + tokenizer_name_or_path="gpt2", json_keys=["text"], - append_eod=False, workers=1, - log_interval=10, ) - # Verify that output files were created output_prefix = tmp_path / "minipile_100_samples" expected_bin_file = f"{output_prefix}_text_document.bin" expected_idx_file = f"{output_prefix}_text_document.idx" @@ -94,55 +82,31 @@ def test_megatron_preprocess_data_with_minipile_dataset(tmp_path): f"Expected index file {expected_idx_file} should exist" ) - # Verify the files have content (non-zero size) assert os.path.getsize(expected_bin_file) > 0, "Binary file should not be empty" assert os.path.getsize(expected_idx_file) > 0, "Index file should not be empty" - # Optional: Verify the input JSONL file structure - with open(input_jsonl, encoding="utf-8") as f: - first_line = f.readline().strip() - first_item = json.loads(first_line) - assert "text" in first_item, "Each JSONL item should have a 'text' field" - assert isinstance(first_item["text"], str), "Text field should be a string" - -def test_megatron_preprocess_data_with_custom_parameters(tmp_path): - """Test megatron_preprocess_data with different parameters.""" - # Create a minimal test dataset - input_jsonl = tmp_path / "test_data.jsonl" +def test_megatron_preprocess_data_with_hf_dataset(tmp_path): + """Test megatron_preprocess_data with dataset download, --append_eod and --max_sequence_length. - # Create some test data - test_data = [ - {"text": "This is a test sentence for preprocessing."}, - {"text": "Another test sentence with different content."}, - {"text": "A third sentence to make sure the function works correctly."}, - ] - - with open(input_jsonl, "w", encoding="utf-8") as f: - f.writelines(json.dumps(item) + "\n" for item in test_data) - - # Test with different parameters + Downloads nanotron/minipile_100_samples train split from Hugging Face and tokenizes it. + """ megatron_preprocess_data( - input_path=input_jsonl, + hf_dataset="nanotron/minipile_100_samples", + hf_split="train", output_dir=tmp_path, tokenizer_name_or_path="gpt2", json_keys=["text"], - append_eod=True, # Test with end-of-document token - max_sequence_length=5, # Test with sequence length limit - workers=1, - log_interval=1, + append_eod=True, + max_sequence_length=512, + workers=4, ) - # Verify output files exist - output_prefix = tmp_path / "test_data" - expected_bin_file = f"{output_prefix}_text_document.bin" - expected_idx_file = f"{output_prefix}_text_document.idx" + bin_files = sorted(tmp_path.glob("*.bin")) + idx_files = sorted(tmp_path.glob("*.idx")) - assert os.path.exists(expected_bin_file), ( - f"Expected binary file {expected_bin_file} should exist" - ) - assert os.path.exists(expected_idx_file), ( - f"Expected index file {expected_idx_file} should exist" - ) - assert os.path.getsize(expected_bin_file) > 0, "Binary file should not be empty" - assert os.path.getsize(expected_idx_file) > 0, "Index file should not be empty" + assert len(bin_files) > 0, f"Expected .bin files in {tmp_path}, found none" + assert len(idx_files) > 0, f"Expected .idx files in {tmp_path}, found none" + + for f in bin_files + idx_files: + assert f.stat().st_size > 0, f"{f.name} should not be empty" diff --git a/tests/gpu/torch/utils/plugins/test_utils_megatron.py b/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py similarity index 100% rename from tests/gpu/torch/utils/plugins/test_utils_megatron.py rename to tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py diff --git a/tests/unit/onnx/autocast/test_autocast.py b/tests/unit/onnx/autocast/test_autocast.py index 3f987eaeae..f761e1e9e3 100644 --- a/tests/unit/onnx/autocast/test_autocast.py +++ b/tests/unit/onnx/autocast/test_autocast.py @@ -20,7 +20,7 @@ import onnx import onnx_graphsurgeon as gs import pytest -from _test_utils.onnx.lib_test_models import build_conv_isinf_model +from _test_utils.onnx.lib_test_models import build_conv_isinf_model, build_conv_resize_model import modelopt.onnx.autocast.utils as utils import modelopt.onnx.utils as onnx_utils @@ -174,7 +174,7 @@ def test_conv_isinf_conversion(tmp_path, opset_version): output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") onnx.save(converted_model, output_onnx_path) - # Load the output model and check QDQ node placements + # Load the output model graph = gs.import_onnx(converted_model) # Check that Conv is converted @@ -190,6 +190,30 @@ def test_conv_isinf_conversion(tmp_path, opset_version): assert assert_input_precision(isinf_nodes, dtype=supported_dtype) +def test_conv_resize_conversion(tmp_path): + onnx_model = build_conv_resize_model() + onnx_path = os.path.join(tmp_path, "conv_resize_model.onnx") + onnx.save(onnx_model, onnx_path) + + # Convert the model + converted_model = convert_to_mixed_precision(onnx_path=onnx_path) + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx") + onnx.save(converted_model, output_onnx_path) + + # Load the output model + graph = gs.import_onnx(converted_model) + + # Check that Resize is correctly converted: + # - Data and ROI inputs (indices 0 and 1) should be FP16 + # - The remaining inputs (scales/sizes) should be kept in their original precisions + resize_node = next(n for n in graph.nodes if n.op == "Resize") + assert all(inp.dtype == np.float16 for inp in resize_node.inputs[0:2]), ( + "Resize data and ROI inputs should be FP16" + ) + + @pytest.mark.parametrize("target_opset", [13, 17, 19, 21]) def test_opset_parameter(temp_model_path, target_opset): """Test that the opset parameter correctly sets the output model's opset version.""" diff --git a/tests/unit/onnx/autocast/test_nodeclassifier.py b/tests/unit/onnx/autocast/test_nodeclassifier.py index 50e38499e8..e604fb7c1c 100644 --- a/tests/unit/onnx/autocast/test_nodeclassifier.py +++ b/tests/unit/onnx/autocast/test_nodeclassifier.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from collections import OrderedDict + import numpy as np import pytest from onnx import TensorProto, helper, numpy_helper @@ -27,7 +30,7 @@ IORangeRule, NodeClassifier, ) -from modelopt.onnx.autocast.referencerunner import ReferenceRunner +from modelopt.onnx.autocast.referencerunner import ReferenceRunner, TensorStats configure_logging("DEBUG") @@ -382,3 +385,178 @@ def test_node_classifier_force_include(test_model): assert "mul_node" in fp16_nodes assert "add_node" in fp32_nodes assert not set(fp16_nodes).intersection(set(fp32_nodes)) + + +def test_io_range_rule_with_tensor_stats(): + """Test IORangeRule with TensorStats objects (multi-batch aggregated data).""" + # Create TensorStats objects with aggregated statistics + reference_outputs = { + "out1": TensorStats(absmax=9000.0, min_val=-5000.0, max_val=9000.0, shape=(1,)), + "out2": TensorStats(absmax=11000.0, min_val=-11000.0, max_val=11000.0, shape=(1,)), + "out3": TensorStats(absmax=9000.0, min_val=-9000.0, max_val=5000.0, shape=(1,)), + "out4": TensorStats(absmax=11000.0, min_val=-8000.0, max_val=11000.0, shape=(1,)), + } + + rule = IORangeRule(10000, reference_outputs, node_to_init_map={}) + + node1 = helper.make_node("TestOp", [], ["out1"], name="node1") + node2 = helper.make_node("TestOp", [], ["out2"], name="node2") + node3 = helper.make_node("TestOp", [], ["out3"], name="node3") + node4 = helper.make_node("TestOp", [], ["out4"], name="node4") + node5 = helper.make_node("TestOp", ["out1"], [], name="node5") + node6 = helper.make_node("TestOp", ["out2"], [], name="node6") + node7 = helper.make_node("TestOp", ["out3"], [], name="node7") + node8 = helper.make_node("TestOp", ["out4"], [], name="node8") + + # out1: absmax=9000 < 10000, should not be blocked + assert rule.check(node1) is False + # out2: absmax=11000 > 10000, should be blocked + assert rule.check(node2) is True + # out3: absmax=9000 < 10000, should not be blocked + assert rule.check(node3) is False + # out4: absmax=11000 > 10000, should be blocked + assert rule.check(node4) is True + # Input out1: absmax=9000 < 10000, should not be blocked + assert rule.check(node5) is False + # Input out2: absmax=11000 > 10000, should be blocked + assert rule.check(node6) is True + # Input out3: absmax=9000 < 10000, should not be blocked + assert rule.check(node7) is False + # Input out4: absmax=11000 > 10000, should be blocked + assert rule.check(node8) is True + + +def test_io_range_rule_mixed_numpy_and_tensor_stats(): + """Test IORangeRule with mixed numpy arrays and TensorStats objects.""" + reference_outputs = { + "out1": np.array([9000], dtype=np.float32), # Raw numpy array + "out2": TensorStats( + absmax=11000.0, min_val=-11000.0, max_val=11000.0, shape=(1,) + ), # TensorStats + } + + rule = IORangeRule(10000, reference_outputs, node_to_init_map={}) + + node1 = helper.make_node("TestOp", [], ["out1"], name="node1") + node2 = helper.make_node("TestOp", [], ["out2"], name="node2") + + # out1: numpy array with absmax=9000 < 10000, should not be blocked + assert rule.check(node1) is False + # out2: TensorStats with absmax=11000 > 10000, should be blocked + assert rule.check(node2) is True + + +def test_depth_of_reduction_rule_with_tensor_stats(): + """Test DepthOfReductionRule with TensorStats objects.""" + # Create TensorStats objects for reference data + reference_data = { + "matmul_output": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(10, 30)), + "small_matmul_output": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(5, 8)), + "conv_output": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(1, 64, 62, 62)), + "small_conv_output": TensorStats( + absmax=1.0, min_val=0.0, max_val=1.0, shape=(1, 16, 15, 15) + ), + "matmul_input_a": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(10, 50)), + "matmul_input_b": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(50, 30)), + "small_matmul_a": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(5, 10)), + "small_matmul_b": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(10, 8)), + "conv_input": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(1, 32, 64, 64)), + "conv_weight": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(64, 32, 3, 3)), + "small_conv_input": TensorStats(absmax=1.0, min_val=0.0, max_val=1.0, shape=(1, 8, 16, 16)), + } + + node_to_init_map = { + "matmul_node": [], + "small_matmul_node": [], + "conv_node": [], + "small_conv_node": [], + } + initializer_map = {} + + rule = DepthOfReductionRule( + max_depth_of_reduction=40, + reference_data=reference_data, + node_to_init_map=node_to_init_map, + initializer_map=initializer_map, + ) + + # MatMul nodes + matmul_node = helper.make_node( + "MatMul", ["matmul_input_a", "matmul_input_b"], ["matmul_output"], name="matmul_node" + ) + small_matmul_node = helper.make_node( + "MatMul", + ["small_matmul_a", "small_matmul_b"], + ["small_matmul_output"], + name="small_matmul_node", + ) + + # Conv nodes + conv_node = helper.make_node( + "Conv", ["conv_input", "conv_weight"], ["conv_output"], name="conv_node" + ) + small_conv_node = helper.make_node( + "Conv", + ["small_conv_input", "small_conv_weight"], + ["small_conv_output"], + name="small_conv_node", + ) + + # Test MatMul: reduction depth 50 > 40, should be blocked + assert rule.check(matmul_node) is True + + # Test small MatMul: reduction depth 10 < 40, should not be blocked + assert rule.check(small_matmul_node) is False + + # Test Conv: reduction depth 288 > 40, should be blocked + assert rule.check(conv_node) is True + + # Test small Conv: reduction depth 32 < 40, should not be blocked + assert rule.check(small_conv_node) is False + + +@pytest.mark.skipif( + Version(ort_version) < Version("1.21.0"), reason="WAR: Requires onnxruntime>=1.21.0" +) +def test_node_classifier_with_multi_batch_calibration(test_model): + """Test NodeClassifier with multi-batch calibration data.""" + import tempfile + + node_to_init_map = {key: [] for key in ["add_node", "mul_node"]} + + # Create multiple batches of calibration data + with tempfile.TemporaryDirectory() as temp_dir: + # Batch 1: small values + inputs1 = {"X": np.array([[0.1, 0.2], [0.2, 0.3]], dtype=np.float32)} + np.savez(os.path.join(temp_dir, "batch_001.npz"), **inputs1) + + # Batch 2: medium values + inputs2 = {"X": np.array([[0.3, 0.4], [0.4, 0.5]], dtype=np.float32)} + np.savez(os.path.join(temp_dir, "batch_002.npz"), **inputs2) + + # Batch 3: larger values (but still within fp16 range) + inputs3 = {"X": np.array([[0.5, 0.6], [0.6, 0.7]], dtype=np.float32)} + np.savez(os.path.join(temp_dir, "batch_003.npz"), **inputs3) + + ref_runner = ReferenceRunner(test_model) + classifier = NodeClassifier( + model=test_model, + node_to_init_map=node_to_init_map, + data_max=4.1, + ) + + # Run with multi-batch calibration directory + ref_outputs_dict = ref_runner.run(temp_dir) + + # Should get TensorStats objects + assert isinstance(ref_outputs_dict, OrderedDict) + # At least one output should be TensorStats (if multiple batches) + has_tensor_stats = any(isinstance(v, TensorStats) for v in ref_outputs_dict.values()) + assert has_tensor_stats + + # Run classification + fp16_nodes, fp32_nodes = classifier.run(ref_outputs_dict) + + # Verify classification works correctly + assert len(fp16_nodes) + len(fp32_nodes) == 2 + assert not set(fp16_nodes).intersection(set(fp32_nodes)) diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index 4fb02a230f..a14991319a 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -32,6 +32,15 @@ def low_precision_onnx_type(low_precision_type_str): return TensorProto.FLOAT16 if low_precision_type_str == "fp16" else TensorProto.BFLOAT16 +def setup_mappings( + model: onnx.ModelProto, use_standalone_type_inference: bool = False +) -> tuple[onnx.ModelProto, dict, dict, dict]: + # Setup internal mappings + model = onnx_utils.infer_types(model, use_standalone_type_inference) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + return model, value_info_map, initializer_map, node_to_init_map + + #################################################################################################### # Testing with a basic GEMM->Add->Relu graph #################################################################################################### @@ -56,16 +65,21 @@ def simple_model(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_graph_converter_init(simple_model): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_graph_converter_init(simple_model, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=True + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + use_standalone_type_inference=use_standalone_type_inference, ) assert converter.model == model assert converter.value_info_map == value_info_map @@ -75,7 +89,10 @@ def test_graph_converter_init(simple_model): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_simple_convert(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_simple_convert( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -84,6 +101,7 @@ def test_simple_convert(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert add node to fp16, keep mul in fp32 @@ -133,7 +151,10 @@ def test_unsupported_precision_type(simple_model, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_convert_no_disabled_nodes( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -142,6 +163,7 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to fp16 @@ -167,7 +189,10 @@ def test_convert_no_disabled_nodes(simple_model, keep_io_types, low_precision_ty @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_get_tensors_to_cast( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -176,6 +201,7 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when relu node is in low precision @@ -196,7 +222,10 @@ def test_get_tensors_to_cast(simple_model, keep_io_types, low_precision_type): @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_keep_io_names(simple_model, keep_io_types, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_keep_io_names( + simple_model, keep_io_types, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = simple_model converter = PrecisionConverter( model, @@ -205,6 +234,7 @@ def test_keep_io_names(simple_model, keep_io_types, low_precision_type): node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert all nodes to low precision @@ -258,16 +288,16 @@ def model_with_multiple_consumers(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_convert_with_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -277,6 +307,7 @@ def test_convert_with_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Only gemm1 and add1 are converted to fp32, gemm2 and add2 are fp16 @@ -300,8 +331,9 @@ def test_convert_with_multiple_consumers( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_get_tensors_to_cast_multiple_consumers( - model_with_multiple_consumers, keep_io_types, low_precision_type + model_with_multiple_consumers, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( @@ -311,6 +343,7 @@ def test_get_tensors_to_cast_multiple_consumers( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test when gemm2 and add1 nodes are in low precision @@ -327,7 +360,10 @@ def test_get_tensors_to_cast_multiple_consumers( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_convert_initializers(model_with_multiple_consumers, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_convert_initializers( + model_with_multiple_consumers, low_precision_type, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers converter = PrecisionConverter( model, @@ -335,6 +371,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test successful cast, add1 and add2 share add_init and operate in different precisions @@ -361,6 +398,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter2.model.graph.node if n.name == "add1") add2_node = next(n for n in converter2.model.graph.node if n.name == "add2") @@ -384,6 +422,7 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) initializer_map, node_to_init_map, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) add1_node = next(n for n in converter3.model.graph.node if n.name == "add1") add2_node = next(n for n in converter3.model.graph.node if n.name == "add2") @@ -404,7 +443,10 @@ def test_convert_initializers(model_with_multiple_consumers, low_precision_type) assert f"add_init_{low_precision_type}" in init_names -def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_clamping_fp16_initializers_out_of_range( + model_with_multiple_consumers, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, node is converted to FP16 @@ -412,7 +454,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): add_init = numpy_helper.from_array(add_init_out_of_range, name="add_init") model.graph.initializer[1].CopyFrom(add_init) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) # Verify initializer is clamped @@ -427,7 +475,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert add_init_converted_array[0, 1] == np.finfo(np.float16).max # Initializer is out of FP16 range, node is kept in FP32 - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter2._convert_initializers(low_precision_nodes=[], high_precision_nodes=["add1", "add2"]) # Verify initializer is not clamped @@ -441,7 +495,13 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_converted_array == add_init_out_of_range) # Initializer is out of FP16 range, one consumer is converted to FP16, the other is kept in FP32 - converter3 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter3 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter3._convert_initializers(low_precision_nodes=["add1"], high_precision_nodes=["add2"]) # Verify initializer is duplicated, and the FP16 copy is clamped @@ -462,7 +522,10 @@ def test_clamping_fp16_initializers_out_of_range(model_with_multiple_consumers): assert np.all(add_init_fp32_array == add_init_out_of_range) -def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumers): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_bf16_no_clamping_initializers_out_of_range( + model_with_multiple_consumers, use_standalone_type_inference +): model, value_info_map, initializer_map, node_to_init_map = model_with_multiple_consumers # Initializer is out of FP16 range, but that does not affect BF16 conversion @@ -476,6 +539,7 @@ def test_bf16_no_clamping_initializers_out_of_range(model_with_multiple_consumer initializer_map, node_to_init_map, low_precision_type="bf16", + use_standalone_type_inference=use_standalone_type_inference, ) converter._convert_initializers(low_precision_nodes=["add1", "add2"], high_precision_nodes=[]) @@ -511,13 +575,13 @@ def model_with_dynamic_shapes(): matmul_node = helper.make_node("MatMul", ["X", "weight"], ["matmul_out"], name="matmul") transpose_node = helper.make_node("Transpose", ["Y"], ["transpose_out"], name="transpose") concat_node = helper.make_node( - "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat", axis=0 + "Concat", ["matmul_out", "transpose_out"], ["concat_out"], name="concat1", axis=0 ) size_y = helper.make_node("Size", ["concat_out"], ["total_size"], name="size") const_4 = numpy_helper.from_array(np.array([4], dtype=np.int64), name="const_4") first_dim = helper.make_node("Div", ["total_size", "const_4"], ["first_dim"], name="div") concat_dims_node = helper.make_node( - "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat", axis=0 + "Concat", ["first_dim", "const_4"], ["final_shape"], name="concat2", axis=0 ) reshape_node = helper.make_node("Reshape", ["concat_out", "final_shape"], ["Z"], name="reshape") @@ -540,20 +604,25 @@ def model_with_dynamic_shapes(): model = helper.make_model(graph, producer_name="model_dynamic") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map -def test_dynamic_model_conversion(model_with_dynamic_shapes): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_dynamic_model_conversion(model_with_dynamic_shapes, use_standalone_type_inference): model, value_info_map, initializer_map, node_to_init_map = model_with_dynamic_shapes # Test mixed precision conversion - converter2 = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + converter2 = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) high_precision_nodes = ["matmul"] - low_precision_nodes = ["transpose", "concat", "size", "div", "concat_dims", "reshape"] + low_precision_nodes = ["transpose", "concat1", "size", "div", "concat2", "reshape"] converted_model = converter2.convert(high_precision_nodes, low_precision_nodes) # Verify model is valid @@ -563,7 +632,8 @@ def test_dynamic_model_conversion(model_with_dynamic_shapes): #################################################################################################### # Cast cleanup logic #################################################################################################### -def test_cast_output_pattern(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -583,10 +653,14 @@ def test_cast_output_pattern(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) # Setting all nodes to FP16 means that the final graph should have no cast nodes converted_model = converter.convert( @@ -602,7 +676,8 @@ def test_cast_output_pattern(): assert converted_model.graph.output[i].name == model.graph.output[i].name -def test_cast_output_pattern_mixed_precision(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_pattern_mixed_precision(use_standalone_type_inference): x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [3, 4]) x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [3, 4]) y0 = helper.make_tensor_value_info("Y0", TensorProto.FLOAT, [3, 4]) @@ -625,10 +700,14 @@ def test_cast_output_pattern_mixed_precision(): model = helper.make_model(graph, producer_name="model_double_cast") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) # Network output Y0 has two consumers, one is FP16 and the other is FP32 converted_model = converter.convert( @@ -641,7 +720,8 @@ def test_cast_output_pattern_mixed_precision(): @pytest.mark.parametrize("keep_io_types", [True, False]) -def test_chain_of_casts_pattern(keep_io_types): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_chain_of_casts_pattern(keep_io_types, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4]) @@ -690,11 +770,14 @@ def test_chain_of_casts_pattern(keep_io_types): model = helper.make_model(graph, producer_name="model_cast_chain") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( - model, value_info_map, initializer_map, node_to_init_map, keep_io_types=keep_io_types + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=keep_io_types, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -705,7 +788,8 @@ def test_chain_of_casts_pattern(keep_io_types): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_existing_low_precision_output(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_existing_low_precision_output(low_precision_type, use_standalone_type_inference): # Create a simple model with FP16 output x = helper.make_tensor_value_info("X", low_precision_onnx_type(low_precision_type), [3, 4]) y = helper.make_tensor_value_info("Y", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -715,8 +799,7 @@ def test_existing_low_precision_output(low_precision_type): model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -725,6 +808,7 @@ def test_existing_low_precision_output(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=["add"], low_precision_nodes=[]) @@ -743,7 +827,8 @@ def test_existing_low_precision_output(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_output_cast_output_pattern(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_output_cast_output_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", low_precision_onnx_type(low_precision_type), [3, 4]) @@ -764,9 +849,8 @@ def test_output_cast_output_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_output_cast_output") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -774,6 +858,7 @@ def test_output_cast_output_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Setting nodes precision to match I/O type means that the final graph should have no cast nodes @@ -790,7 +875,8 @@ def test_output_cast_output_pattern(low_precision_type): @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_cast_output_keep_io_types_pattern(low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_cast_output_keep_io_types_pattern(low_precision_type, use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [3, 4]) y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [3, 4]) @@ -809,9 +895,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): model = helper.make_model(graph, producer_name="model_cast_output_keep_io_types") model.opset_import[0].version = 20 model.ir_version = 10 - model = onnx_utils.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -819,6 +903,7 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converter.convert(high_precision_nodes=[], low_precision_nodes=["add1", "add2"]) @@ -827,7 +912,8 @@ def test_cast_output_keep_io_types_pattern(low_precision_type): assert converter.model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT -def test_unsupported_op_types_model(): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_unsupported_op_types_model(use_standalone_type_inference): x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4]) roi = helper.make_tensor_value_info("roi", TensorProto.FLOAT, [3, 4]) scales = helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4]) @@ -848,17 +934,24 @@ def test_unsupported_op_types_model(): [], ) model = helper.make_model(graph, producer_name="model_celu") - model = onnx.shape_inference.infer_shapes(model) - - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) - converter = PrecisionConverter(model, value_info_map, initializer_map, node_to_init_map) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + use_standalone_type_inference=use_standalone_type_inference, + ) converter.convert(high_precision_nodes=[], low_precision_nodes=["celu", "resize", "nms"]) onnx.checker.check_model(converter.model) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("empty_tensor_target", ["low_precision", "high_precision"]) -def test_empty_tensor_handling(low_precision_type, empty_tensor_target): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_empty_tensor_handling( + low_precision_type, empty_tensor_target, use_standalone_type_inference +): """Test empty tensor handling for both low and high precision node targets.""" # Create model with empty float tensor from Constant layer x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2]) @@ -888,8 +981,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, value_info_map, @@ -897,6 +989,7 @@ def test_empty_tensor_handling(low_precision_type, empty_tensor_target): node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Test empty tensor detection @@ -979,14 +1072,16 @@ def model_with_constant_cast_patterns(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_constant_cast_folding( + model_with_constant_cast_patterns, low_precision_type, use_standalone_type_inference +): """Test constant->cast folding as part of the full conversion process.""" model, value_info_map, initializer_map, node_to_init_map = model_with_constant_cast_patterns @@ -997,6 +1092,7 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_ node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Convert with some nodes in low precision to trigger cast insertion @@ -1077,15 +1173,17 @@ def model_with_multiple_output_node_casted_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_multiple_output_node_casted_to_output( - model_with_multiple_output_node_casted_to_output, low_precision_type + model_with_multiple_output_node_casted_to_output, + low_precision_type, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( model_with_multiple_output_node_casted_to_output @@ -1098,6 +1196,7 @@ def test_multiple_output_node_casted_to_output( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"] @@ -1145,16 +1244,19 @@ def model_with_casted_input_to_output(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("keep_io_types", [True, False]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_casted_input_to_output_model( - model_with_casted_input_to_output, low_precision_type, keep_io_types + model_with_casted_input_to_output, + low_precision_type, + keep_io_types, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = model_with_casted_input_to_output @@ -1168,6 +1270,7 @@ def test_casted_input_to_output_model( min_opset=22 if low_precision_type == "bf16" else 13, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, trt_plugins=[], + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=["cast_input"], low_precision_nodes=["add1", "add2"] @@ -1218,8 +1321,7 @@ def create_model_with_resize_op(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @@ -1276,16 +1378,16 @@ def create_model_with_resize_op_tensor_scales(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_initializer_conversion( - create_model_with_resize_op, keep_io_types, low_precision_type + create_model_with_resize_op, keep_io_types, low_precision_type, use_standalone_type_inference ): model, value_info_map, initializer_map, node_to_init_map = create_model_with_resize_op @@ -1296,6 +1398,7 @@ def test_resize_op_initializer_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1305,8 +1408,12 @@ def test_resize_op_initializer_conversion( @pytest.mark.parametrize("keep_io_types", [True, False]) @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_resize_op_tensor_scales_conversion( - create_model_with_resize_op_tensor_scales, keep_io_types, low_precision_type + create_model_with_resize_op_tensor_scales, + keep_io_types, + low_precision_type, + use_standalone_type_inference, ): model, value_info_map, initializer_map, node_to_init_map = ( create_model_with_resize_op_tensor_scales @@ -1319,6 +1426,7 @@ def test_resize_op_tensor_scales_conversion( node_to_init_map, keep_io_types=keep_io_types, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) converted_model = converter.convert( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] @@ -1409,15 +1517,15 @@ def model_with_if_subgraph(): model.ir_version = 10 onnx.checker.check_model(model) - model = onnx_utils.infer_shapes(model) - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) return model, value_info_map, initializer_map, node_to_init_map @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) @pytest.mark.parametrize("if_precision", ["low", "high"]) +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) def test_if_subgraph_initializer_conversion( - model_with_if_subgraph, low_precision_type, if_precision + model_with_if_subgraph, low_precision_type, if_precision, use_standalone_type_inference ): """Test that initializers in If subgraphs are converted based on parent node precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1429,6 +1537,7 @@ def test_if_subgraph_initializer_conversion( node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # Classify the If node based on test parameter @@ -1482,7 +1591,10 @@ def test_if_subgraph_initializer_conversion( @pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) -def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precision_type): +@pytest.mark.parametrize("use_standalone_type_inference", [True, False]) +def test_if_subgraph_mixed_precision_boundary( + model_with_if_subgraph, low_precision_type, use_standalone_type_inference +): """Test that types are correctly handled at If subgraph boundaries in mixed precision.""" model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph @@ -1498,7 +1610,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis model.graph.output.append(output_tensor) # Refresh mappings - value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + model, value_info_map, initializer_map, node_to_init_map = setup_mappings(model) converter = PrecisionConverter( model, @@ -1507,6 +1619,7 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis node_to_init_map, keep_io_types=True, low_precision_type=low_precision_type, + use_standalone_type_inference=use_standalone_type_inference, ) # If in low precision, Add in high precision diff --git a/tests/unit/onnx/autocast/test_referencerunner.py b/tests/unit/onnx/autocast/test_referencerunner.py index de76b7425a..82155e4247 100644 --- a/tests/unit/onnx/autocast/test_referencerunner.py +++ b/tests/unit/onnx/autocast/test_referencerunner.py @@ -23,7 +23,7 @@ from onnx import TensorProto, helper import modelopt.onnx.utils as onnx_utils -from modelopt.onnx.autocast.referencerunner import ReferenceRunner +from modelopt.onnx.autocast.referencerunner import ReferenceRunner, TensorStats def create_multi_io_model(): @@ -141,7 +141,7 @@ def test_run_with_dict_inputs(reference_runner): def test_invalid_input_format(reference_runner): """Test error handling for invalid input format.""" - with pytest.raises(ValueError, match="Supported input file types:.*"): + with pytest.raises(ValueError, match=r"Supported input types:.*"): reference_runner.run("invalid.txt") @@ -175,7 +175,7 @@ def test_invalid_json(reference_runner): json.dump(inputs, f) input_path = f.name try: - with pytest.raises(ValueError, match="Invalid input file."): + with pytest.raises(ValueError, match=r"Invalid input file\."): reference_runner.run(input_path) finally: os.remove(input_path) @@ -189,7 +189,7 @@ def test_invalid_npz_file(reference_runner): np.save(f, data) input_path = f.name try: - with pytest.raises(ValueError, match="Invalid input file."): + with pytest.raises(ValueError, match=r"Invalid input file\."): reference_runner.run(input_path) finally: os.remove(input_path) @@ -240,3 +240,152 @@ def test_compare_outputs(reference_runner): np.testing.assert_allclose(outputs2["Y2"], expected_y2) finally: os.remove(input_path) + + +def test_tensor_stats(): + """Test TensorStats dataclass functionality.""" + stats = TensorStats(absmax=10.5, min_val=-5.0, max_val=10.5, shape=(2, 3)) + + assert stats.absmax == 10.5 + assert stats.min_val == -5.0 + assert stats.max_val == 10.5 + assert stats.shape == (2, 3) + assert stats.size == 6 + assert abs(stats) == 10.5 # Test __abs__ method + + +def test_run_with_multi_batch_npz_directory(reference_runner): + """Test running inference with directory containing multiple NPZ files.""" + # Create temporary directory with multiple NPZ files + with tempfile.TemporaryDirectory() as temp_dir: + # Create batch 1 + inputs1 = { + "X1": np.array([[1.0, 2.0, 3.0]], dtype=np.float32), + "X2": np.array([[4.0, 5.0, 6.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_001.npz"), **inputs1) + + # Create batch 2 + inputs2 = { + "X1": np.array([[2.0, 3.0, 4.0]], dtype=np.float32), + "X2": np.array([[1.0, 2.0, 3.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_002.npz"), **inputs2) + + # Create batch 3 + inputs3 = { + "X1": np.array([[5.0, 6.0, 7.0]], dtype=np.float32), + "X2": np.array([[8.0, 9.0, 10.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_003.npz"), **inputs3) + + # Run with directory path + results = reference_runner.run(temp_dir) + + # Should return TensorStats objects for multi-batch + assert isinstance(results, OrderedDict) + assert "Y1" in results + assert "Y2" in results + + # Check that results are TensorStats objects + assert isinstance(results["Y1"], TensorStats) + assert isinstance(results["Y2"], TensorStats) + + # Verify aggregated statistics + # Y1 = X1 + X2 + # Batch 1: [5.0, 7.0, 9.0] + # Batch 2: [3.0, 5.0, 7.0] + # Batch 3: [13.0, 15.0, 17.0] + assert results["Y1"].absmax == 17.0 + assert results["Y1"].min_val == 3.0 + assert results["Y1"].max_val == 17.0 + assert results["Y1"].shape == (1, 3) + + # Y2 = X1 * X2 + # Batch 1: [4.0, 10.0, 18.0] + # Batch 2: [2.0, 6.0, 12.0] + # Batch 3: [40.0, 54.0, 70.0] + assert results["Y2"].absmax == 70.0 + assert results["Y2"].min_val == 2.0 + assert results["Y2"].max_val == 70.0 + assert results["Y2"].shape == (1, 3) + + +def test_run_with_empty_npz_directory(reference_runner): + """Test error handling for empty NPZ directory.""" + with ( + tempfile.TemporaryDirectory() as temp_dir, + pytest.raises(ValueError, match="No NPZ files found in directory"), + ): + reference_runner.run(temp_dir) + + +def test_single_batch_backward_compatibility(reference_runner): + """Test that single batch still returns raw numpy arrays (backward compatibility).""" + inputs = { + "X1": np.array([[1.0, 2.0, 3.0]], dtype=np.float32), + "X2": np.array([[4.0, 5.0, 6.0]], dtype=np.float32), + } + + with tempfile.NamedTemporaryFile(suffix=".npz", delete=False) as f: + np.savez(f, **inputs) + input_path = f.name + + try: + results = reference_runner.run(input_path) + # Single batch should return raw numpy arrays, not TensorStats + assert isinstance(results, OrderedDict) + assert "Y1" in results + assert "Y2" in results + assert isinstance(results["Y1"], np.ndarray) + assert isinstance(results["Y2"], np.ndarray) + assert not isinstance(results["Y1"], TensorStats) + assert not isinstance(results["Y2"], TensorStats) + finally: + os.remove(input_path) + + +def test_multi_batch_aggregation_statistics(reference_runner): + """Test that multi-batch aggregation correctly computes statistics across batches.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create batches with different value ranges + # Batch 1: small values + inputs1 = { + "X1": np.array([[-1.0, 0.0, 1.0]], dtype=np.float32), + "X2": np.array([[1.0, 2.0, 3.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_001.npz"), **inputs1) + + # Batch 2: large values + inputs2 = { + "X1": np.array([[-10.0, 0.0, 10.0]], dtype=np.float32), + "X2": np.array([[5.0, 6.0, 7.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_002.npz"), **inputs2) + + # Batch 3: mixed values + inputs3 = { + "X1": np.array([[5.0, -5.0, 0.0]], dtype=np.float32), + "X2": np.array([[2.0, 3.0, 4.0]], dtype=np.float32), + } + np.savez(os.path.join(temp_dir, "batch_003.npz"), **inputs3) + + results = reference_runner.run(temp_dir) + + # Y1 = X1 + X2 + # Batch 1: [0.0, 2.0, 4.0] -> absmax=4.0, min=0.0, max=4.0 + # Batch 2: [-5.0, 6.0, 17.0] -> absmax=17.0, min=-5.0, max=17.0 + # Batch 3: [7.0, -2.0, 4.0] -> absmax=7.0, min=-2.0, max=7.0 + # Aggregated: absmax=17.0, min=-5.0, max=17.0 + assert results["Y1"].absmax == 17.0 + assert results["Y1"].min_val == -5.0 + assert results["Y1"].max_val == 17.0 + + # Y2 = X1 * X2 + # Batch 1: [-1.0, 0.0, 3.0] -> absmax=3.0, min=-1.0, max=3.0 + # Batch 2: [-50.0, 0.0, 70.0] -> absmax=70.0, min=-50.0, max=70.0 + # Batch 3: [10.0, -15.0, 0.0] -> absmax=15.0, min=-15.0, max=10.0 + # Aggregated: absmax=70.0, min=-50.0, max=70.0 + assert results["Y2"].absmax == 70.0 + assert results["Y2"].min_val == -50.0 + assert results["Y2"].max_val == 70.0 diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py new file mode 100644 index 0000000000..2818d31723 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -0,0 +1,948 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for common data structures in the autotuner. + +Tests: +1. InsertionPoint classes (NodeInputInsertionPoint, ChildRegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +2. InsertionScheme serialization/deserialization +3. InsertionScheme hashing and equality +4. InsertionScheme properties and methods +5. PatternSchemes management +6. Utility functions (skip_invalid_insertion_points, has_quantizable_operations, etc.) +7. Resolve and collect_from methods for all InsertionPoint types +""" + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import onnx_graphsurgeon as gs +import pytest + +from modelopt.onnx.quantization.autotune.common import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + InsertionScheme, + NodeInputInsertionPoint, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + has_quantizable_operations, + merge_resolved_insertion_points, + resolve_region_io_insertion_points, + skip_invalid_insertion_points, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +INSERTION_POINT_CASES = [ + pytest.param( + NodeInputInsertionPoint, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 3}, + "node_index", + ["5", "2"], + id="NodeInputInsertionPoint", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": None, "node_index": 2, "output_index": 1}, + "region_index", + ["region", "2"], + id="ChildRegionOutputInsertionPoint-region", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 1}, + "node_index", + ["node", "5"], + id="ChildRegionOutputInsertionPoint-node", + ), + pytest.param( + ChildRegionInputInsertionPoint, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 2}, + "region_index", + ["3", "1"], + id="ChildRegionInputInsertionPoint", + ), +] + + +class TestInsertionPoints: + """Combined tests for all InsertionPoint types.""" + + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_creation(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) + for key, val in kwargs.items(): + assert getattr(point, key) == val + + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "mutate_attr", "___"), INSERTION_POINT_CASES + ) + def test_immutability(self, cls, kwargs, _, __, mutate_attr, ___): + point = cls(**kwargs) + with pytest.raises(AttributeError): + setattr(point, mutate_attr, 999) + + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_equality(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) + assert point1 == point2 + assert point1 != point3 + + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_hashable(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) + point_set = {point1, point2, point3} + assert len(point_set) == 2 + + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_serialization(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) + data = point.to_dict() + for key, val in kwargs.items(): + assert data[key] == val + restored = cls.from_dict(data) + assert point == restored + + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "___", "str_checks"), INSERTION_POINT_CASES + ) + def test_string_representation(self, cls, kwargs, _, __, ___, str_checks): + point = cls(**kwargs) + s = str(point).lower() + for check in str_checks: + assert check.lower() in s + + +class TestInsertionScheme: + """Test InsertionScheme functionality.""" + + def test_empty_scheme(self): + """Test empty InsertionScheme.""" + scheme = InsertionScheme() + assert scheme.is_empty + assert len(scheme.node_inputs) == 0 + assert len(scheme.child_region_inputs) == 0 + assert len(scheme.region_outputs) == 0 + assert not scheme.error + + @pytest.mark.parametrize( + ("attr", "points"), + [ + ("node_inputs", [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)]), + ( + "region_outputs", + [ + ChildRegionOutputInsertionPoint(None, 0, 0), + ChildRegionOutputInsertionPoint(1, None, 0), + ], + ), + ( + "child_region_inputs", + [ChildRegionInputInsertionPoint(0, 0), ChildRegionInputInsertionPoint(1, 0)], + ), + ], + ) + def test_scheme_with_points_not_empty(self, attr, points): + """Test scheme with insertion points is not empty.""" + scheme = InsertionScheme() + setattr(scheme, attr, points) + assert not scheme.is_empty + assert len(getattr(scheme, attr)) == 2 + + def test_scheme_hash_empty(self): + """Test hash of empty schemes are equal.""" + assert InsertionScheme().hash == InsertionScheme().hash + + def test_scheme_hash_equality(self): + """Test hash with same/different insertion points.""" + + def make_scheme(*node_indices): + s = InsertionScheme() + s.node_inputs = [NodeInputInsertionPoint(i, 0) for i in node_indices] + return s + + assert make_scheme(0, 1).hash == make_scheme(0, 1).hash + assert make_scheme(0, 1).hash == make_scheme(1, 0).hash # order independent + assert make_scheme(0, 1).hash != make_scheme(0, 2).hash + + @pytest.mark.parametrize( + ("error", "latency"), + [ + (False, float("inf")), # empty + (False, 12.5), # full + (True, float("inf")), # with error + ], + ) + def test_serialization_roundtrip(self, error, latency): + """Test serialization roundtrip.""" + scheme = InsertionScheme() + scheme.error = error + scheme.latency_ms = latency + + if latency != float("inf") or error: # add points for non-empty cases + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [ChildRegionOutputInsertionPoint(None, 0, 0)] + + restored = InsertionScheme.from_dict(scheme.to_dict()) + + assert restored.error == error + assert restored.latency_ms == latency + if not scheme.is_empty: + assert len(restored.node_inputs) == len(scheme.node_inputs) + assert len(restored.child_region_inputs) == len(scheme.child_region_inputs) + assert len(restored.region_outputs) == len(scheme.region_outputs) + + +def _create_mock_tensor(name: str, dtype=np.float32, shape=None): + """Create a mock tensor with the specified properties.""" + tensor = MagicMock() + tensor.name = name + tensor.dtype = dtype + tensor.shape = shape if shape is not None else [1, 3, 224, 224] + tensor.inputs = [] + return tensor + + +def _create_mock_node(op: str, inputs: list, outputs: list, name: str = ""): + """Create a mock node with the specified properties.""" + node = MagicMock(spec=gs.Node) + node.op = op + node.name = name + node.inputs = inputs + node.outputs = outputs + return node + + +def _create_region(region_id=1, level=0, region_type=RegionType.LEAF, nodes=None): + """Create a region with the specified properties. + + Args: + region_id: ID for the region + level: Hierarchy level (0 for LEAF, 1+ for COMPOSITE/ROOT) + region_type: Type of region (LEAF, COMPOSITE, or ROOT) + nodes: Optional list/set of node indices to add to the region + + Returns: + Region with specified properties and nodes + """ + region = Region(region_id=region_id, level=level, region_type=region_type) + if nodes: + region.nodes.update(nodes) + return region + + +def _create_simple_graph(): + """Create a mock graph with Conv -> BatchNorm -> Relu -> MaxPool pattern. + + Graph structure: + input -> Conv -> conv_out -> BatchNorm -> bn_out -> Relu -> relu_out -> MaxPool -> pool_out + """ + # Create tensors with realistic shapes + input_tensor = _create_mock_tensor("input", np.float32, [1, 3, 224, 224]) + weight_tensor = _create_mock_tensor("conv_weight", np.float32, [64, 3, 3, 3]) + bias_tensor = _create_mock_tensor("conv_bias", np.float32, [64]) + conv_output = _create_mock_tensor("conv_out", np.float32, [1, 64, 222, 222]) + + # BatchNorm parameters + bn_scale = _create_mock_tensor("bn_scale", np.float32, [64]) + bn_bias = _create_mock_tensor("bn_bias", np.float32, [64]) + bn_mean = _create_mock_tensor("bn_mean", np.float32, [64]) + bn_var = _create_mock_tensor("bn_var", np.float32, [64]) + bn_output = _create_mock_tensor("bn_out", np.float32, [1, 64, 222, 222]) + + relu_output = _create_mock_tensor("relu_out", np.float32, [1, 64, 222, 222]) + pool_output = _create_mock_tensor("pool_out", np.float32, [1, 64, 111, 111]) + + # Create nodes + conv_node = _create_mock_node( + "Conv", [input_tensor, weight_tensor, bias_tensor], [conv_output], "conv1" + ) + bn_node = _create_mock_node( + "BatchNormalization", + [conv_output, bn_scale, bn_bias, bn_mean, bn_var], + [bn_output], + "bn1", + ) + relu_node = _create_mock_node("Relu", [bn_output], [relu_output], "relu1") + pool_node = _create_mock_node("MaxPool", [relu_output], [pool_output], "pool1") + + # Link tensors to their producer nodes + conv_output.inputs = [conv_node] + bn_output.inputs = [bn_node] + relu_output.inputs = [relu_node] + pool_output.inputs = [pool_node] + input_tensor.inputs = [] + weight_tensor.inputs = [] + bias_tensor.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv_node, bn_node, relu_node, pool_node] + graph.inputs = [input_tensor] + graph.outputs = [pool_output] + + tensors = { + "input": input_tensor, + "conv_weight": weight_tensor, + "conv_bias": bias_tensor, + "conv_out": conv_output, + "bn_out": bn_output, + "relu_out": relu_output, + "pool_out": pool_output, + } + + return graph, tensors + + +def _create_residual_graph(): + """Create a mock graph with a residual block pattern (skip connection). + + Graph structure: + input ─────────────────────────────┐ + │ │ + ▼ │ + Conv1 -> conv1_out │ + │ │ + ▼ │ + Relu1 -> relu1_out │ + │ │ + ▼ │ + Conv2 -> conv2_out │ + │ │ + ▼ ▼ + Add (conv2_out + input) -> add_out + │ + ▼ + Relu2 -> output + """ + # Create tensors + input_tensor = _create_mock_tensor("input", np.float32, [1, 64, 56, 56]) + + # First conv branch + weight1 = _create_mock_tensor("conv1_weight", np.float32, [64, 64, 3, 3]) + conv1_out = _create_mock_tensor("conv1_out", np.float32, [1, 64, 56, 56]) + relu1_out = _create_mock_tensor("relu1_out", np.float32, [1, 64, 56, 56]) + + # Second conv + weight2 = _create_mock_tensor("conv2_weight", np.float32, [64, 64, 3, 3]) + conv2_out = _create_mock_tensor("conv2_out", np.float32, [1, 64, 56, 56]) + + # Add and final relu + add_out = _create_mock_tensor("add_out", np.float32, [1, 64, 56, 56]) + output = _create_mock_tensor("output", np.float32, [1, 64, 56, 56]) + + # Create nodes + conv1_node = _create_mock_node("Conv", [input_tensor, weight1], [conv1_out], "conv1") + relu1_node = _create_mock_node("Relu", [conv1_out], [relu1_out], "relu1") + conv2_node = _create_mock_node("Conv", [relu1_out, weight2], [conv2_out], "conv2") + add_node = _create_mock_node("Add", [conv2_out, input_tensor], [add_out], "add1") + relu2_node = _create_mock_node("Relu", [add_out], [output], "relu2") + + # Link tensors to their producer nodes + conv1_out.inputs = [conv1_node] + relu1_out.inputs = [relu1_node] + conv2_out.inputs = [conv2_node] + add_out.inputs = [add_node] + output.inputs = [relu2_node] + input_tensor.inputs = [] + weight1.inputs = [] + weight2.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv1_node, relu1_node, conv2_node, add_node, relu2_node] + graph.inputs = [input_tensor] + graph.outputs = [output] + + tensors = { + "input": input_tensor, + "conv1_weight": weight1, + "conv1_out": conv1_out, + "relu1_out": relu1_out, + "conv2_weight": weight2, + "conv2_out": conv2_out, + "add_out": add_out, + "output": output, + } + + return graph, tensors + + +class TestSkipInvalidInsertionPoints: + """Test skip_invalid_insertion_points function.""" + + @pytest.mark.parametrize( + ("op", "should_skip"), + [ + ("Equal", True), # bool op + ("Shape", True), # shape op + ("MatMul", False), # normal op + ("Add", False), # normal op + ], + ) + def test_skip_by_op_type(self, op, should_skip): + graph, _ = _create_simple_graph() + tensor = _create_mock_tensor("test_input", np.float32, [1, 64, 32, 32]) + node = _create_mock_node(op, [tensor], []) + assert skip_invalid_insertion_points(graph, "test_input", node) is should_skip + + @pytest.mark.parametrize( + ("dtype", "shape", "should_skip"), + [ + (np.int32, [1, 64, 32, 32], True), # non-float + (np.float32, [1], True), # small tensor + (np.float32, [1, 64, 32, 32], False), # large float - OK + ], + ) + def test_skip_by_tensor_properties(self, dtype, shape, should_skip): + graph, _ = _create_simple_graph() + tensor = _create_mock_tensor("test", dtype, shape) + node = _create_mock_node("Add", [tensor], []) + assert skip_invalid_insertion_points(graph, "test", node) is should_skip + + def test_skip_conv_weight_input(self): + """Conv weight inputs (index >= 1) are skipped.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "conv_weight", graph.nodes[0]) + assert result is True + + def test_skip_bn_non_data_inputs(self): + """BatchNormalization non-data inputs are skipped.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "bn_scale", graph.nodes[1]) + assert result is True + + def test_skip_conv_bn_relu_fusion(self): + """Conv->BN->Relu fusion patterns are skipped at intermediate points.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "bn_out", graph.nodes[2]) + assert result is True + + def test_with_region(self): + """Test with a Region containing multiple nodes.""" + graph, _ = _create_simple_graph() + region = _create_region(nodes=[0, 1]) + + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + graph.nodes.append(shape_node) + region.nodes.add(4) + + assert skip_invalid_insertion_points(graph, "shape_input", region) is True + + def test_residual_block_add_inputs_allowed(self): + """Add node inputs in residual blocks should be allowed.""" + graph, _ = _create_residual_graph() + add_node = graph.nodes[3] + + assert skip_invalid_insertion_points(graph, "conv2_out", add_node) is False + assert skip_invalid_insertion_points(graph, "input", add_node) is False + + +class TestHasQuantizableOperations: + """Test has_quantizable_operations function.""" + + @pytest.mark.parametrize( + ("nodes", "graph_fn", "expected"), + [ + ({0}, _create_simple_graph, True), # Conv + ({3}, _create_simple_graph, True), # MaxPool + ({2}, _create_simple_graph, True), # Relu + ({0, 1, 2}, _create_simple_graph, True), # Conv->BN->Relu + ({3}, _create_residual_graph, True), # Add in residual + ], + ) + def test_leaf_with_quantizable_ops(self, nodes, graph_fn, expected): + """Test LEAF region with various quantizable operations.""" + graph, _ = graph_fn() + region = _create_region(nodes=nodes) + assert has_quantizable_operations(region, graph) is expected + + def test_leaf_without_quantizable_ops(self): + """Test LEAF region without major quantizable operations.""" + shape_tensor = _create_mock_tensor("input", np.float32) + output_tensor = _create_mock_tensor("output", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], [output_tensor]) + transpose_node = _create_mock_node("Transpose", [output_tensor], []) + graph = MagicMock(spec=gs.Graph) + graph.nodes = [shape_node, transpose_node] + region = _create_region(nodes={0, 1}) + + assert has_quantizable_operations(region, graph) is False + + def test_composite_region_always_true(self): + """Test that COMPOSITE regions always return True.""" + graph, _ = _create_simple_graph() + region = _create_region(level=1, region_type=RegionType.COMPOSITE) + assert has_quantizable_operations(region, graph) is True + + +class TestResolveRegionIOInsertionPoints(unittest.TestCase): + """Test resolve_region_io_insertion_points function.""" + + def test_resolve_with_region(self): + """Test resolving with a region containing Conv->BN->Relu.""" + graph, tensors = _create_simple_graph() + + # Set up tensor_users_map: conv_out is consumed by BatchNorm (node 1) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + region = _create_region(nodes=[2]) # Relu node + result = resolve_region_io_insertion_points(region, graph, "relu_out") + + assert len(result) >= 1 + assert any(ip.tensor_name == "relu_out" for ip in result) + + def test_resolve_without_region(self): + """Test resolving without a region (None) for tensor-level insertion.""" + graph, _ = _create_simple_graph() + + # Set up tensor_users_map: bn_out is consumed by Relu (node 2) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + result = resolve_region_io_insertion_points(None, graph, "relu_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu_out" + assert ip.node_index == 3 + assert ip.input_index == 0 + + def test_resolve_tensor_not_found(self): + """Test resolving a tensor that has no users.""" + graph, _ = _create_simple_graph() + graph.tensor_users_map = {} + result = resolve_region_io_insertion_points(None, graph, "nonexistent") + + assert len(result) == 0 + + def test_resolve_residual_skip_connection(self): + """Test resolving input tensor used by both Conv1 and Add (skip connection).""" + graph, tensors = _create_residual_graph() + + # Input tensor is used by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = {"input": [0, 3]} + result = resolve_region_io_insertion_points(None, graph, "input") + + # Should find both consumers + assert len(result) == 2 + node_indices = {ip.node_index for ip in result} + assert 0 in node_indices # Conv1 + assert 3 in node_indices # Add + + def test_resolve_with_multiple_consumers(self): + """Test resolving tensor with multiple consumers in a region.""" + graph, tensors = _create_residual_graph() + + # relu1_out feeds conv2 (node 2) + graph.tensor_users_map = {"relu1_out": [2]} + + region = _create_region(nodes=[2]) # Conv2 + + result = resolve_region_io_insertion_points(region, graph, "relu1_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu1_out" + assert ip.node_index == 2 + + +class TestMergeResolvedInsertionPoints(unittest.TestCase): + """Test merge_resolved_insertion_points function.""" + + def test_merge_all_users(self): + """Test merging when all users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by BatchNorm (node 1) + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "conv_out" + assert merged.node_index is None + assert merged.input_index is None + + def test_no_merge_partial_users(self): + """Test no merging when only some users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by nodes 1 and 2, but only node 1 has IP + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1, 2]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT be merged - keep node-specific + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 1 # Still node-specific + + def test_preserve_tensor_level_insertions(self): + """Test that existing tensor-level insertions are preserved.""" + graph, _ = _create_simple_graph() + + # Already tensor-level insertion + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=None, input_index=None), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "input" + assert ip.node_index is None + + def test_merge_residual_skip_connection(self): + """Test merging with residual block where input has two users.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # If we have insertion points for both, they should merge + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + ResolvedInsertionPoint(tensor_name="input", node_index=3, input_index=1), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "input" + assert merged.node_index is None + + def test_no_merge_residual_partial(self): + """Test no merging in residual block when only one branch has insertion point.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # Only Conv1 has an insertion point + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT merge - only one of two users has IP + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 0 # Still node-specific + + +class TestNodeInputInsertionPointMethods(unittest.TestCase): + """Test NodeInputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_simple(self): + """Test resolving a simple node input for Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + + # Create insertion point for first input of first node (Conv) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_conv_includes_weight(self): + """Test that resolving Conv input also includes weight.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0]) # Conv node + + # Create insertion point for first input of Conv (should also add weight) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + result = ip.resolve(region, graph) + + # Should include both data input and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "input" in tensor_names + assert "conv_weight" in tensor_names + + def test_resolve_relu_input(self): + """Test resolving Relu input in the middle of the chain.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + + # Relu is at local index 2, input 0 is bn_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + result = ip.resolve(region, graph) + + assert len(result) == 1 + rip = next(iter(result)) + assert rip.tensor_name == "bn_out" + + def test_resolve_residual_conv_input(self): + """Test resolving Conv input in residual block.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2]) # Conv1, Relu1, Conv2 + + # Conv2 is at local index 2, input 0 is relu1_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + result = ip.resolve(region, graph) + + # Conv includes both data and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "relu1_out" in tensor_names + assert "conv2_weight" in tensor_names + + def test_collect_valid_inputs(self): + """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected some insertion points + assert len(result) >= 1 + # All should be NodeInputInsertionPoint + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + def test_collect_from_residual_block(self): + """Test collecting from residual block with skip connection.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected insertion points from Conv1, Add inputs, etc. + assert len(result) >= 1 + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + # Check that we have insertion points for different nodes + node_indices = {ip.node_index for ip in result} + assert len(node_indices) >= 1 # At least one node has valid inputs + + +class TestChildRegionInputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_composite_region(self): + """Test resolving child region input in COMPOSITE region.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"input": [0]} + + # Create parent (COMPOSITE) with child (LEAF) containing Conv->BN->Relu + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.inputs = ["input"] + parent.add_child(child) + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_leaf_returns_empty(self): + """Test that LEAF regions return empty set.""" + graph, _ = _create_simple_graph() + leaf = _create_region(nodes=[0]) + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result = ip.resolve(leaf, graph) + assert len(result) == 0 + + def test_resolve_multiple_children(self): + """Test resolving child inputs in COMPOSITE with multiple children.""" + graph, tensors = _create_residual_graph() + # input is consumed by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + # Create parent with two child regions + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + # First child: Conv1 (consumes "input") + child1 = _create_region(region_id=2, nodes=[0]) # Conv1 + child1.inputs = ["input"] + + # Second child: Relu1 (consumes "relu1_out") + child2 = _create_region(region_id=3, nodes=[2]) # Relu1 + child2.inputs = ["relu1_out"] + parent.add_child(child1) + parent.add_child(child2) + + # Resolve input of first child (region_index=0) - "input" tensor + ip1 = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result1 = ip1.resolve(parent, graph) + + assert len(result1) >= 1 + assert any(rip.tensor_name == "input" for rip in result1) + + # Resolve input of second child (region_index=1) - "relu1_out" tensor + ip2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + result2 = ip2.resolve(parent, graph) + + assert len(result2) >= 1 + assert any(rip.tensor_name == "relu1_out" for rip in result2) + + def test_collect_from_composite(self): + """Test collecting from COMPOSITE region with children.""" + graph, tensors = _create_simple_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.inputs = ["input"] + parent.add_child(child) + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find the child's input + assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + def test_collect_from_leaf_returns_empty(self): + """Test that LEAF regions return empty list.""" + graph, _ = _create_simple_graph() + leaf = _create_region(nodes=[0]) + result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) + assert len(result) == 0 + + def test_collect_from_composite_with_multiple_children(self): + """Test collecting from COMPOSITE with multiple child regions.""" + graph, tensors = _create_residual_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = _create_region(region_id=2, nodes=[0, 1]) # Conv1, Relu1 + child1.inputs = ["input"] + child2 = _create_region(region_id=3, nodes=[2, 3]) # Conv2, Add + child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection + parent.add_child(child1) + parent.add_child(child2) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find inputs from both children + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + +class TestChildRegionOutputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_node_output(self): + """Test resolving a node output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + region.outputs = ["pool_out"] + # Output of last node (MaxPool) + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) + result = ip.resolve(region, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_child_region_output(self): + """Test resolving a child region output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"relu_out": [3]} + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.outputs = ["relu_out"] + parent.add_child(child) + ip = ChildRegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + result = ip.resolve(parent, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_residual_add_output(self): + """Test resolving Add output in residual block.""" + graph, tensors = _create_residual_graph() + graph.tensor_users_map = {"add_out": [4]} + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + region.outputs = ["add_out"] + # Add is at local index 3, output 0 + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) + result = ip.resolve(region, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "add_out" for rip in result) + + def test_collect_node_outputs(self): + """Test collecting node output insertion points.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + region.outputs = ["pool_out"] # Only pool_out is a region output + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the node output that matches region output + assert len(result) >= 0 # May be filtered + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) + + def test_collect_child_region_outputs(self): + """Test collecting child region output insertion points.""" + graph, tensors = _create_simple_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.outputs = ["relu_out"] + parent.add_child(child) + parent.outputs = ["relu_out"] # Child output is also parent output + result = ChildRegionOutputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child region output + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) + + def test_collect_residual_block_outputs(self): + """Test collecting outputs from residual block.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + region.outputs = ["output"] # Final output + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the output + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py new file mode 100644 index 0000000000..a27b1c98ca --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Region class in the autotuner.""" + +import pytest + +from modelopt.onnx.quantization.autotune.common import Region, RegionType + + +@pytest.fixture +def leaf(): + return Region(region_id=1, level=0, region_type=RegionType.LEAF) + + +@pytest.fixture +def parent_with_children(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + parent.add_child(child1) + parent.add_child(child2) + return parent, child1, child2 + + +@pytest.mark.parametrize( + ("region_id", "level", "region_type"), + [ + (1, 0, RegionType.LEAF), + (2, 1, RegionType.COMPOSITE), + (0, 2, RegionType.ROOT), + ], +) +def test_region_creation(region_id, level, region_type): + region = Region(region_id=region_id, level=level, region_type=region_type) + assert (region.id, region.level, region.type) == (region_id, level, region_type) + + +def test_parent_child_relationship(parent_with_children): + parent, child1, child2 = parent_with_children + assert parent.get_children() == [child1, child2] + assert child1.parent == child2.parent == parent + + +def test_add_and_get_nodes(leaf): + leaf.nodes.update([0, 1, 2]) + assert set(leaf.get_nodes()) == {0, 1, 2} + + +def test_input_output_tensors(leaf): + leaf.inputs = ["in1", "in2"] + leaf.outputs = ["out1"] + assert leaf.inputs == ["in1", "in2"] + assert leaf.outputs == ["out1"] + + +def test_region_size_recursive(parent_with_children): + parent, child1, child2 = parent_with_children + child1.nodes.update([0, 1]) + child2.nodes.update([2, 3, 4]) + parent.nodes.add(5) + assert len(parent.get_region_nodes_and_descendants()) == 6 + + +def test_metadata(leaf): + leaf.metadata.update({"pattern": "Conv->Relu", "quantizable": "true"}) + assert leaf.metadata == {"pattern": "Conv->Relu", "quantizable": "true"} + + +def test_hierarchical_structure(): + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + comp1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + comp2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaves = [Region(region_id=i, level=0, region_type=RegionType.LEAF) for i in range(3, 6)] + root.add_child(comp1) + root.add_child(comp2) + comp1.add_child(leaves[0]) + comp1.add_child(leaves[1]) + comp2.add_child(leaves[2]) + for i, leaf in enumerate(leaves): + leaf.nodes.add(i) + assert len(root.get_children()) == 2 + assert len(comp1.get_children()) == 2 + assert len(comp2.get_children()) == 1 + assert len(root.get_region_nodes_and_descendants()) == 3 + + +def test_remove_child(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent.add_child(child) + parent.remove_child(child) + assert parent.get_children() == [] + assert child.parent is None diff --git a/tests/unit/onnx/quantization/autotune/test_region_inspect.py b/tests/unit/onnx/quantization/autotune/test_region_inspect.py new file mode 100644 index 0000000000..a932fa3c29 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_inspect.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for region_inspect module.""" + +import os +from unittest.mock import Mock, patch + +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper, numpy_helper + + +def create_simple_onnx_model(): + """Create a simple ONNX model for testing. + + Creates a model with: Input -> Conv -> Relu -> MatMul -> Output + """ + # Create input + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1000]) + + # Create weights for Conv + conv_weight = np.random.randn(64, 3, 7, 7).astype(np.float32) + conv_weight_tensor = numpy_helper.from_array(conv_weight, "conv_weight") + + # Create weights for MatMul + matmul_weight = np.random.randn(64, 1000).astype(np.float32) + matmul_weight_tensor = numpy_helper.from_array(matmul_weight, "matmul_weight") + + # Create nodes + conv_node = helper.make_node( + "Conv", + inputs=["input", "conv_weight"], + outputs=["conv_output"], + kernel_shape=[7, 7], + strides=[2, 2], + pads=[3, 3, 3, 3], + ) + + relu_node = helper.make_node( + "Relu", + inputs=["conv_output"], + outputs=["relu_output"], + ) + + flatten_node = helper.make_node( + "Flatten", + inputs=["relu_output"], + outputs=["flatten_output"], + axis=1, + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=["flatten_output", "matmul_weight"], + outputs=["output"], + ) + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node, flatten_node, matmul_node], + "test_model", + [input_tensor], + [output_tensor], + [conv_weight_tensor, matmul_weight_tensor], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + model.opset_import[0].version = 13 + + return model + + +@pytest.fixture +def simple_onnx_model(): + """Fixture that provides a simple ONNX model.""" + return create_simple_onnx_model() + + +@pytest.fixture +def onnx_model_file(tmp_path, simple_onnx_model): + """Fixture that provides a path to a saved ONNX model.""" + model_path = os.path.join(tmp_path, "test_model.onnx") + onnx.save(simple_onnx_model, model_path) + return model_path + + +class TestRegionInspectImports: + """Test that the region_inspect module can be imported.""" + + def test_module_imports(self): + """Test that the module imports without errors when dependencies exist.""" + # This test will skip if the required dependencies don't exist + try: + from modelopt.onnx.quantization.autotune import region_inspect + + assert hasattr(region_inspect, "inspect_region_search") + assert hasattr(region_inspect, "main") + except ImportError as e: + pytest.skip(f"Required dependencies not available: {e}") + + +class TestRegionInspectWithMocks: + """Test region_inspect functionality with mocked dependencies.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_basic( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test basic functionality of inspect_region_search with mocked dependencies.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks + mock_region = Mock() + mock_region.type = Mock(value="LEAF") + mock_region.inputs = ["input1"] + mock_region.outputs = ["output1"] + mock_region.children = [] + mock_region.get_region_nodes_and_descendants.return_value = [Mock(), Mock()] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + mock_has_quantizable.return_value = True + + # Call the function + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False + ) + + # Verify the function was called correctly + assert mock_combined_search.called + assert mock_search_instance.search_regions.called + assert isinstance(result, list) + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_with_custom_params( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test inspect_region_search with custom parameters.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks + mock_region = Mock() + mock_region.type = Mock(value="COMPOSITE") + mock_region.inputs = ["input1"] + mock_region.outputs = ["output1"] + mock_region.children = [] + mock_region.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + mock_has_quantizable.return_value = True + + # Call with custom parameters + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=20, include_all_regions=True + ) + + # Verify custom parameters were used + assert mock_combined_search.called + call_kwargs = mock_combined_search.call_args[1] + assert call_kwargs.get("maximum_sequence_region_size") == 20 + assert isinstance(result, list) + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_filtering( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test that regions without quantizable operations are filtered out.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks - one region with quantizable ops, one without + mock_region_quantizable = Mock() + mock_region_quantizable.type = Mock(value="LEAF") + mock_region_quantizable.inputs = ["input1"] + mock_region_quantizable.outputs = ["output1"] + mock_region_quantizable.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region_quantizable.get_children.return_value = [] + + mock_region_non_quantizable = Mock() + mock_region_non_quantizable.type = Mock(value="LEAF") + mock_region_non_quantizable.inputs = ["input2"] + mock_region_non_quantizable.outputs = ["output2"] + mock_region_non_quantizable.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region_non_quantizable.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [ + mock_region_quantizable, + mock_region_non_quantizable, + ] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + # First region has quantizable ops, second doesn't + mock_has_quantizable.side_effect = [True, False] + + # Call with filtering enabled + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False + ) + + # Should only return the quantizable region + assert len(result) == 1 + + +class TestRegionInspectMain: + """Test the main CLI entry point.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_success(self, mock_inspect, onnx_model_file): + """Test main function with successful execution.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock(), Mock()] + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]): + exit_code = main() + assert exit_code == 0 + assert mock_inspect.called + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_verbose(self, mock_inspect, onnx_model_file): + """Test main function with verbose flag.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file, "--verbose"]): + exit_code = main() + assert exit_code == 0 + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_custom_max_sequence_size(self, mock_inspect, onnx_model_file): + """Test main function with custom max_sequence_size.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch( + "sys.argv", ["region_inspect", "--model", onnx_model_file, "--max-sequence-size", "20"] + ): + exit_code = main() + assert exit_code == 0 + # Verify max_sequence_size parameter was passed + call_kwargs = mock_inspect.call_args[1] + assert call_kwargs.get("max_sequence_size") == 20 + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_include_all_regions(self, mock_inspect, onnx_model_file): + """Test main function with include_all_regions flag.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch( + "sys.argv", ["region_inspect", "--model", onnx_model_file, "--include-all-regions"] + ): + exit_code = main() + assert exit_code == 0 + # Verify include_all_regions parameter was passed + call_kwargs = mock_inspect.call_args[1] + assert call_kwargs.get("include_all_regions") is True + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_failure(self, mock_inspect, onnx_model_file): + """Test main function with execution failure.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.side_effect = Exception("Test error") + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]): + exit_code = main() + assert exit_code == 1 + + +class TestRegionInspectModelLoading: + """Test model loading functionality.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_loads_valid_onnx_model( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test that a valid ONNX model can be loaded.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup minimal mocks + mock_region = Mock() + mock_region.type = Mock(value="LEAF") + mock_region.inputs = [] + mock_region.outputs = [] + mock_region.get_region_nodes_and_descendants.return_value = [] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + mock_has_quantizable.return_value = False + + # Should not raise an exception + result = inspect_region_search(onnx_model_file) + assert isinstance(result, list) + + def test_fails_on_nonexistent_file(self): + """Test that loading a non-existent file raises an error.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + with pytest.raises(Exception): # Could be FileNotFoundError or other + inspect_region_search("/nonexistent/path/to/model.onnx") diff --git a/tests/unit/onnx/quantization/autotune/test_region_pattern.py b/tests/unit/onnx/quantization/autotune/test_region_pattern.py new file mode 100644 index 0000000000..1d134457c6 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_pattern.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for RegionPattern functionality in the autotuner. + +Tests pattern generation, matching, and tree visualization. +""" + +import numpy as np +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class TestRegionPattern: + """Test RegionPattern functionality.""" + + # ========================================================================= + # Helper Methods + # ========================================================================= + + @staticmethod + def _create_simple_graph(): + """Create a simple Conv->Relu graph for testing. + + Graph structure: + input -> Conv -> Relu -> output + """ + # Create inputs and outputs + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 3, 224, 224]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu_out = gs.Variable(name="output", dtype=np.float32) + + # Create weights + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 3, 3, 3), dtype=np.float32) + ) + + # Create nodes + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [1, 1, 1, 1]}, + ) + relu = gs.Node( + name="Relu_0", + op="Relu", + inputs=[conv_out], + outputs=[relu_out], + ) + + # Create graph + graph = gs.Graph( + nodes=[conv, relu], + inputs=[inp], + outputs=[relu_out], + opset=13, + ) + return graph + + @staticmethod + def _create_hierarchical_graph(): + """Create a hierarchical graph with composite regions. + + Graph structure: + input -> Conv -> Relu -> Add -> MatMul -> Relu -> output + ^ + | + other_input + + Region structure: + ROOT + ├── COMPOSITE (Conv->Relu->Add) + │ ├── LEAF (Conv->Relu) + │ └── LEAF (Add) + └── COMPOSITE (MatMul->Relu) + └── LEAF (MatMul->Relu) + """ + # Create inputs and intermediate tensors + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 64, 64, 64]) + other_inp = gs.Variable(name="other_input", dtype=np.float32, shape=[1, 64, 64, 64]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu1_out = gs.Variable(name="relu1_out", dtype=np.float32) + add_out = gs.Variable(name="add_out", dtype=np.float32) + matmul_out = gs.Variable(name="matmul_out", dtype=np.float32) + output = gs.Variable(name="output", dtype=np.float32) + + # Create constants + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 64, 1, 1), dtype=np.float32) + ) + matmul_weight = gs.Constant( + name="matmul_weight", values=np.ones((64, 64), dtype=np.float32) + ) + + # Create nodes (order matters for node indices) + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [1, 1]}, + ) # Node 0 + relu1 = gs.Node(name="Relu_0", op="Relu", inputs=[conv_out], outputs=[relu1_out]) # Node 1 + add = gs.Node( + name="Add_0", op="Add", inputs=[relu1_out, other_inp], outputs=[add_out] + ) # Node 2 + matmul = gs.Node( + name="MatMul_0", op="MatMul", inputs=[add_out, matmul_weight], outputs=[matmul_out] + ) # Node 3 + relu2 = gs.Node(name="Relu_1", op="Relu", inputs=[matmul_out], outputs=[output]) # Node 4 + + # Create graph + graph = gs.Graph( + nodes=[conv, relu1, add, matmul, relu2], + inputs=[inp, other_inp], + outputs=[output], + opset=13, + ) + return graph + + @staticmethod + def _create_test_region( + region_id: int, level: int, region_type: RegionType, node_indices: list[int] | None = None + ) -> Region: + """Create a test region.""" + region = Region(region_id, level, region_type) + if node_indices: + region.nodes.update(node_indices) + return region + + # ========================================================================= + # Test Cases + # ========================================================================= + + def test_pattern_creation(self): + """Test basic RegionPattern creation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + assert pattern.signature == "Conv->Relu" + assert pattern.size == 2 + assert not pattern.is_empty + assert pattern.is_leaf + + def test_pattern_equality_and_hash(self): + """Test RegionPattern equality and hashing based on signature.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) # Different size + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + # Same signature = equal (size doesn't affect equality) + assert pattern1 == pattern2 + # Different signature = not equal + assert pattern1 != pattern3 + + # Same signature = same hash + assert hash(pattern1) == hash(pattern2) + + # Can be used as dict keys + pattern_dict = {pattern1: "scheme1"} + assert pattern_dict[pattern2] == "scheme1" # pattern2 finds pattern1's entry + + def test_pattern_from_simple_region(self): + """Test pattern computation from a simple region.""" + graph = self._create_simple_graph() + + # Create a leaf region with Conv and Relu nodes + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + pattern = RegionPattern.from_region(region, graph) + + # Should capture both operations + assert "Conv" in pattern.signature + assert "Relu" in pattern.signature + assert pattern.size == 2 + assert pattern.is_leaf + + def test_pattern_from_composite_region(self): + """Test pattern computation from a composite region with children.""" + graph = self._create_hierarchical_graph() + + # Create leaf regions + leaf1 = self._create_test_region( + region_id=1, + level=0, + region_type=RegionType.LEAF, + node_indices=[0, 1], # Conv, Relu + ) + leaf2 = self._create_test_region( + region_id=2, + level=0, + region_type=RegionType.LEAF, + node_indices=[2], # Add + ) + + # Create composite region + composite = self._create_test_region( + region_id=3, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf1) + composite.add_child(leaf2) + + pattern = RegionPattern.from_region(composite, graph) + + assert pattern.is_composite + assert "COMPOSITE" in pattern.signature + assert pattern.size == 3 # Total nodes in region hierarchy + + def test_pattern_get_hash(self): + """Test cryptographic hash generation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + hash_val = pattern.get_hash() + + # Hash should be 32 hex characters (128-bit truncated SHA-256) + assert len(hash_val) == 32 + assert all(c in "0123456789abcdef" for c in hash_val) + + # Same signature = same hash + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + assert pattern.get_hash() == pattern2.get_hash() + + def test_pattern_get_short_signature(self): + """Test signature truncation.""" + long_sig = "COMPOSITE(" + "Conv->Relu->" * 20 + "Output)" + pattern = RegionPattern(signature=long_sig, size=20) + + short_sig = pattern.get_short_signature(max_length=50) + assert len(short_sig) == 50 + assert short_sig.endswith("...") + + # Short signature stays unchanged + short_pattern = RegionPattern(signature="Conv", size=1) + assert short_pattern.get_short_signature(max_length=50) == "Conv" + + def test_print_tree(self): + """Test format_tree to visualize region structure. + + This test demonstrates how to use format_tree to display + the hierarchical structure of regions and their patterns. + """ + graph = self._create_hierarchical_graph() + + # Build a hierarchical region structure: + # ROOT (level=2) + # ├── COMPOSITE (level=1) [Conv->Relu + Add] + # │ ├── LEAF (level=0) [Conv, Relu - nodes 0,1] + # │ └── LEAF (level=0) [Add - node 2] + # └── LEAF (level=0) [MatMul, Relu - nodes 3,4] + + # Create leaf regions + leaf_conv_relu = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + leaf_add = self._create_test_region( + region_id=2, level=0, region_type=RegionType.LEAF, node_indices=[2] + ) + leaf_matmul_relu = self._create_test_region( + region_id=3, level=0, region_type=RegionType.LEAF, node_indices=[3, 4] + ) + + # Create composite region containing conv_relu and add + composite = self._create_test_region( + region_id=4, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf_conv_relu) + composite.add_child(leaf_add) + + # Create root region containing everything + root = self._create_test_region( + region_id=5, level=2, region_type=RegionType.ROOT, node_indices=[] + ) + root.add_child(composite) + root.add_child(leaf_matmul_relu) + + # Generate pattern for root and print tree + root_pattern = RegionPattern.from_region(root, graph) + tree_output = root_pattern.format_tree(root, graph) + + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(tree_output) + print("=" * 60) + + # Verify tree output contains expected elements + assert "Region 5" in tree_output # Root + assert "Region 4" in tree_output # Composite + assert "Region 1" in tree_output # Leaf conv_relu + assert "Region 2" in tree_output # Leaf add + assert "Region 3" in tree_output # Leaf matmul_relu + + # Verify indentation shows hierarchy + lines = tree_output.strip().split("\n") + assert len(lines) >= 3 # At least root + children + + # Root should have no indentation + assert lines[0].startswith("Region 5") + + # Children should be indented + indented_lines = [line for line in lines if line.startswith(" ")] + assert len(indented_lines) > 0 + + def test_pattern_matches(self): + """Test pattern matching against both patterns and regions.""" + # Test pattern-to-pattern matching + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + assert pattern1.matches(pattern2) # Same signature + assert not pattern1.matches(pattern3) # Different signature + + # Test pattern-to-region matching + graph = self._create_simple_graph() + + # Create region + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + + # Match should return node IDs + node_ids = pattern.matches(region, graph) + assert node_ids is not None + assert set(node_ids) == {0, 1} + + def test_empty_region_pattern(self): + """Test pattern for empty region.""" + graph = self._create_simple_graph() + + # Create empty region + empty_region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[] + ) + + pattern = RegionPattern.from_region(empty_region, graph) + + assert pattern.is_empty + assert pattern.signature == "EMPTY" + assert pattern.size == 0 + + def test_symmetric_operation_signature(self): + """Test that symmetric operations (Add, Mul) have consistent signatures.""" + # Create two graphs with Add inputs in different order + inp1 = gs.Variable(name="input1", dtype=np.float32, shape=[1, 64]) + inp2 = gs.Variable(name="input2", dtype=np.float32, shape=[1, 64]) + out = gs.Variable(name="output", dtype=np.float32) + + # Graph 1: Add(inp1, inp2) + add1 = gs.Node(name="Add_0", op="Add", inputs=[inp1, inp2], outputs=[out]) + graph1 = gs.Graph(nodes=[add1], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Graph 2: Add(inp2, inp1) - reversed inputs + add2 = gs.Node(name="Add_0", op="Add", inputs=[inp2, inp1], outputs=[out]) + graph2 = gs.Graph(nodes=[add2], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Create regions + region1 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + region2 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + + pattern1 = RegionPattern.from_region(region1, graph1) + pattern2 = RegionPattern.from_region(region2, graph2) + + # Patterns should be equal regardless of input order + assert pattern1 == pattern2 + + def test_pattern_repr_and_str(self): + """Test string representations.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + # str() shows just signature + assert str(pattern) == "Conv->Relu" + + # repr() shows full info + assert "RegionPattern" in repr(pattern) + assert "Conv->Relu" in repr(pattern) + assert "size=2" in repr(pattern) diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py new file mode 100644 index 0000000000..e2fb179fd3 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for region search algorithms. + +Tests CombinedRegionSearch, RegionPartitioner, and TopDownRegionBuilder. +Note: Comprehensive integration tests with real ONNX graphs should be in separate integration test files. +""" + +import io + +import onnx +import onnx_graphsurgeon as gs +import pytest +from onnx import helper + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_search import ( + CombinedRegionSearch, + RegionPartitioner, + TopDownRegionBuilder, +) + + +@pytest.fixture +def simple_linear_graph(): + """ + Create a simple linear graph: Input -> Conv -> Relu -> Output. + + This is the simplest possible graph for testing region discovery. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_linear", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + + # Convert to GraphSurgeon + return gs.import_onnx(model) + + +@pytest.fixture +def divergent_graph(): + """ + Create a graph with divergence: Input -> Conv -> [Relu1, Relu2] -> Add -> Output. + + Tests divergence/convergence pattern detection. + """ + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu1_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu1_out"], name="relu1") + relu2_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu2_out"], name="relu2") + add_node = helper.make_node( + "Add", inputs=["relu1_out", "relu2_out"], outputs=["output"], name="add" + ) + + graph = helper.make_graph( + [conv_node, relu1_node, relu2_node, add_node], + "divergent", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + model = helper.make_model(graph, producer_name="test") + return gs.import_onnx(model) + + +class TestRegionPartitioner: + """Test RegionPartitioner basic functionality.""" + + def test_partition_linear_graph(self, simple_linear_graph): + """Test partitioning a simple linear graph.""" + partitioner = RegionPartitioner(simple_linear_graph) + + regions = partitioner.partition_graph() + + # Should create at least one region + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(simple_linear_graph.nodes) + + def test_partition_divergent_graph(self, divergent_graph): + """Test partitioning a divergent graph.""" + partitioner = RegionPartitioner(divergent_graph) + + regions = partitioner.partition_graph() + + # Should create regions covering all nodes + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(divergent_graph.nodes) + + +class TestTopDownRegionBuilder: + """Test TopDownRegionBuilder basic functionality.""" + + def test_build_composite_region(self, simple_linear_graph): + """Test building a composite region.""" + # First partition to get initial regions + partitioner = RegionPartitioner(simple_linear_graph) + initial_regions = partitioner.partition_graph() + + if len(initial_regions) > 0: + # Use first region as root for top-down building + root_region = initial_regions[0] + + builder = TopDownRegionBuilder(simple_linear_graph, root_region, next_region_id=100) + + # Build composite region (may return LEAF or COMPOSITE depending on structure) + composite = builder.build_composite_region() + + assert composite is not None + # Region type depends on whether refinement created internal structure + # For simple linear graphs, may stay as LEAF + assert composite.type in [RegionType.LEAF, RegionType.COMPOSITE] + else: + pytest.skip("No initial regions to refine") + + +class TestCombinedRegionSearch: + """Test CombinedRegionSearch two-phase algorithm.""" + + def test_search_linear_graph(self, simple_linear_graph): + """Test searching regions in a simple linear graph.""" + search = CombinedRegionSearch(simple_linear_graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(simple_linear_graph.nodes) + + # Each region should have valid inputs/outputs + for region in regions: + assert region.inputs is not None + assert region.outputs is not None + + def test_search_divergent_graph(self, divergent_graph): + """Test searching regions in a divergent graph.""" + search = CombinedRegionSearch(divergent_graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(divergent_graph.nodes) + + def test_region_hierarchy(self, simple_linear_graph): + """Test that regions have proper hierarchical structure.""" + search = CombinedRegionSearch(simple_linear_graph) + + regions = search.search_regions() + + # Check that regions have children (hierarchical structure) + for region in regions: + if region.type == RegionType.COMPOSITE: + assert len(region.get_children()) > 0 + + # Verify parent-child relationships + for child in region.get_children(): + assert child.parent == region + + def test_parameters(self, simple_linear_graph): + """Test CombinedRegionSearch with custom parameters.""" + # Test with different parameter values + search = CombinedRegionSearch( + simple_linear_graph, + maximum_sequence_region_size=5, + minimum_topdown_search_size=5, + ) + + regions = search.search_regions() + + assert len(regions) > 0 + + +class TestPrintTree: + """Test print_tree functionality.""" + + def test_print_tree_output_content(self, simple_linear_graph): + """Test that print_tree output contains region, node, and I/O information.""" + search = CombinedRegionSearch(simple_linear_graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + result = output.getvalue() + + # Region information + assert "Region" in result + assert "Level" in result + assert "Type:" in result + + # Node counts + assert "Direct nodes:" in result + assert "Total nodes:" in result + assert "Children:" in result + + # I/O information + assert "Inputs:" in result + assert "Outputs:" in result + + def test_print_tree_divergent_graph(self, divergent_graph): + """Test print_tree on a divergent graph with more complex structure.""" + search = CombinedRegionSearch(divergent_graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should produce valid output + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_max_nodes_to_show(self, simple_linear_graph): + """Test print_tree with custom max_nodes_to_show parameter.""" + search = CombinedRegionSearch(simple_linear_graph) + search.search_regions() + + # Test with different max_nodes_to_show values + output1 = io.StringIO() + search.print_tree(max_items=1, file=output1) + + output2 = io.StringIO() + search.print_tree(max_items=10, file=output2) + + # Both should produce output + assert len(output1.getvalue()) > 0 + assert len(output2.getvalue()) > 0 + + def test_print_tree_specific_region(self, simple_linear_graph): + """Test print_tree with a specific region instead of root.""" + search = CombinedRegionSearch(simple_linear_graph) + regions = search.search_regions() + + if len(regions) > 0: + # Print a specific region + output = io.StringIO() + search.print_tree(region=regions[0], file=output) + + result = output.getvalue() + assert "Region" in result + assert f"Region {regions[0].id}" in result + + def test_print_tree_partitioner(self, simple_linear_graph): + """Test print_tree on RegionPartitioner.""" + partitioner = RegionPartitioner(simple_linear_graph) + partitioner.partition_graph() + + output = io.StringIO() + partitioner.print_tree(file=output) + + result = output.getvalue() + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_top_down_builder(self, simple_linear_graph): + """Test print_tree on TopDownRegionBuilder.""" + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + root.nodes.update(range(len(simple_linear_graph.nodes))) + + builder = TopDownRegionBuilder(simple_linear_graph, root) + # Compute region I/O boundaries before building + builder.compute_region_boundaries(root) + builder.build_composite_region() + + output = io.StringIO() + builder.print_tree(file=output) + + result = output.getvalue() + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(result) + print("=" * 60) + + assert "Region" in result + assert len(result) > 0 diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index 2acc4046a1..42aa317119 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -630,3 +630,394 @@ def test_fp4qdq_conversion(self, with_transpose): # Verify Cast nodes are added for input type conversion cast_nodes = [node for node in converted_model.graph.node if node.op_type == "Cast"] assert len(cast_nodes) >= 1 # At least one cast node should be added + + +def create_test_model_with_int4_dq_matmul(): + """Create a simple test model with INT4 DequantizeLinear -> MatMul pattern. + + Returns the model and original weight/scale arrays for verification. + """ + from modelopt.onnx.quantization.quant_utils import pack_float32_to_4bit_cpp_based + + # Create INT4 quantized weight tensor (K=32, N=16) + # Using int8 storage for INT4 values in range [-8, 7] + weight_data = np.random.randint(-8, 8, size=(32, 16), dtype=np.int8) + + # Pack INT4 data (2 values per byte) for ORT compatibility + packed_weight = pack_float32_to_4bit_cpp_based(weight_data, signed=True).astype(np.int8) + weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=weight_data.shape, + vals=packed_weight.tobytes(), + raw=True, + ) + + # Create scale tensor for block quantization (block_size=32, so 1 scale per column) + scale_data = np.random.uniform(0.1, 1.0, size=(1, 16)).astype(np.float16) + scale_tensor = numpy_helper.from_array(scale_data, "scale") + + # Create input tensor for MatMul (batch=4, K=32) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [4, 32]) + + # Create DequantizeLinear node with INT4 blocked quantization + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=0, + block_size=32, + ) + + # Create MatMul node: input (4, 32) @ weight (32, 16) -> output (4, 16) + matmul_node = helper.make_node( + "MatMul", + inputs=["input", "dq_output"], + outputs=["output"], + name="matmul", + ) + + graph = helper.make_graph( + nodes=[dq_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[weight_tensor, scale_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # ORT only supports IR version up to 10 + return model, weight_data, scale_data + + +class TestColumnMajorTransformation: + """Test suite for column-major storage transformation functions.""" + + def test_column_major_transformation_graph_structure(self): + """Test that column-major transformation produces correct graph structure. + + Verifies: DQ(W) -> MatMul becomes DQ(W^T) -> Transpose -> MatMul + """ + import onnx_graphsurgeon as gs + + from modelopt.onnx.quantization.qdq_utils import ( + apply_column_major_transformation, + insert_transpose_nodes_for_column_major, + ) + + model, original_weight, original_scale = create_test_model_with_int4_dq_matmul() + + # Get weights and scales as dicts (simulating what int4.py does) + weights_dict = {"weight": original_weight.copy()} + scales_dict = {"scale": original_scale.copy()} + + # Apply column-major transformation (transposes in-place) + apply_column_major_transformation(weights_dict, scales_dict) + + # Verify weights and scales are transposed + assert weights_dict["weight"].shape == (16, 32), ( + f"Expected transposed weight shape (16, 32), got {weights_dict['weight'].shape}" + ) + assert scales_dict["scale"].shape == (16, 1), ( + f"Expected transposed scale shape (16, 1), got {scales_dict['scale'].shape}" + ) + + # Verify the transposed values match + assert np.array_equal(weights_dict["weight"], original_weight.T) + assert np.array_equal(scales_dict["scale"], original_scale.T) + + # Now test insert_transpose_nodes_for_column_major on a graph + # Create a fresh model and apply the full transformation + model2, _, _ = create_test_model_with_int4_dq_matmul() + graph2 = gs.import_onnx(model2) + + # Add transpose nodes for column-major + insert_transpose_nodes_for_column_major(graph2) + + # Export and verify structure + transformed_model = gs.export_onnx(graph2) + + # Check that Transpose node was added + node_types = [node.op_type for node in transformed_model.graph.node] + assert "Transpose" in node_types, "Transpose node should be added after DQ" + assert "DequantizeLinear" in node_types + assert "MatMul" in node_types + + # Verify the order: DQ -> Transpose -> MatMul + dq_node = next(n for n in transformed_model.graph.node if n.op_type == "DequantizeLinear") + transpose_node = next(n for n in transformed_model.graph.node if n.op_type == "Transpose") + matmul_node = next(n for n in transformed_model.graph.node if n.op_type == "MatMul") + + # DQ output should be Transpose input + assert dq_node.output[0] == transpose_node.input[0], "DQ output should feed into Transpose" + # Transpose output should be MatMul weight input + assert transpose_node.output[0] == matmul_node.input[1], ( + "Transpose output should feed into MatMul" + ) + + # Verify transpose permutation is [1, 0] + perm_attr = next((a for a in transpose_node.attribute if a.name == "perm"), None) + assert perm_attr is not None, "Transpose should have perm attribute" + assert list(perm_attr.ints) == [1, 0], "Transpose perm should be [1, 0]" + + def test_column_major_transformation_output_equivalence(self): + """Test that column-major transformed graph produces equivalent output. + + Creates two graphs: + 1. Original: DQ(W) -> MatMul + 2. Transformed: DQ(W^T) -> Transpose -> MatMul + + Verifies both produce the same output for the same input. + """ + import onnxruntime as ort + + from modelopt.onnx.quantization.quant_utils import pack_float32_to_4bit_cpp_based + + # Create original model + original_model, original_weight, original_scale = create_test_model_with_int4_dq_matmul() + + # Create input data + input_data = np.random.randn(4, 32).astype(np.float16) + + # Run original model + original_session = ort.InferenceSession(original_model.SerializeToString()) + original_output = original_session.run(None, {"input": input_data})[0] + + # Create transformed model + # We need to manually create a model with transposed weights + transposed_weight = original_weight.T.copy() # Shape: (16, 32) + transposed_scale = original_scale.T.copy() # Shape: (16, 1) + + # Pack INT4 data (2 values per byte) for ORT compatibility + packed_transposed_weight = pack_float32_to_4bit_cpp_based( + transposed_weight, signed=True + ).astype(np.int8) + weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=transposed_weight.shape, + vals=packed_transposed_weight.tobytes(), + raw=True, + ) + scale_tensor = numpy_helper.from_array(transposed_scale, "scale") + + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [4, 32]) + + # DQ node with axis=1 for column-major (transposed weight) + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=1, + block_size=32, + ) + + # Transpose node to convert back: (16, 32) -> (32, 16) + transpose_node = helper.make_node( + "Transpose", + inputs=["dq_output"], + outputs=["transpose_output"], + name="transpose_back", + perm=[1, 0], + ) + + # MatMul: input (4, 32) @ transposed_back (32, 16) -> output (4, 16) + matmul_node = helper.make_node( + "MatMul", + inputs=["input", "transpose_output"], + outputs=["output"], + name="matmul", + ) + + transformed_graph = helper.make_graph( + nodes=[dq_node, transpose_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[weight_tensor, scale_tensor], + ) + + transformed_model = helper.make_model( + transformed_graph, opset_imports=[helper.make_opsetid("", 21)] + ) + transformed_model.ir_version = 10 # ORT only supports IR version up to 10 + + # Run transformed model + transformed_session = ort.InferenceSession(transformed_model.SerializeToString()) + transformed_output = transformed_session.run(None, {"input": input_data})[0] + + # Print output values for visibility + print(f"Original model output shape: {original_output.shape}") + print(f"Transformed model output shape: {transformed_output.shape}") + print(f"Original output (first 5): {original_output.flatten()[:5]}") + print(f"Transformed output (first 5): {transformed_output.flatten()[:5]}") + + # Verify outputs are equivalent (allowing small numerical tolerance) + assert original_output.shape == transformed_output.shape, ( + f"Output shapes should match: {original_output.shape} vs {transformed_output.shape}" + ) + np.testing.assert_allclose( + original_output, + transformed_output, + rtol=1e-3, + atol=1e-3, + err_msg="Column-major transformed model should produce equivalent output", + ) + + def test_column_major_gemm_trans_b_flip(self): + """Test that Gemm with transB=1 gets flipped to transB=0 for column-major. + + When weights are already transposed (column-major), Gemm nodes with transB=1 + should have transB flipped to 0 instead of inserting a Transpose node. + Also verifies output equivalence between original and transformed models. + """ + import onnx_graphsurgeon as gs + import onnxruntime as ort + + from modelopt.onnx.quantization.qdq_utils import ( + apply_column_major_transformation, + insert_transpose_nodes_for_column_major, + ) + from modelopt.onnx.quantization.quant_utils import pack_float32_to_4bit_cpp_based + + # Original model: weight (N=16, K=32) with Gemm transB=1 + # Gemm computes: A @ B^T = (4, 32) @ (16, 32)^T = (4, 16) + weight_data = np.random.randint(-8, 8, size=(16, 32), dtype=np.int8) # Shape (N, K) + scale_data = np.random.uniform(0.1, 1.0, size=(16, 1)).astype(np.float16) + + # Pack INT4 data for original model + packed_weight = pack_float32_to_4bit_cpp_based(weight_data, signed=True).astype(np.int8) + weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=weight_data.shape, + vals=packed_weight.tobytes(), + raw=True, + ) + scale_tensor = numpy_helper.from_array(scale_data, "scale") + + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [4, 32]) + + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=1, + block_size=32, + ) + + gemm_node = helper.make_node( + "Gemm", + inputs=["input", "dq_output"], + outputs=["output"], + name="gemm", + transB=1, + ) + + graph = helper.make_graph( + nodes=[dq_node, gemm_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[weight_tensor, scale_tensor], + ) + + original_model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + original_model.ir_version = 10 + + # Run original model + input_data = np.random.randn(4, 32).astype(np.float16) + original_session = ort.InferenceSession(original_model.SerializeToString()) + original_output = original_session.run(None, {"input": input_data})[0] + + # Apply column-major transformation using the actual functions + weights_dict = {"weight": weight_data.copy()} + scales_dict = {"scale": scale_data.copy()} + apply_column_major_transformation(weights_dict, scales_dict) + + # Build transformed model with transposed weights/scales + transposed_weight = weights_dict["weight"] # Now (32, 16) + transposed_scale = scales_dict["scale"] # Now (1, 16) + + packed_transposed = pack_float32_to_4bit_cpp_based(transposed_weight, signed=True).astype( + np.int8 + ) + transposed_weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=transposed_weight.shape, + vals=packed_transposed.tobytes(), + raw=True, + ) + transposed_scale_tensor = numpy_helper.from_array(transposed_scale, "scale") + + # Build model with transposed weights but keep transB=1 initially + dq_node_col = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=0, + block_size=32, + ) + + gemm_node_col = helper.make_node( + "Gemm", + inputs=["input", "dq_output"], + outputs=["output"], + name="gemm", + transB=1, # Still transB=1, will be flipped by insert_transpose_nodes_for_column_major + ) + + col_graph = helper.make_graph( + nodes=[dq_node_col, gemm_node_col], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[transposed_weight_tensor, transposed_scale_tensor], + ) + + col_model = helper.make_model(col_graph, opset_imports=[helper.make_opsetid("", 21)]) + col_model.ir_version = 10 + + # Apply insert_transpose_nodes_for_column_major to flip transB + gs_graph = gs.import_onnx(col_model) + insert_transpose_nodes_for_column_major(gs_graph) + transformed_model = gs.export_onnx(gs_graph) + transformed_model.ir_version = 10 # ORT only supports IR version up to 10 + + # Verify transB was flipped to 0 + gemm_nodes = [n for n in transformed_model.graph.node if n.op_type == "Gemm"] + assert len(gemm_nodes) == 1 + trans_b_attr = next((a for a in gemm_nodes[0].attribute if a.name == "transB"), None) + trans_b_value = trans_b_attr.i if trans_b_attr else 0 + assert trans_b_value == 0, f"transB should be 0, got {trans_b_value}" + + # Verify no Transpose node was added + transpose_nodes = [n for n in transformed_model.graph.node if n.op_type == "Transpose"] + assert len(transpose_nodes) == 0, ( + f"No Transpose should be added, found {len(transpose_nodes)}" + ) + + # Run transformed model and verify output equivalence + transformed_session = ort.InferenceSession(transformed_model.SerializeToString()) + transformed_output = transformed_session.run(None, {"input": input_data})[0] + + print(f"Original model output shape: {original_output.shape}") + print(f"Transformed model output shape: {transformed_output.shape}") + print(f"Original output (first 5): {original_output.flatten()[:5]}") + print(f"Transformed output (first 5): {transformed_output.flatten()[:5]}") + + np.testing.assert_allclose( + original_output, + transformed_output, + rtol=1e-3, + atol=1e-3, + err_msg="Gemm transB flip should produce equivalent output", + ) + + print(f"transB flipped: 1 -> {trans_b_value}") + print(f"Transpose nodes: {len(transpose_nodes)}") diff --git a/tests/unit/onnx/test_quantize_api.py b/tests/unit/onnx/test_quantize_api.py new file mode 100644 index 0000000000..464fb1a88b --- /dev/null +++ b/tests/unit/onnx/test_quantize_api.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ONNX quantization opset handling.""" + +import os + +import onnx +import onnxruntime +import pytest +import torch +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx +from packaging import version + +import modelopt.onnx.quantization as moq +from modelopt.onnx.utils import get_opset_version + +# Mapping of quantization mode to minimum required opset +MIN_OPSET = { + "int8": 19, + "fp8": 19, + "int4": 21, +} + +# onnxruntime version that supports opset 22+ +ORT_VERSION_FOR_OPSET_22 = version.parse("1.23.0") + + +# Test scenarios: (scenario_name, export_opset_offset, request_opset_offset, expected_opset_offset) +# Offsets are relative to MIN_OPSET[quant_mode]. +OPSET_SCENARIOS = [ + # Requesting opset below minimum should upgrade to minimum + ("below_min_upgrades", -1, -1, 0), + # Requesting opset below original model's opset (but above minimum) should preserve original + ("below_original_preserves", 1, 0, 1), + # Requesting opset above minimum should be respected + ("above_min_respected", 0, 1, 1), +] + + +@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) +@pytest.mark.parametrize( + ("scenario_name", "export_opset_offset", "request_opset_offset", "expected_opset_offset"), + OPSET_SCENARIOS, + ids=[s[0] for s in OPSET_SCENARIOS], +) +def test_quantize_opset_handling( + tmp_path, + quant_mode, + scenario_name, + export_opset_offset, + request_opset_offset, + expected_opset_offset, +): + """Test opset handling in quantization API. + + Scenarios: + - below_min_upgrades: Requesting opset below minimum upgrades to minimum. + - below_original_preserves: Requesting opset below original model's opset preserves original. + - above_min_respected: Requesting opset at or above minimum is respected. + """ + min_opset = MIN_OPSET[quant_mode] + + # Calculate actual opset values from offsets + export_opset = min_opset + export_opset_offset + request_opset = min_opset + request_opset_offset + expected_opset = min_opset + expected_opset_offset + + # Skip if required opset exceeds onnxruntime support + max_opset = max(export_opset, request_opset, expected_opset) + if max_opset >= 22: + ort_version = version.parse(onnxruntime.__version__) + if ort_version < ORT_VERSION_FOR_OPSET_22: + pytest.skip( + f"Opset {max_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}" + ) + + # Setup: create and export model + model_torch = SimpleMLP() + input_tensor = torch.randn(2, 16, 16) + onnx_path = os.path.join(tmp_path, "model.onnx") + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=export_opset) + + # Run quantization + moq.quantize(onnx_path, quantize_mode=quant_mode, opset=request_opset) + + # Verify output opset + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + output_model = onnx.load(output_onnx_path) + output_opset = get_opset_version(output_model) + + assert output_opset == expected_opset, ( + f"[{scenario_name}] Expected opset {expected_opset} for {quant_mode}, got {output_opset}" + ) diff --git a/tests/unit/torch/distill/test_distill.py b/tests/unit/torch/distill/test_distill.py index 10241f076d..69dec86b7f 100644 --- a/tests/unit/torch/distill/test_distill.py +++ b/tests/unit/torch/distill/test_distill.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input -from torch.nn.modules.loss import _Loss as Loss from torchvision.models import alexnet import modelopt.torch.distill as mtd @@ -37,7 +36,7 @@ def tiny_mobilenet(): def tiny_alexnet(): - return alexnet(num_classes=10) # Same class as tiny_mobilenet + return alexnet(num_classes=10) # same num classes as tiny_mobilenet @pytest.fixture @@ -168,13 +167,6 @@ def test_distillation_export(distillation_model, tmp_path): assert not hasattr(model_exported, "_teacher_model") assert hasattr(model_exported, mto.ModeloptStateManager._state_key) - # Test if kd_loss config has been cleaned up - manager = mto.ModeloptStateManager(model_exported) - cfg = manager._state[-2][1]["config"] - assert cfg["teacher_model"] == nn.Module - assert isinstance(next(iter(cfg["criterion"].values())), Loss) - assert cfg["loss_balancer"] is None - mto.save(model_exported, tmp_path / "ckpt.pt") new_student = tiny_mobilenet() new_student_restored = mto.restore(new_student, tmp_path / "ckpt.pt") diff --git a/tests/unit/torch/distill/test_layerwise.py b/tests/unit/torch/distill/test_layerwise.py new file mode 100644 index 0000000000..aea1dd6350 --- /dev/null +++ b/tests/unit/torch/distill/test_layerwise.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import pytest +import torch +from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto + + +def get_input_tensor(): + """Dummy input tensor.""" + return torch.rand(2, 3, 112, 112) + + +def tiny_mobilenet(): + return get_tiny_mobilenet_and_input()[0] + + +@pytest.fixture +def layerwise_distillation_model(): + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer(), + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + return layerwise_model + + +def test_layerwise_hooks_registration(layerwise_distillation_model): + """Test that layerwise-specific hooks are registered correctly.""" + # Check that student layers have _teacher_layer attribute + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + assert student_layer._teacher_layer[0] is teacher_layer + assert hasattr(student_layer, "_intermediate_output") + + # Check that teacher layers have both input and output capture attributes + assert hasattr(teacher_layer, "_intermediate_input") + assert hasattr(teacher_layer, "_intermediate_output") + + +def test_layerwise_forward_pass(layerwise_distillation_model): + """Test that forward pass works and captures both teacher inputs and outputs.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + layerwise_distillation_model(input_tensor) + + # Check that teacher intermediate inputs and outputs are captured + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is not None + + +def test_layerwise_input_injection(layerwise_distillation_model): + """Test that teacher inputs are injected into student layers during layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + # Perform forward pass + layerwise_distillation_model(input_tensor) + + # Verify that teacher inputs were captured (they should be reset after injection) + # After forward, teacher inputs should have been consumed by student layers + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + # After full forward pass, teacher_layer._intermediate_input should be None + # because it gets consumed by the student layerwise hook + assert teacher_layer._intermediate_input is None + + +def test_layerwise_loss_computation(layerwise_distillation_model): + """Test that loss computation works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + output = layerwise_distillation_model(input_tensor) + loss = layerwise_distillation_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + assert loss.requires_grad + + +def test_layerwise_only_student_forward(layerwise_distillation_model): + """Test that only_student_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + # When using only_student_forward, teacher inputs should not be captured + with warnings.catch_warnings(record=True) as w: + with layerwise_distillation_model.only_student_forward(): + layerwise_distillation_model(input_tensor) + + # Should get warning about missing teacher input + warning_messages = [str(warning.message) for warning in w] + assert any("has no intermediate input stored" in msg for msg in warning_messages) + + # Verify teacher didn't run + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is None + assert teacher_layer._intermediate_output is None + assert student_layer._intermediate_output is not None + + +def test_layerwise_only_teacher_forward(layerwise_distillation_model): + """Test that only_teacher_forward context manager works with layerwise distillation.""" + layerwise_distillation_model.train() + input_tensor = get_input_tensor() + + with layerwise_distillation_model.only_teacher_forward(): + layerwise_distillation_model(input_tensor) + + # Verify teacher ran and student didn't + for student_layer, teacher_layer in layerwise_distillation_model._layers_to_loss: + assert teacher_layer._intermediate_input is not None + assert teacher_layer._intermediate_output is not None + assert student_layer._intermediate_output is None + + +def test_layerwise_export(layerwise_distillation_model): + """Test that export correctly cleans up layerwise-specific attributes.""" + # Check that _teacher_layer exists before export + for student_layer, _ in layerwise_distillation_model._layers_to_loss: + assert hasattr(student_layer, "_teacher_layer") + + # Export the model + exported_model = mtd.export(layerwise_distillation_model) + + # Check that _teacher_layer is removed after export + for student_layer in exported_model.modules(): + assert not hasattr(student_layer, "_teacher_layer") + + assert not hasattr(exported_model, "_teacher_model") + assert not isinstance(exported_model, mtd.LayerwiseDistillationModel) + + +def test_layerwise_save_restore(layerwise_distillation_model, tmp_path): + """Test that save/restore works correctly with layerwise distillation.""" + mto.save(layerwise_distillation_model, tmp_path / "ckpt.pt") + + new_student = tiny_mobilenet() + restored_model = mto.restore(new_student, tmp_path / "ckpt.pt") + + # Ensure state is not actually restored (expected behavior from test_distill.py) + manager = mto.ModeloptStateManager(restored_model) + assert not manager.has_state + assert isinstance(restored_model, type(new_student)) + + +def test_layerwise_multiloss(): + """Test layerwise distillation with multiple loss functions.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.1", "features.1"): torch.nn.MSELoss(), + ("features.3", "features.3"): torch.nn.MSELoss(), + }, + "loss_balancer": mtd.StaticLossBalancer([0.5, 0.5]), + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + # Verify hooks are registered for all layers + assert len(layerwise_model._layers_to_loss) == 2 + + # Test forward pass + output = layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss(student_loss=output.mean()) + + assert isinstance(loss, torch.Tensor) + assert loss.numel() == 1 + + +def test_layerwise_gradient_flow(): + """Test that gradients flow correctly through layerwise distillation.""" + student = tiny_mobilenet().train() + config = { + "teacher_model": tiny_mobilenet(), + "criterion": { + ("features.2", "features.2"): torch.nn.MSELoss(), + }, + "loss_balancer": None, + } + layerwise_model = mtd.convert(student, mode=[("layerwise_kd", config)]) + + # Save param snapshots by module + param_snapshots = { + name: p.clone() for name, p in layerwise_model.named_parameters() if p.requires_grad + } + + # Forward and backward + optimizer = torch.optim.SGD(layerwise_model.parameters(), lr=0.5) + optimizer.zero_grad() + layerwise_model(get_input_tensor()) + loss = layerwise_model.compute_kd_loss() + loss.backward() + optimizer.step() + + # Check: parameters in only the target layer(s) are changed + updated_any = False + for name, param in layerwise_model.named_parameters(): + if not param.requires_grad: + continue + changed = not torch.allclose(param, param_snapshots[name]) + if "features.2" in name: + assert changed, f"'{name}' parameters did not change!" + updated_any = True + else: + assert not changed, f"Parameters in unrelated layer '{name}' changed!" + assert updated_any, ( + "No parameters were updated in 'features.2' or related layers during training" + ) diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py new file mode 100644 index 0000000000..4fa85bba8b --- /dev/null +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest +from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet + +pytest.importorskip("diffusers") + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format +from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs +from modelopt.torch.export.unified_export_hf import export_hf_checkpoint + + +def _load_config(config_path): + with open(config_path) as file: + return json.load(file) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_models_non_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}" + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" not in config_data + + +def test_export_diffusers_unet_quantized_matches_llm_config(tmp_path, monkeypatch): + model = get_tiny_unet() + export_dir = tmp_path / "export_unet_quant" + + import modelopt.torch.export.unified_export_hf as unified_export_hf + + monkeypatch.setattr(unified_export_hf, "has_quantized_modules", lambda *_: True) + + fuse_calls = {"count": 0} + process_calls = {"count": 0} + + def _fuse_stub(*_args, **_kwargs): + fuse_calls["count"] += 1 + + def _process_stub(*_args, **_kwargs): + process_calls["count"] += 1 + + monkeypatch.setattr(unified_export_hf, "_fuse_qkv_linears_diffusion", _fuse_stub) + monkeypatch.setattr(unified_export_hf, "_process_quantized_modules", _process_stub) + + dummy_quant_config = { + "quantization": {"quant_algo": "FP8", "kv_cache_quant_algo": "FP8"}, + "producer": {"name": "modelopt", "version": "0.0"}, + } + monkeypatch.setattr( + unified_export_hf, "get_quant_config", lambda *_args, **_kwargs: dummy_quant_config + ) + + export_hf_checkpoint(model, export_dir=export_dir) + + assert fuse_calls["count"] == 1 + assert process_calls["count"] == 1 + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data + assert config_data["quantization_config"] == convert_hf_quant_config_format(dummy_quant_config) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_real_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}_real_quant" + + def _calib_fn(m): + param = next(m.parameters()) + dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype) + assert dummy_inputs is not None + m(**dummy_inputs) + + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_fn) + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data diff --git a/tests/unit/torch/nas/test_evaluate_constraints.py b/tests/unit/torch/nas/test_evaluate_constraints.py index 9afbbdf5c5..909f7113a2 100644 --- a/tests/unit/torch/nas/test_evaluate_constraints.py +++ b/tests/unit/torch/nas/test_evaluate_constraints.py @@ -14,13 +14,14 @@ # limitations under the License. import pytest +from _test_utils.torch.nas_prune.utils import param_num from _test_utils.torch.vision_models import get_tiny_mobilenet_and_input, get_tiny_resnet_and_input pytest.importorskip("torchprofile") from torchprofile import profile_macs from modelopt.torch.nas.algorithms import ConstraintsFunc -from modelopt.torch.utils import param_num, remove_bn +from modelopt.torch.utils import remove_bn try: from _test_utils.torch.deploy.runtime import FAKE_DEPLOYMENT, fake_latency diff --git a/tests/unit/torch/opt/test_chaining.py b/tests/unit/torch/opt/test_chaining.py index bedbbfee02..3bc294f3b1 100644 --- a/tests/unit/torch/opt/test_chaining.py +++ b/tests/unit/torch/opt/test_chaining.py @@ -15,6 +15,7 @@ import pytest import torch +import torch.nn.functional as F from _test_utils.torch.misc import compare_outputs from _test_utils.torch.opt.utils import apply_mode_with_sampling from torchvision.models.mobilenetv2 import InvertedResidual @@ -22,10 +23,20 @@ import modelopt.torch.distill as mtd import modelopt.torch.nas as mtn import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts from modelopt.torch.utils.distributed import _serialize +class SimpleLinearModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + def get_model(): return InvertedResidual(16, 32, 1, 6) @@ -228,3 +239,101 @@ def test_sparse_quantized_module(): model = mtn.export(model) assert torch.equal(conv.weight, weight_expected) assert torch.equal(conv._parameters["weight"], weight_expected), "Weight should be overwritten!" + + +def test_sparse_quantize_kd_linear_forward_backward(): + """Ensure sparse + quantize + distill works for linear forward/backward.""" + model = SimpleLinearModel() + teacher_model = SimpleLinearModel() + + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0} + + def _make_patched_forward(linear): + def patched_forward(x): + called["patched_forward"] += 1 + w = linear.weight + b = linear.bias if linear.bias is not None else None + return F.linear(x, w, b) + + return patched_forward + + model.linear.forward = _make_patched_forward(model.linear) + teacher_model.linear.forward = _make_patched_forward(teacher_model.linear) + + def _get_linear_kd_mode(): + config = { + "teacher_model": teacher_model, + "criterion": {("linear", "linear"): mtd.LogitsDistillationLoss()}, + "loss_balancer": mtd.StaticLossBalancer(), + } + return [("kd_loss", config)] + + model = mto.apply_mode(model, mode="sparse_magnitude", init_state=True) + model = mto.apply_mode(model, mode="quantize") + model = mto.apply_mode(model, mode=_get_linear_kd_mode()) + + def _count_quant_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_quant_weight(_m, _inp, _out): + called["weight_q"] += 1 + + model.linear.input_quantizer.register_forward_hook(_count_quant_input) + model.linear.weight_quantizer.register_forward_hook(_count_quant_weight) + + model.train() + x = torch.randn(2, 4) + target = torch.randn(2, 4) + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + + assert output.shape == target.shape + assert any(p.grad is not None for p in model.parameters() if p.requires_grad), ( + "Expected gradients on student parameters." + ) + assert called["patched_forward"] == 2 + assert called["input_q"] == 1 + assert called["weight_q"] == 1 + + +def test_chained_modes_preserve_forward_patching_during_quantize(): + """Ensure chained modes do not break runtime forward patching during quantize.""" + model = InvertedResidual(16, 32, 1, 6).to(torch.float16) + model = mto.apply_mode(model, mode="fastnas", init_state=True) + model = mto.apply_mode(model, mode="export_nas") + + conv = model.conv[0][0] + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0} + + def patched_forward(x): + called["patched_forward"] += 1 + return F.conv2d( + x, + conv.weight, + conv.bias, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + ) + + conv.forward = patched_forward + + def _count_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_weight(_m, _inp, _out): + called["weight_q"] += 1 + + def forward_loop(model): + conv.input_quantizer.register_forward_hook(_count_input) + conv.weight_quantizer.register_forward_hook(_count_weight) + x = torch.randn(1, 16, 8, 8, dtype=torch.float16) + model(x) + + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop) + + assert called["patched_forward"] == 1 + assert called["input_q"] == 1 + assert called["weight_q"] == 1 diff --git a/tests/unit/torch/quantization/plugins/test_accelerate.py b/tests/unit/torch/quantization/plugins/test_accelerate.py index 0c81ba4570..df5a4701d4 100644 --- a/tests/unit/torch/quantization/plugins/test_accelerate.py +++ b/tests/unit/torch/quantization/plugins/test_accelerate.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle + import pytest import torch import torch.nn as nn import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.nn import QuantLinearConvBase +from modelopt.torch.quantization.nn import QuantLinearConvBase, TensorQuantizer try: from accelerate.hooks import ModelHook, add_hook_to_module @@ -51,3 +53,30 @@ def test_linear_with_accelerate_monkey_patched_forward(): assert module_test.input_quantizer.amax is not None assert module_test.weight_quantizer.amax is not None + + +def test_tensor_quantizer_modelopt_state_with_accelerate_hook(): + """Verify accelerate hook attributes are excluded from modelopt state. + + When accelerate's add_hook_to_module patches a TensorQuantizer, it adds + _hf_hook, _old_forward, and an instance-level forward (a functools.partial + wrapping a local function). These must be excluded from the modelopt state + dict, otherwise torch.save / pickle will fail with: + AttributeError: Can't get local object 'add_hook_to_module..new_forward' + """ + tq = TensorQuantizer() + add_hook_to_module(tq, ModelHook()) + + # The hook should have injected these instance attributes + assert hasattr(tq, "_hf_hook") + assert hasattr(tq, "_old_forward") + assert "forward" in tq.__dict__ + + # None of the accelerate attributes should appear in the modelopt state + state = tq.get_modelopt_state() + accelerate_attrs = {"_hf_hook", "_old_forward", "forward"} + leaked = accelerate_attrs & state.keys() + assert not leaked, f"Accelerate attributes leaked into modelopt state: {leaked}" + + # The state dict must be picklable (torch.save uses pickle internally) + pickle.dumps(state) diff --git a/tests/unit/torch/quantization/plugins/test_sparse_moe.py b/tests/unit/torch/quantization/plugins/test_sparse_moe.py new file mode 100644 index 0000000000..6d548aa400 --- /dev/null +++ b/tests/unit/torch/quantization/plugins/test_sparse_moe.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _is_sparse_moe_block and _QuantSparseMoe.""" + +import pytest +import torch +import torch.nn as nn + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import get_tiny_qwen3_moe + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.plugins.huggingface import ( + TRANSFORMERS_VERSION_GE_5_0, + _is_sparse_moe_block, + register_sparse_moe_on_the_fly, +) + + +# --------------------------------------------------------------------------- +# Helpers: lightweight mock modules for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class _FakeGateWithRouter(nn.Module): + """Mimics a v5.x TopKRouter gate with top_k and num_experts.""" + + def __init__(self, top_k=2, num_experts=4): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.linear = nn.Linear(8, num_experts) + + def forward(self, x): + return self.linear(x) + + +class _FakeExperts(nn.ModuleList): + def __init__(self, n=4): + super().__init__([nn.Linear(8, 8) for _ in range(n)]) + self.num_experts = n + + +class _MoEBlockWithGateRouter(nn.Module): + """Matches the primary detection path: gate.top_k + gate.num_experts.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.gate = _FakeGateWithRouter(top_k=top_k, num_experts=num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.gate.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.gate.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +class _MoEBlockFallback(nn.Module): + """Matches the fallback path: top_k + num_experts on the block itself.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(8, num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +# --------------------------------------------------------------------------- +# Tests for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class TestIsSparseBlock: + def test_no_experts_returns_false(self): + module = nn.Linear(8, 8) + assert _is_sparse_moe_block(module) is False + + def test_experts_but_no_gate_or_topk_returns_false(self): + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + assert _is_sparse_moe_block(module) is False + + def test_gate_with_router_attrs_returns_true(self): + block = _MoEBlockWithGateRouter(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_fallback_block_level_attrs_returns_true(self): + block = _MoEBlockFallback(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_gate_missing_num_experts_returns_false(self): + """gate.top_k present but gate.num_experts absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_gate_missing_top_k_returns_false(self): + """gate.num_experts present but gate.top_k absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.num_experts = 4 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_top_k_returns_false(self): + """Only top_k on block (no num_experts) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.top_k = 2 + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_num_experts_returns_false(self): + """Only num_experts on block (no top_k) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.num_experts = 4 + assert _is_sparse_moe_block(module) is False + + def test_glm4_like_block_rejected(self): + """A module with n_routed_experts instead of num_experts should be rejected.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + gate.n_routed_experts = 4 # different attr name + module.gate = gate + assert _is_sparse_moe_block(module) is False + + +# --------------------------------------------------------------------------- +# Tests for _QuantSparseMoe +# --------------------------------------------------------------------------- +class TestQuantSparseMoe: + """Tests for _QuantSparseMoe using a real tiny Qwen3Moe model.""" + + @staticmethod + def _get_moe_block(model): + """Return the first MoE block from the model.""" + for module in model.modules(): + if _is_sparse_moe_block(module): + return module + raise RuntimeError("No MoE block found in model") + + def test_register_sparse_moe_on_the_fly(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is not None: + pytest.skip("MoE type already registered (upstream change)") + + register_sparse_moe_on_the_fly(model) + assert QuantModuleRegistry.get(moe_type) is not None + + def test_setup_creates_expert_token_count(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert hasattr(converted, "expert_token_count") + if hasattr(moe_block, "gate") and hasattr(moe_block.gate, "num_experts"): + expected_num_experts = moe_block.gate.num_experts + elif hasattr(moe_block, "num_experts"): + expected_num_experts = moe_block.num_experts + elif hasattr(moe_block, "experts") and hasattr(moe_block.experts, "num_experts"): + expected_num_experts = moe_block.experts.num_experts + else: + expected_num_experts = 0 + assert converted.expert_token_count.shape == (expected_num_experts,) + assert converted.expert_token_count.dtype == torch.long + assert (converted.expert_token_count == 0).all() + + def test_setup_count_expert_tokens_default_false(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert converted._count_expert_tokens is False + + def test_forward_no_calib_matches_original(self): + """When calibration is off, _QuantSparseMoe should produce the same output as the original.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + ref_block = self._get_moe_block(get_tiny_qwen3_moe()) + ref_block.load_state_dict(moe_block.state_dict()) + + converted = QuantModuleRegistry.convert(moe_block) + + torch.manual_seed(42) + x = torch.randn(1, 4, 32) + with torch.no_grad(): + out_ref = ref_block(x) + out_test = converted(x) + + if isinstance(out_ref, tuple): + out_ref = out_ref[0] + if isinstance(out_test, tuple): + out_test = out_test[0] + assert torch.allclose(out_ref, out_test, atol=1e-5) + + def test_forward_calib_sends_all_tokens_to_all_experts(self): + """During calibration, all experts should see tokens (expert_token_count all > 0).""" + model = get_tiny_qwen3_moe() + register_sparse_moe_on_the_fly(model) + + def calib_fn(model): + x = model.dummy_inputs["input_ids"] + model(x) + + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn) + + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + assert (module.expert_token_count > 0).all(), ( + f"Not all experts received tokens in {name}: {module.expert_token_count}" + ) + + def test_forward_calib_restores_top_k(self): + """After calibration forward, top_k should be restored to its original value.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + if TRANSFORMERS_VERSION_GE_5_0: + original_top_k = moe_block.gate.top_k + else: + original_top_k = moe_block.top_k + + converted = QuantModuleRegistry.convert(moe_block) + + # Simulate calibration mode: set _if_calib on a child TensorQuantizer + for m in converted.experts.modules(): + if hasattr(m, "_if_calib"): + m._if_calib = True + break + + x = torch.randn(1, 4, 32) + with torch.no_grad(): + converted(x) + + if TRANSFORMERS_VERSION_GE_5_0: + assert converted.gate.top_k == original_top_k + else: + assert converted.top_k == original_top_k + + def test_gate_forward_hook_counts_tokens(self): + """Verify the gate forward hook correctly counts expert token assignments.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + + # Reset counts and enable counting + converted.expert_token_count.zero_() + converted._count_expert_tokens = True + + if TRANSFORMERS_VERSION_GE_5_0: + hidden_size = converted.gate.weight.shape[1] + top_k = converted.gate.top_k + else: + hidden_size = converted.gate.in_features + top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k + + x = torch.randn(8, hidden_size) + with torch.no_grad(): + converted.gate(x) + total_assigned = converted.expert_token_count.sum().item() + assert total_assigned == 8 * top_k + + # Disable counting and verify counts don't change + converted._count_expert_tokens = False + prev_counts = converted.expert_token_count.clone() + with torch.no_grad(): + converted.gate(x) + assert torch.equal(converted.expert_token_count, prev_counts) diff --git a/tests/unit/torch/quantization/test_forward_patching.py b/tests/unit/torch/quantization/test_forward_patching.py new file mode 100644 index 0000000000..a28a1bce6a --- /dev/null +++ b/tests/unit/torch/quantization/test_forward_patching.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import torch +import torch.nn.functional as F +from torch import nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization import QuantModuleRegistry +from modelopt.torch.quantization.nn.modules.quant_module import QuantLinearConvBase + + +def test_quant_input_base_ignores_forward_pre_dm_in_mro(): + """Regression test for recursion when `_forward_pre_dm` points to a wrapper forward in the MRO. + + In complex wrapper stacks, `_forward_pre_dm` may accidentally end up referencing a `forward` + method already present in the quant wrapper MRO (e.g. QuantLinearConvBase.forward). If + QuantInputBase.forward calls that directly, it can recurse indefinitely: + + QuantLinearConvBase.forward -> super().forward (QuantInputBase.forward) + -> _forward_pre_dm (QuantLinearConvBase.forward) -> ... + + The fix is to detect this case and fall back to `super().forward` instead. + """ + lin = nn.Linear(8, 8, bias=False) + QuantModuleRegistry.convert(lin) + + # Force the problematic state: `_forward_pre_dm` points to a wrapper forward already in MRO. + lin._forward_pre_dm = types.MethodType(QuantLinearConvBase.forward, lin) + + x = torch.randn(2, 8) + y = lin(x) + assert isinstance(y, torch.Tensor) + assert y.shape == (2, 8) + + +def test_quantize_calibration_calls_quantizers_with_runtime_forward_patch(): + """Regression test for on-the-fly forward patching during mtq.quantize calibration. + + Some frameworks replace `module.forward` on-the-fly with a closure just before a forward pass. + During mtq.quantize calibration, quantizers must still run (input + weight at minimum). + """ + lin = nn.Linear(8, 8, bias=True).to(torch.float32) + + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0} + + # Monkey patch instance-level forward (closure-style, no `self` argument). + def patched_forward(x): + called["patched_forward"] += 1 + # Use module parameters directly; if quantization wrappers are active, weight access + # should still be routed through the quantized path. + w = lin.weight.to(dtype=x.dtype) + b = lin.bias.to(dtype=x.dtype) if lin.bias is not None else None + return F.linear(x, w, b) + + def _count_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_weight(_m, _inp, _out): + called["weight_q"] += 1 + + lin.forward = patched_forward + x = torch.randn(2, 8, dtype=torch.float16) + + def forward_loop(model): + # Patch forward on-the-fly (after conversion, right before calibration forward). + + # Count quantizer executions during calibration. + model.input_quantizer.register_forward_hook(_count_input) + model.weight_quantizer.register_forward_hook(_count_weight) + + model(x) + + mtq.quantize(lin, mtq.INT8_DEFAULT_CFG, forward_loop) + lin(x) + + assert called["patched_forward"] == 2 + assert called["input_q"] == 2 + assert called["weight_q"] == 2 diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 26e7d52da1..5e55465120 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -68,7 +68,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -115,7 +115,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=25, + step_size=0.045, start_multiplier=0.1, stop_multiplier=1.2, quant_func=quant_func, @@ -162,7 +162,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=50, + step_size=0.008, start_multiplier=0.8, stop_multiplier=1.2, quant_func=quant_func, @@ -214,7 +214,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -265,7 +265,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=15, + step_size=0.07, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -307,7 +307,7 @@ def quant_func(x, amax): tq._if_calib = was_calib_enabled return xq - cal = calib.MseCalibrator(amax=initial_amax, num_steps=10, quant_func=quant_func) + cal = calib.MseCalibrator(amax=initial_amax, step_size=0.4, quant_func=quant_func) cal.collect(x) @@ -352,7 +352,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -398,7 +398,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=15, + step_size=0.1, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -458,7 +458,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=20, + step_size=0.05, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -511,7 +511,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 5bc39a517d..43233b3239 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -45,6 +45,7 @@ "algorithm": "awq_lite", } +# Test configs for per channel MSE calibration INT8_MSE_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py index b487d86394..ce2fa3da2a 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -30,7 +30,7 @@ def test_phase_inference(self): """Test phase detection from attention score shape.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -70,30 +70,11 @@ def test_threshold_update_dict_config(self): method._update_threshold("prefill") assert method.threshold == 1e-3 - def test_threshold_update_static_config(self): - """Test threshold with static float config.""" - method = FlashSkipSoftmax( - { - "threshold": 5e-4, - "br": 128, - "bc": 128, - "backend": "pytorch", - "is_causal": True, - } - ) - - initial_threshold = method.threshold - assert initial_threshold == 5e-4 - - # Should not change for static config - method._update_threshold("decode") - assert method.threshold == 5e-4 - def test_block_reshaping_divisible(self): """Test block reshaping with divisible sequence lengths.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -116,7 +97,7 @@ def test_block_reshaping_with_padding(self): """Test block reshaping with non-divisible lengths.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -139,7 +120,7 @@ def test_correction_factor_calculation_prefill(self): """Test correction factor for prefill phase.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -166,7 +147,7 @@ def test_correction_factor_calculation_decode(self): """Test correction factor for decode phase.""" method = FlashSkipSoftmax( { - "threshold": 1e-5, + "threshold": {"prefill": 1e-3, "decode": 1e-5}, "br": 128, "bc": 128, "backend": "pytorch", @@ -185,32 +166,11 @@ def test_correction_factor_calculation_decode(self): assert 0 <= stats["sparsity"] <= 1 assert mask.shape == (1, 1, 1, 256) - def test_sparsity_statistics(self): - """Test sparsity statistics structure.""" - method = FlashSkipSoftmax( - { - "threshold": 1e-3, - "br": 128, - "bc": 128, - "backend": "pytorch", - "is_causal": True, - } - ) - - attn = torch.randn(1, 1, 128, 256) - _, stats = method.calc_correction_factor_and_p(attn, "prefill") - - # Verify statistics are present - assert stats["total_blocks"] > 0 - assert "sparse_blocks" in stats - assert "sample_length" in stats - assert stats["sample_length"] == 256 - def test_block_mask_correctness(self): """Test block mask shape and type.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -229,7 +189,7 @@ def test_block_mask_correctness(self): def test_causal_vs_noncausal(self): """Test total_blocks calculation for causal vs non-causal.""" config_base = { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -248,11 +208,11 @@ def test_causal_vs_noncausal(self): assert stats_causal["total_blocks"] == 3 assert stats_noncausal["total_blocks"] == 4 - def test_apply_sparsity_assertions(self): - """Test apply_sparsity input validation.""" + def test_calculate_sparsity_assertions(self): + """Test calculate_sparsity input validation.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", @@ -260,23 +220,53 @@ def test_apply_sparsity_assertions(self): } ) - # Test: attention_scores required - with pytest.raises(AssertionError, match="attention_scores must be provided"): - method.apply_sparsity() - # Test: 4D shape required with pytest.raises(AssertionError, match="Expected 4D"): - method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + method.calculate_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D - def test_name_property(self): - """Test method name property.""" + def test_apply_sparsity_with_mask(self): + """Test apply_sparsity with pre-computed mask.""" method = FlashSkipSoftmax( { - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, "bc": 128, "backend": "pytorch", "is_causal": True, } ) - assert method.name == "flash_skip_softmax" + + attn = torch.randn(2, 4, 128, 256) + + # Calculate sparsity first + sparse_mask, stats = method.calculate_sparsity(attn) + + # Apply sparsity with pre-computed mask + sparse_attn = method.apply_sparsity(attn, sparse_mask) + + # Verify output shape matches input + assert sparse_attn.shape == attn.shape + + # Verify masked positions have min value + mask_value = torch.finfo(attn.dtype).min + assert (sparse_attn[~sparse_mask] == mask_value).all() + + def test_apply_sparsity_without_mask(self): + """Test apply_sparsity calculates mask internally when None.""" + method = FlashSkipSoftmax( + { + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + + # Apply sparsity without pre-computed mask + sparse_attn = method.apply_sparsity(attn) + + # Verify output shape matches input + assert sparse_attn.shape == attn.shape diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 0000000000..b91ec40cf0 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import pytest + +pytest.importorskip("transformers") + +import numpy as np +from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.calibrate import ( + _extract_calibration_config, + calibrate_sparse_attention, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.ruler_dataset import ( + _generate_target_lengths, +) +from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness (regression calculations).""" + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "calibration_params", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, + "samples": 4, + "max_seqlen": 1024, + }, + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 64, + "bc": 64, + "enable": True, + }, + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + # Valid config + config = CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == {"prefill": 0.5, "decode": 0.5} + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio.*must be between 0.0 and 1.0"): + CalibrationConfig( + target_sparse_ratio={"prefill": 1.5, "decode": 0.5}, samples=48, max_seqlen=32768 + ) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio.*must be between 0.0 and 1.0"): + CalibrationConfig( + target_sparse_ratio={"prefill": -0.1, "decode": 0.5}, samples=48, max_seqlen=32768 + ) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, samples=0, max_seqlen=32768 + ) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, samples=48, max_seqlen=512 + ) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio={"prefill": 0.5, "decode": 0.5}, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio={"prefill": 0.5, "decode": 0.5}) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +class TestDynamicThresholdCalibratorMethods: + """Test error paths and edge cases of DynamicThresholdCalibrator.""" + + def test_calibrate_no_sparse_modules(self): + """Test calibrate raises error when no sparse modules found.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Don't apply sparse attention + calibrator = DynamicThresholdCalibrator( + threshold_trials=[0.001, 0.01], + ) + + def dummy_forward_loop(m): + pass + + with pytest.raises(ValueError, match="No sparse attention modules found"): + calibrator.calibrate(model, dummy_forward_loop, "prefill") + + def test_calibrate_empty_stats(self): + """Test calibrate handles empty stats gracefully.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.1, "decode": 0.1}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + calibrator = DynamicThresholdCalibrator( + threshold_trials=[0.001], # Only one threshold for speed + ) + + # Forward loop that doesn't generate any stats + def empty_forward_loop(m): + pass + + # Should return empty dict + result = calibrator.calibrate(sparse_model, empty_forward_loop, "prefill") + assert result == {} + + +class TestCalibrateFunction: + """Test calibrate_sparse_attention function.""" + + def test_calibrate_no_config(self): + """Test calibration when config has no calibration section.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Config without calibration + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.1, "decode": 0.1}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # Should return empty dict when no calibration config + result = calibrate_sparse_attention(model, config) + + assert result == {} + + def test_extract_calibration_config(self): + """Test _extract_calibration_config function.""" + # Config with calibration + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.3, "decode": 0.3}, + "samples": 12, + "max_seqlen": 2048, + }, + "*attn*": { + "method": "flash_skip_softmax", + }, + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is not None + assert calib_config.target_sparse_ratio == {"prefill": 0.3, "decode": 0.3} + assert calib_config.samples == 12 + assert calib_config.max_seqlen == 2048 + + def test_extract_calibration_config_none(self): + """Test _extract_calibration_config when no calibration.""" + # Config without calibration + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.1, "decode": 0.1}, + } + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is None diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 0000000000..ddbb718f49 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold={"prefill": 1e-4, "decode": 1e-4}, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == {"prefill": 1e-4, "decode": 1e-4} + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold dict values must be in range (0, 1).""" + # Zero value + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 0, "decode": 1e-4}) + + # Negative value + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -0.1, "decode": 1e-4}) + + # Value equals 1.0 + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0, "decode": 1e-4}) + + # Value greater than 1.0 + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.5, "decode": 1e-4}) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold must be a dict (not single value or string).""" + # Single float value not allowed + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): + SparseAttentionAttributeConfig(threshold=1e-4) + + # String not allowed + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): + SparseAttentionAttributeConfig(threshold="invalid") diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 6fcad9bb85..9a9544419d 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.conversion import ( disable_sparse_attention, enable_sparse_attention, + print_sparse_attention_summary, ) from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule @@ -59,28 +60,6 @@ def test_basic_replacement(self): # Verify replacement occurred assert sparse_attention_count > 0 - def test_enable_disable_toggle(self): - """Test enabling and disabling sparse attention.""" - model = SimpleAttentionModel() - model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) - - # Check initially enabled - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert module.is_enabled - - # Disable all sparse attention modules - disable_sparse_attention(model, "*") - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert not module.is_enabled - - # Re-enable all sparse attention modules - enable_sparse_attention(model, "*") - for module in model.modules(): - if isinstance(module, SparseAttentionModule): - assert module.is_enabled - def test_pattern_based_replacement(self): """Test pattern-based selective replacement.""" model = SimpleTransformerEncoderLayer() @@ -90,7 +69,7 @@ def test_pattern_based_replacement(self): "sparse_cfg": { "*self_attn*": { "method": "flash_skip_softmax", - "threshold": 1e-4, + "threshold": {"prefill": 1e-4, "decode": 1e-4}, "br": 128, "bc": 128, "enable": True, @@ -121,7 +100,7 @@ def filter_func(name): "sparse_cfg": { filter_func: { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, }, }, @@ -139,7 +118,7 @@ def test_no_matching_modules(self): "sparse_cfg": { "*nonexistent*": { "method": "flash_skip_softmax", - "threshold": 1e-3, + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "enable": True, }, }, @@ -150,10 +129,6 @@ def test_no_matching_modules(self): def test_disable_enable_functions(self): """Test disable/enable utility functions.""" - from modelopt.torch.sparsity.attention_sparsity.conversion import ( - disable_sparse_attention, - enable_sparse_attention, - ) model = SimpleAttentionModel() model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) @@ -170,6 +145,19 @@ def test_disable_enable_functions(self): if isinstance(module, SparseAttentionModule): assert module.is_enabled + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Sparse attention:" in captured.out + assert "modules enabled" in captured.out + def test_restore_sparse_attention_model(self): """Test save/restore via modelopt_state.""" # Create and sparsify original model @@ -192,3 +180,72 @@ def test_restore_sparse_attention_model(self): if isinstance(module, SparseAttentionModule): assert hasattr(module, "_method") assert module._method == "flash_skip_softmax" + + +class TestSparseAttentionModuleMethods: + """Test SparseAttentionModule methods.""" + + def test_get_stats_with_stats_manager(self): + """Test get_stats() when stats manager exists and is enabled.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 64, + "bc": 64, + "collect_stats": True, # Enable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + assert sparse_module._stats_manager is not None + + # Get stats (should return summary) + stats = sparse_module.get_stats() + + assert isinstance(stats, dict) + assert "module" in stats + assert "total_calls" in stats + assert "average_sparsity" in stats + + def test_get_stats_without_stats_manager(self): + """Test get_stats() when stats manager is None.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 64, + "bc": 64, + "collect_stats": False, # Disable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Stats manager should be None + assert module._stats_manager is None + + # get_stats should return empty dict + stats = module.get_stats() + assert stats == {} + break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py index e7e32e1534..2a017d66dc 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -35,9 +35,3 @@ def test_sparse_attention_mode_descriptor(): assert mode_descriptor is not None assert hasattr(mode_descriptor, "config_class") assert hasattr(mode_descriptor, "convert") - - -def test_mode_registry_get(): - """Test getting mode from registry.""" - mode = SparseAttentionModeRegistry["sparse_attention"] - assert mode is not None diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py new file mode 100644 index 0000000000..2a390ab1f5 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SparseAttentionStatsManager.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager + + +class TestStatsManagerInitialization: + """Test stats manager initialization.""" + + def test_initialization_defaults(self): + """Test default initialization.""" + manager = SparseAttentionStatsManager(module_name="test_module") + + assert manager.module_name == "test_module" + assert manager.enabled is True + assert manager.calibration_mode is False + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + + def test_initialization_disabled(self): + """Test initialization with disabled stats.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=False) + + assert manager.enabled is False + assert manager.calibration_mode is False + + +class TestStatsCollection: + """Test statistics collection functionality.""" + + def test_collect_stats_enabled(self): + """Test collecting stats when enabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 1 + assert manager.aggregated_stats["total_blocks"] == 100 + assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 + + def test_collect_stats_disabled(self): + """Test that collect() is no-op when disabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=False) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + + manager.collect(stats) + + # Should remain at initial values + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + + def test_collect_multiple_calls(self): + """Test accumulation over multiple collect calls.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect multiple times + for i in range(5): + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 5 + assert manager.aggregated_stats["total_blocks"] == 500 + assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 + + def test_collect_different_phases(self): + """Test phase counting.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect prefill stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + # Collect decode stats + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + + assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 + assert manager.aggregated_stats["phase_counts"]["decode"] == 1 + assert manager.aggregated_stats["phase_counts"]["unknown"] == 0 + + +class TestCalibrationMode: + """Test calibration mode functionality.""" + + def test_calibration_mode_per_sample_collection(self): + """Test that calibration mode stores per-sample stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Enable calibration mode + manager.set_calibration_mode(enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + # Should store in per_sample_stats + assert len(manager.per_sample_stats) == 1 + assert manager.per_sample_stats[0]["module"] == "test" + assert manager.per_sample_stats[0]["sparsity"] == 0.5 + assert manager.per_sample_stats[0]["sample_length"] == 1024 + assert manager.per_sample_stats[0]["phase"] == "prefill" + + def test_calibration_mode_off(self): + """Test that per-sample stats are not collected when calibration mode is off.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + # Calibration mode is off by default + + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + + manager.collect(stats) + + # Should NOT store in per_sample_stats + assert len(manager.per_sample_stats) == 0 + + # But should still aggregate + assert manager.aggregated_stats["total_calls"] == 1 + + def test_set_calibration_mode_with_reset(self): + """Test set_calibration_mode with reset_history=True.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats in calibration mode + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Re-enable with reset + manager.set_calibration_mode(enabled=True, reset_history=True) + assert len(manager.per_sample_stats) == 0 # Should be cleared + + def test_set_calibration_mode_without_reset(self): + """Test set_calibration_mode with reset_history=False.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Disable without reset + manager.set_calibration_mode(enabled=False, reset_history=False) + assert len(manager.per_sample_stats) == 1 # Should be preserved + + +class TestGetSummary: + """Test get_summary() functionality.""" + + def test_get_summary_with_data(self): + """Test get_summary returns correct averages.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) + + # Collect stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + summary = manager.get_summary() + + assert summary["module"] == "test_module" + assert summary["total_calls"] == 2 + # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 + assert summary["average_sparsity"] == 0.4 + assert summary["phase_distribution"]["prefill"] == 2 + + def test_get_summary_zero_blocks(self): + """Test get_summary when total_blocks is zero.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect stats with zero blocks + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + + summary = manager.get_summary() + + assert summary["average_sparsity"] == 0.0 # Should handle division by zero + + +class TestGetCalibrationStats: + """Test get_calibration_stats() functionality.""" + + def test_get_calibration_stats(self): + """Test retrieving per-sample calibration stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect multiple samples + for i in range(3): + manager.collect( + { + "sparsity": 0.3 + i * 0.1, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 30, + "sample_length": 1024 + i * 512, + } + ) + + calib_stats = manager.get_calibration_stats() + + assert len(calib_stats) == 3 + assert calib_stats[0]["sparsity"] == 0.3 + assert calib_stats[1]["sparsity"] == 0.4 + assert calib_stats[2]["sparsity"] == 0.5 + + def test_get_calibration_stats_empty(self): + """Test get_calibration_stats when no calibration data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + calib_stats = manager.get_calibration_stats() + + assert calib_stats == [] + + +class TestReset: + """Test reset functionality.""" + + def test_reset(self): + """Test reset() clears all statistics.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect some stats + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + manager.collect( + { + "sparsity": 0.3, + "phase": "decode", + "total_blocks": 10, + "sparse_blocks": 3, + "sample_length": 128, + } + ) + + # Verify stats exist + assert manager.aggregated_stats["total_calls"] == 2 + assert len(manager.per_sample_stats) == 2 + + # Reset + manager.reset() + + # All stats should be cleared + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py new file mode 100644 index 0000000000..320196ccc4 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for threshold calibration functionality.""" + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch.sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestFlashSkipSoftmaxThresholdInfo: + """Test FlashSkipSoftmax.get_threshold_info() method.""" + + def test_phased_threshold(self): + """Test threshold info for phase-specific static thresholds.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + # Static phased thresholds are reported as type "static" with dict value + assert info["type"] == "static" + assert isinstance(info["value"], dict) + assert info["value"]["prefill"] == 0.001 + assert info["value"]["decode"] == 0.0001 + + def test_dynamic_calibrated_threshold(self): + """Test threshold info for calibrated dynamic threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Simulate calibration setting a and b parameters + method.calibration_params = { + "prefill": {"a": 150.0, "b": 1.5}, + "decode": {"a": 200.0, "b": 1.8}, + } + method.target_sparse_ratio = {"prefill": 0.9, "decode": 0.9} + + info = method.get_threshold_info() + + assert info["type"] == "dynamic_calibrated" + assert info["formula"] == "threshold = a * exp(b * target_sparsity) / seqlen" + assert "calibration_params" in info + assert "target_sparse_ratio" in info + assert "phases" in info + assert "prefill" in info["phases"] + assert "decode" in info["phases"] + # Check that a and b are in phase info + assert info["phases"]["prefill"]["a"] == 150.0 + assert info["phases"]["prefill"]["b"] == 1.5 + assert info["phases"]["prefill"]["target_sparsity"] == 0.9 + + +class TestSparseAttentionModuleThresholdInfo: + """Test SparseAttentionModule.get_threshold_info() delegation.""" + + def test_module_delegates_to_method(self): + """Test that module correctly delegates to sparse method instance.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.005, "decode": 0.001}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find sparse attention module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + + # Test get_threshold_info + info = sparse_module.get_threshold_info() + + assert info["type"] == "static" + assert info["value"]["prefill"] == 0.005 + assert info["value"]["decode"] == 0.001 + + def test_module_with_calibrated_threshold(self): + """Test module reports calibrated threshold correctly.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module and set calibrated params (Exponential model) + module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.calibration_params = { + "prefill": {"a": 150.0, "b": 1.5}, + "decode": {"a": 200.0, "b": 1.8}, + } + module._sparse_method_instance.target_sparse_ratio = { + "prefill": 0.9, + "decode": 0.9, + } + break + + assert module is not None, "No SparseAttentionModule found" + # Get threshold info + info = module.get_threshold_info() + + assert info["type"] == "dynamic_calibrated" + assert info["calibration_params"]["prefill"]["a"] == 150.0 + + def test_module_without_method_instance(self): + """Test get_threshold_info when sparse method instance doesn't exist.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Remove sparse method instance to test fallback + delattr(module, "_sparse_method_instance") + + info = module.get_threshold_info() + + assert info["type"] == "none" + assert info["value"] is None + break diff --git a/tests/unit/torch/utils/test_network.py b/tests/unit/torch/utils/test_network.py index 3d4600c7ed..15b7c69166 100644 --- a/tests/unit/torch/utils/test_network.py +++ b/tests/unit/torch/utils/test_network.py @@ -19,6 +19,7 @@ import pytest import torch +from _test_utils.torch.nas_prune.utils import param_num from torch import nn from torchvision.models import MobileNetV2 @@ -28,7 +29,6 @@ get_model_attributes, get_same_padding, make_divisible, - param_num, set_submodule, standardize_model_args, ) diff --git a/tox.ini b/tox.ini index ee7acf0297..079e9a2ade 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,9 @@ [tox] envlist= pre-commit-all - py312-torch28-tf_latest-unit - py312-cuda12-gpu + py312-torch210-tf_latest-unit + cuda13-gpu + cuda13-gpu-megatron skipsdist = True toxworkdir = /tmp/{env:USER}-modelopt-tox @@ -10,13 +11,14 @@ toxworkdir = /tmp/{env:USER}-modelopt-tox ############################ # CPU Unit test environments ############################ -[testenv:{py310,py311,py312}-torch{26,27,28,29}-tf_{min,latest}-unit] +[testenv:{py310,py311,py312}-torch{26,27,28,29,210}-tf_{min,latest}-unit] deps = # torch version auto-selected based on torchvision version torch26: torchvision~=0.21.0 torch27: torchvision~=0.22.0 torch28: torchvision~=0.23.0 torch29: torchvision~=0.24.0 + torch210: torchvision~=0.25.0 # Install megatron-core for special unit tests megatron-core @@ -36,8 +38,8 @@ commands = allowlist_externals = bash, rm deps = - # Make sure torch 2.9 is used - torchvision~=0.24.0 + # Make sure torch 2.10 is used + torchvision~=0.25.0 # ONNX unit tests heavily rely on torch / torchvision onnx: .[onnx,dev-test] @@ -57,26 +59,28 @@ commands = ########################################################### # GPU test environments (Should be used with --current-env) ########################################################### -[testenv:{py310,py311,py312}-cuda12-gpu] +[testenv:cuda13-gpu] commands_pre = # Install deps here so that it gets installed even in --current-env - pip install -U megatron-core - pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git - - # Skip triton because pytorch-triton is installed in the NGC PyTorch containers - pip install pip-mark-installed - pip-mark-installed triton - pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git + pip install --no-build-isolation git+https://github.com/Dao-AILab/fast-hadamard-transform.git + pip install -e .[all,dev-test] - # Install Eagle-3 test dependencies - pip install tiktoken blobfile sentencepiece + # Install cupy-cuda13x for INT4 ONNX quantization (default is cupy-cuda12x) + pip uninstall -y cupy-cuda12x + pip install cupy-cuda13x +commands = + # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" + python -m pytest tests/gpu - # NOTE: User is expected to have correct torch-cuda version pre-installed if using --current-env - # to avoid possible CUDA version mismatch +[testenv:cuda13-gpu-megatron] +commands_pre = + # Install deps here so that it gets installed even in --current-env + pip install -U megatron-core + pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git pip install -e .[all,dev-test] commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" - python -m pytest tests/gpu + python -m pytest tests/gpu_megatron ############################################# # Code quality checks on all files or on diff