diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py index 5aae52ef3e8..e819aaf5159 100644 --- a/backends/openvino/quantizer/__init__.py +++ b/backends/openvino/quantizer/__init__.py @@ -1,3 +1,9 @@ +from .llm_compression import apply_nncf_data_aware_compression from .quantizer import OpenVINOQuantizer, QuantizationMode, quantize_model -__all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"] +__all__ = [ + "OpenVINOQuantizer", + "quantize_model", + "QuantizationMode", + "apply_nncf_data_aware_compression", +] diff --git a/backends/openvino/quantizer/llm_compression.py b/backends/openvino/quantizer/llm_compression.py new file mode 100644 index 00000000000..1737f638bf9 --- /dev/null +++ b/backends/openvino/quantizer/llm_compression.py @@ -0,0 +1,133 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from typing import Tuple + +import torch +from executorch.extension.llm.export.builder import LLMEdgeManager +from torchao.quantization.pt2e.quantizer import Quantizer + +try: + import nncf # type: ignore[import-untyped] + from pytorch_tokenizers import get_tokenizer # type: ignore[import-untyped] +except ImportError: + raise ImportError("Please install nncf via backends/openvino/requirements.txt") + + +# This code is adapted from https://github.com/pytorch/executorch/blob/0c54fd0483314da173f8e14d63d2ed9591c7133a/extension/llm/export/builder.py#L278 +def get_calibration_data( + module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int +): + """ + This method is used to obtain calibration data from a prompt so that the algorithm + is calibrated not only with the dataset but also the inputs which are output by + the model. + Currently, this method is only tested with Llama models. + """ + # TODO: change criteria & support batch inputs if necessary + pos = 0 + token_list = tokenizer.encode(prompts, bos=True, eos=False) + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_len: + logits = module( + torch.full((1, 1), token_list[pos]), + {"input_pos": torch.tensor((pos,))}, + ) + pos += 1 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:], dim=-1).item()) + token_list = [ + ( + torch.tensor(pos, dtype=torch.int64), + token, + ) + for pos, token in enumerate(token_list) + ] + return token_list + + +def transform_fn(token_pos_map: Tuple[int, int]): + """ + Transforms and returns input from dataset so that it is acceptable by the model + Currently, this method is only tested with Llama models. + + :param token_pos_map: This input contains the position and its token ID + """ + inputs = ( + torch.tensor([[token_pos_map[1]]]), + {"input_pos": torch.tensor([token_pos_map[0]])}, + ) + + return inputs + + +def apply_nncf_data_aware_compression( + builder_exported: LLMEdgeManager, + quantizer: Quantizer, + awq: bool, + scale_estimation: bool, +) -> LLMEdgeManager: + """ + Applies NNCF data-aware weight compression to the exported LLM graph. + Uses the builder's tokenizer and calibration prompt to generate token-level + calibration data, then runs `nncf.experimental.torch.fx.compress_pt2e` with + the given quantizer and optional AWQ / scale estimation enabled. + + :param builder_exported: LLMEdgeManager containing the FX graph, tokenizer path, + calibration prompt, and max sequence length. + :param quantizer: TorchAO quantizer to use for compression. + :param awq: If True, enables Activation-aware Weights Quantization (AWQ). + :param scale_estimation: If True, enables NNCF's scale estimation algorithm. + :return: The updated LLMEdgeManager with compressed torch FX model + """ + nncf_calibration_data = None + if ( + builder_exported.calibration_seq_length is not None + and builder_exported.calibration_data is not None + and builder_exported.tokenizer_path is not None + and (awq or scale_estimation) + ): + tokenizer = get_tokenizer(builder_exported.tokenizer_path) + nncf_calibration_data = nncf.Dataset( + get_calibration_data( + builder_exported.pre_autograd_graph_module, # type: ignore[arg-type] + tokenizer, + builder_exported.calibration_data, + builder_exported.calibration_seq_length, + ), + transform_func=transform_fn, + ) + + # AWQ can work without a dataset as well. + if scale_estimation and not nncf_calibration_data: + missing_params = [] + if builder_exported.calibration_data is None: + missing_params.append("calibration_data") + if builder_exported.calibration_seq_length is None: + missing_params.append("calibration_seq_length") + if builder_exported.tokenizer_path is None: + missing_params.append("tokenizer_path") + if missing_params: + msg = ( + "Missing required calibration parameter(s): " + + ", ".join(missing_params) + + ". Please provide calibration_data, calibration_seq_length, and tokenizer_path." + ) + raise ValueError(msg) + + builder_exported.pre_autograd_graph_module = ( + nncf.experimental.torch.fx.compress_pt2e( + builder_exported.pre_autograd_graph_module, + quantizer=quantizer, + dataset=nncf_calibration_data, + awq=awq, + scale_estimation=scale_estimation, + ) + ) + return builder_exported diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index 4a46b310cf5..2ffc80ec219 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -13,7 +13,6 @@ import nncf # type: ignore[import-untyped] import nncf.common.quantization as quantization # type: ignore[import-untyped] import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped] - import torch.fx from executorch.backends.openvino.quantizer.observers import ( INT4WeightObserver, @@ -78,12 +77,12 @@ class OpenVINOQuantizer(Quantizer): optimally for the inference via OpenVINO. """ - WEIGHTS_ONLY_COMPRESSION_MODES = ( - QuantizationMode.INT4WO_SYM, - QuantizationMode.INT4WO_ASYM, - QuantizationMode.INT8WO_SYM, - QuantizationMode.INT8WO_ASYM, - ) + WEIGHTS_ONLY_COMPRESSION_MODES = { + QuantizationMode.INT4WO_SYM: "int4_sym", + QuantizationMode.INT4WO_ASYM: "int4_asym", + QuantizationMode.INT8WO_SYM: "int8_sym", + QuantizationMode.INT8WO_ASYM: "int8_asym", + } def __init__( self, @@ -116,17 +115,63 @@ def __init__( preset=preset, model_type=model_type, **kwargs ) else: - compression_mode = mode.value.replace( - "wo", "" - ) # Mode value has to match NNCF CompressWeightsMode + compression_mode = OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES[ + mode + ] # Mode value has to match NNCF CompressWeightsMode weight_compression_configuration = get_weight_compression_configuration( nncf.CompressWeightsMode(compression_mode), **kwargs, ) - subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve + weight_compression_configuration["subset_size"] = ( + 1 # Doesn't really matter in this case since it is data-free. Should just be +ve + ) + self._algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression( - subset_size=subset_size, **weight_compression_configuration + **weight_compression_configuration + ) + + def _require_wc_algo( + self, + ) -> nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression: + if not isinstance( + self._algo, + nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression, + ): + raise TypeError( + "This method requires WeightCompression algo, but " + f"got {type(self._algo).__name__} (mode={self.mode})." ) + return self._algo + + def _require_ptq_algo(self) -> MinMaxQuantization: + if not isinstance(self._algo, MinMaxQuantization): + raise TypeError( + "This method requires MinMaxQuantization algo, but " + f"got {type(self._algo).__name__} (mode={self.mode})." + ) + return self._algo + + def get_weights_compression_config(self) -> Dict[str, Any]: + """ + Returns a dictionary with all_layers, group_size, backup_mode and Quantization mode parameters + used by the compress_pt2e weight compression algorithm. + + :return: A dictionary containing: + 1. mode: Quantization mode. One of INT4 Sym, INT4 Asym, INT8 Sym, INT8 Asym. + 2. group_size: group size to be used for group-wise compression. + 3. all_layers: Indicates whether embeddings and last MatMul layers should be compressed to a primary + precision. By default, the backup precision is assigned for the embeddings and last MatMul layers. + 4. backup_mode: Defines a backup mode for mixed-precision weight compression. + """ + algo = self._require_wc_algo() + quantizer_initialized_algo_attributes = { + "mode": algo.mode, + "group_size": algo.group_size, + "all_layers": algo.all_layers, + "backup_mode": algo.backup_mode, + } + + return quantizer_initialized_algo_attributes def set_ignored_scope( self, @@ -160,8 +205,32 @@ def set_ignored_scope( def get_nncf_quantization_setup( self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph ) -> quantization.quantizer_setup.SingleConfigQuantizerSetup: - self._algo._set_backend_entity(model) - return self._algo.find_quantization_setup(model, nncf_graph) + algo = self._require_ptq_algo() + algo._set_backend_entity(model) + return algo.find_quantization_setup(model, nncf_graph) + + def get_nncf_weight_compression_parameters( + self, + model: torch.fx.GraphModule, + nncf_graph: NNCFGraph, + ) -> Tuple[ + List[WeightCompressionParameters], + List[WeightCompressionParameters], + List[WeightCompressionParameters], + ]: + """ + Collect weight compression parameters for the given FX model and NNCF graph. + + :param model: FX GraphModule to analyze for weight compression. + :param nncf_graph: NNCFGraph representation of the model. + :return: A tuple of: + - all parameters eligible for weight compression, + - ratio-defining parameters used to set primary/backup precisions, + - parameters that are not compressible and remain in original precision. + """ + algo = self._require_wc_algo() + algo.set_backend_entity(model) + return algo.get_weight_compression_parameters(model, nncf_graph) def _annotate_weight_compression( self, @@ -182,12 +251,17 @@ def _annotate_weight_compression( :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. :return: Updated mapping of FX nodes with weight compression annotations. """ - self._algo.set_backend_entity(model) - all_wc_params, _ = self._algo.get_weight_compression_parameters( + all_wc_params, *_ = self.get_nncf_weight_compression_parameters( model, nncf_graph ) for wc_param in all_wc_params: + if not wc_param.compression_config: + nncf_logger.debug( + "Skipping weight compression for node '%s' because compression_config is missing.", + getattr(wc_param.node_with_weight, "node_name", ""), + ) + continue node_with_weight = wc_param.node_with_weight target_node = nncf_fx.node_utils.get_graph_node_by_name( graph, node_with_weight.node_name diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt index 88ae5f9546b..ba338416583 100644 --- a/backends/openvino/requirements.txt +++ b/backends/openvino/requirements.txt @@ -1 +1 @@ -git+https://github.com/openvinotoolkit/nncf@3d753ac#egg=nncf +nncf==3.0.0 diff --git a/backends/openvino/tests/README.md b/backends/openvino/tests/README.md index 0aad14e04a0..b5624a8ca15 100644 --- a/backends/openvino/tests/README.md +++ b/backends/openvino/tests/README.md @@ -11,6 +11,8 @@ backends/openvino/tests └── test_.py # Individual op tests scripts. ├── models # Directory with model test scripts. └── test_classification.py # Test script for classification models. +├── quantizer # Directory with quantizer test scripts. + └── test_llm_compression.py # Test script for llm compression using NNCF algorithms. ├── README.md # Documentation for unit tests (this file) └── test_runner.py # Script to execute unit tests. ``` @@ -31,6 +33,7 @@ Before you begin, refer to instructions provided in [OpenVINO Backend for ExecuT Supported values: - `ops` (default) - `models` + - `quantizer` - **`--pattern`** (optional): Pattern to match test files. Provide complete file name to run individual tests. The default value is `test_*.py` diff --git a/backends/openvino/tests/quantizer/synthetic_test_models.py b/backends/openvino/tests/quantizer/synthetic_test_models.py new file mode 100644 index 00000000000..6c7e91c5539 --- /dev/null +++ b/backends/openvino/tests/quantizer/synthetic_test_models.py @@ -0,0 +1,22 @@ +import torch + + +class ExportLlamaTestModel(torch.nn.Module): + def __init__(self, vocab_size=5, hidden_size=2, num_layers=1): + super().__init__() + self.embed = torch.nn.Embedding(vocab_size, hidden_size) + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(hidden_size, hidden_size) for _ in range(num_layers)] + ) + self.lm_head = torch.nn.Linear(hidden_size, vocab_size) + self.vocab_size = vocab_size + + def forward(self, tokens, input_pos): + x = self.embed(tokens) + + for layer in self.layers: + x = torch.relu(layer(x)) + + logits = self.lm_head(x) + + return logits diff --git a/backends/openvino/tests/quantizer/test_llm_compression.py b/backends/openvino/tests/quantizer/test_llm_compression.py new file mode 100644 index 00000000000..6dfef1fb600 --- /dev/null +++ b/backends/openvino/tests/quantizer/test_llm_compression.py @@ -0,0 +1,249 @@ +import unittest +from unittest.mock import Mock, patch + +import torch +from executorch.backends.openvino.quantizer import OpenVINOQuantizer, QuantizationMode + +from executorch.backends.openvino.quantizer.llm_compression import ( + apply_nncf_data_aware_compression, + get_calibration_data, + transform_fn, +) +from executorch.extension.llm.export.builder import LLMEdgeManager +from synthetic_test_models import ExportLlamaTestModel # type: ignore[import-not-found] + + +class TestWeightsOnlyQuantization(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(42) + cls.model = ExportLlamaTestModel(vocab_size=5, hidden_size=2, num_layers=1) + cls.model.eval() + + cls.max_seq_len = 128 + cls.example_inputs = ( + torch.tensor([[1]], dtype=torch.long), + {"input_pos": torch.tensor([0], dtype=torch.long)}, + ) + + cls.compression_configs = [ + { + "name": "awq_only", + "awq": True, + "scale_estimation": False, + }, + { + "name": "scale_estimation_only", + "awq": False, + "scale_estimation": True, + }, + { + "name": "awq_and_scale_estimation", + "awq": True, + "scale_estimation": True, + }, + { + "name": "no_calibration", + "awq": False, + "scale_estimation": False, + }, + ] + + cls.calibration_data = "The quick brown fox jumps over the lazy dog." + + cls.reference_scales = { + "awq_only": { + "symmetric_weights_decompressor_embed_weight._scale": torch.tensor( + [[-0.042084], [-0.029312], [0.140381], [-0.276123], [-0.057709]], + dtype=torch.float16, + ), + "symmetric_weights_decompressor_layers_0_weight._scale": torch.tensor( + [[0.040710], [-0.058624]], dtype=torch.float16 + ), + "relu/awq_mul._scale_value": torch.tensor([[[1.0, 1.0]]]), + "symmetric_weights_decompressor_lm_head_weight_updated_constant0._scale": torch.tensor( + [[0.053131], [0.087280], [-0.079834], [-0.068237], [-0.054626]], + dtype=torch.float16, + ), + }, + "scale_estimation_only": { + "symmetric_weights_decompressor_embed_weight._scale": torch.tensor( + [[-0.042084], [-0.029312], [0.140381], [-0.276123], [-0.057709]], + dtype=torch.float16, + ), + "symmetric_weights_decompressor_layers_0_weight._scale": torch.tensor( + [[0.040710], [-0.057709]], dtype=torch.float16 + ), + "symmetric_weights_decompressor_lm_head_weight._scale": torch.tensor( + [[0.0], [0.0], [-0.0], [-0.0], [-0.0]], dtype=torch.float16 + ), + }, + "awq_and_scale_estimation": { + "symmetric_weights_decompressor_embed_weight._scale": torch.tensor( + [[-0.042084], [-0.029312], [0.140381], [-0.276123], [-0.057709]], + dtype=torch.float16, + ), + "symmetric_weights_decompressor_layers_0_weight._scale": torch.tensor( + [[0.040710], [-0.057709]], dtype=torch.float16 + ), + "relu/awq_mul._scale_value": torch.tensor([[[1.0, 1.0]]]), + "symmetric_weights_decompressor_lm_head_weight_updated_constant0._scale": torch.tensor( + [[0.0], [0.0], [-0.0], [-0.0], [-0.0]], dtype=torch.float16 + ), + }, + "no_calibration": { + "symmetric_weights_decompressor_embed_weight._scale": torch.tensor( + [[-0.042084], [-0.029312], [0.140381], [-0.276123], [-0.057709]], + dtype=torch.float16, + ), + "symmetric_weights_decompressor_layers_0_weight._scale": torch.tensor( + [[0.040710], [-0.058624]], dtype=torch.float16 + ), + "symmetric_weights_decompressor_lm_head_weight._scale": torch.tensor( + [[0.053131], [0.087280], [-0.079834], [-0.068237], [-0.054626]], + dtype=torch.float16, + ), + }, + } + + def _create_builder(self, config_name, calibration_data=None): + builder_kwargs = { + "model": self.model, + "modelname": f"tinyllama_{config_name}", + "max_seq_len": self.max_seq_len, + "use_kv_cache": True, + "example_inputs": self.example_inputs, + "example_kwarg_inputs": None, + } + + if calibration_data: + builder_kwargs.update( + { + "calibration_seq_length": 32, + "calibration_data": calibration_data, + "tokenizer_path": "dummy_path", + } + ) + + return LLMEdgeManager(**builder_kwargs) + + def _extract_scales_from_model(self, model): + extracted_scales = {} + state_dict = dict(model.state_dict()) + for name, _ in state_dict.items(): + if "_scale" in name.lower(): + extracted_scales[name] = state_dict[name] + return extracted_scales + + def _compare_scales(self, extracted_scales, reference_scales): + for name, reference_value in reference_scales.items(): + self.assertIn(name, extracted_scales, f"Scale {name} not found in model") + extracted_value = extracted_scales[name] + self.assertTrue( + torch.allclose(extracted_value, reference_value), + f"Scale {name} mismatch {extracted_value}", + ) + + @patch("executorch.backends.openvino.quantizer.llm_compression.get_tokenizer") + @patch( + "executorch.backends.openvino.quantizer.llm_compression.get_calibration_data" + ) + def test_compression_flow_with_mocked_calibration( + self, mock_get_calibration_data, mock_get_tokenizer + ): + mock_calibration_data = [(i, i) for i in range(5)] + mock_get_calibration_data.return_value = mock_calibration_data + + mock_tokenizer = Mock() + mock_get_tokenizer.return_value = mock_tokenizer + + for config in self.compression_configs: + with self.subTest(phase="compression_config", config=config["name"]): + calibration_data = ( + self.calibration_data + if config["awq"] or config["scale_estimation"] + else None + ) + + builder = self._create_builder( + config["name"], calibration_data=calibration_data + ) + builder.export() + + test_input = torch.tensor([[4]], dtype=torch.long) + test_pos = torch.tensor([0], dtype=torch.long) + # Quantize weights for all layers(including embedding and lm_head which would by default be in INT8) + # to Per-Channel INT4 Symmetric + quantizer = OpenVINOQuantizer( + mode=QuantizationMode.INT4WO_SYM, group_size=-1, all_layers=True + ) + builder = apply_nncf_data_aware_compression( + builder, + quantizer=quantizer, + awq=config["awq"], + scale_estimation=config["scale_estimation"], + ) + # Run the model to check it is performant + builder.pre_autograd_graph_module(test_input, {"input_pos": test_pos}) + extracted_scales = self._extract_scales_from_model( + builder.pre_autograd_graph_module + ) + self._compare_scales( + extracted_scales, + self.reference_scales[config["name"]], + ) + + def test_scale_estimation_requires_calibration_params(self): + builder = self._create_builder( + "missing_calibration_data", calibration_data=None + ) + builder.export() + + quantizer = OpenVINOQuantizer( + mode=QuantizationMode.INT4WO_SYM, group_size=-1, all_layers=True + ) + + with self.assertRaises(ValueError) as cm: + apply_nncf_data_aware_compression( + builder, + quantizer=quantizer, + awq=False, + scale_estimation=True, + ) + + err = str(cm.exception) + self.assertIn("Missing required calibration parameter(s)", err) + self.assertIn("calibration_data", err) + self.assertIn("calibration_seq_length", err) + self.assertIn("tokenizer_path", err) + + +class TestCalibrationDataGeneration(unittest.TestCase): + + def test_get_calibration_data_with_mock_module(self): + mock_tokenizer = Mock() + mock_tokenizer.eos_id = 2 + mock_tokenizer.encode = Mock(return_value=[1, 5, 6]) + + mock_module = Mock() + mock_module.return_value = torch.tensor([[[0.1, 0.2, 0.9, 0.0]]]) + + result = get_calibration_data( + mock_module, mock_tokenizer, "test prompt", max_len=10 + ) + + positions = [item[0] for item in result] + self.assertEqual(positions, list(range(len(positions)))) + + def test_transform_fn(self): + token_pos_map = (5, 10) + result = transform_fn(token_pos_map) + + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + + token, input_pos_dict = result + self.assertEqual(token.shape, torch.Size([1, 1])) + self.assertEqual(token, torch.tensor([[10]])) + self.assertIn("input_pos", input_pos_dict) + self.assertEqual(input_pos_dict["input_pos"], torch.tensor([5])) diff --git a/backends/openvino/tests/test_runner.py b/backends/openvino/tests/test_runner.py index 021c372db25..7d8c6b968c2 100644 --- a/backends/openvino/tests/test_runner.py +++ b/backends/openvino/tests/test_runner.py @@ -1,8 +1,6 @@ import argparse import unittest -import nncf.torch # type: ignore[import-untyped,import-not-found] - class OpenvinoTestSuite(unittest.TestSuite): @@ -44,10 +42,10 @@ def parse_arguments(): parser.add_argument( "-t", "--test_type", - help="Specify the type of tests ('ops' or 'models')", + help="Specify the type of tests ('ops', 'models' or 'quantizer')", type=str, default="ops", - choices={"ops", "models"}, + choices={"ops", "models", "quantizer"}, ) args, ns_args = parser.parse_known_args(namespace=unittest) @@ -68,8 +66,7 @@ def parse_arguments(): # Discover all existing op tests in "ops" folder suite = loader.discover(test_params["test_type"], pattern=test_params["pattern"]) # Start running tests - with nncf.torch.disable_patching(): - result = unittest.TextTestRunner().run(suite) + result = unittest.TextTestRunner().run(suite) if result.wasSuccessful(): print("OpenVINO backend tests completed successfully") else: diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index b074e1cab7c..f0b22623b86 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -249,6 +249,18 @@ def build_args_parser() -> argparse.ArgumentParser: help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.", ) + parser.add_argument( + "--openvino_awq", + action="store_true", + help="Whether to use AWQ from NNCF. Applicable only for the OpenVINO backend.", + ) + + parser.add_argument( + "--openvino_scale_estimation", + action="store_true", + help="Whether to use Scale Estimation algorithm from NNCF. Applicable only for the OpenVINO backend", + ) + parser.add_argument( "--use_qnn_sha", action="store_true", @@ -783,7 +795,7 @@ def get_quantizer_and_quant_params(llm_config): ) quantizers.append(qnn_quantizer) if llm_config.backend.openvino.enabled and llm_config.quantization.pt2e_quantize: - assert not quantizers, "Should not enable both xnnpack and openvino" + assert not quantizers, "Should not enable openvino and other quantizers" group_size = llm_config.quantization.group_size group_size = group_size if group_size else 128 ov_quantizer = get_ov_quantizer( @@ -942,6 +954,8 @@ def _to_edge_and_lower_llama_openvino( modelname, quantizers, additional_passes, + awq, + scale_estimation, openvino_device: str = "CPU", verbose: bool = False, ) -> LLMEdgeManager: # noqa: C901 @@ -955,10 +969,15 @@ def _to_edge_and_lower_llama_openvino( for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( - partitioners + from executorch.backends.openvino.quantizer import apply_nncf_data_aware_compression + + logging.info(f"Applying AWQ = {awq}, Scale Estimation = {scale_estimation}") + builder = apply_nncf_data_aware_compression( + builder_exported, quantizers[0], awq, scale_estimation ) + builder = builder.to_edge_transform_and_lower(partitioners) + if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) @@ -1341,6 +1360,8 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 modelname, quantizers, additional_passes, + awq=llm_config.backend.openvino.openvino_awq, + scale_estimation=llm_config.backend.openvino.openvino_scale_estimation, openvino_device=llm_config.backend.openvino.device, verbose=llm_config.debug.verbose, ) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index c4c6ef11cca..84912737b84 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -520,8 +520,9 @@ class OpenvinoConfig: enabled: bool = False device: str = "CPU" - nncf_compression: bool = False nncf_compression_group_size: int = 32 + openvino_awq: bool = False + openvino_scale_estimation: bool = False @dataclass @@ -736,8 +737,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.backend.openvino.enabled = args.openvino if hasattr(args, "openvino_device"): llm_config.backend.openvino.device = args.openvino_device - if hasattr(args, "nncf_compression"): - llm_config.backend.openvino.nncf_compression = args.nncf_compression + if hasattr(args, "openvino_awq"): + llm_config.backend.openvino.openvino_awq = args.openvino_awq + if hasattr(args, "openvino_scale_estimation"): + llm_config.backend.openvino.openvino_scale_estimation = ( + args.openvino_scale_estimation + ) if hasattr(args, "group_size") and args.group_size: llm_config.backend.openvino.nncf_compression_group_size = args.group_size