diff --git a/backends/openvino/quantizer/__init__.py b/backends/openvino/quantizer/__init__.py index df038483f2f..0fd8c10b249 100644 --- a/backends/openvino/quantizer/__init__.py +++ b/backends/openvino/quantizer/__init__.py @@ -1,3 +1,3 @@ -from .quantizer import OpenVINOQuantizer, quantize_model +from .quantizer import OpenVINOQuantizer, quantize_model, QuantizationMode -__all__ = ["OpenVINOQuantizer", "quantize_model"] +__all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"] diff --git a/backends/openvino/quantizer/observers.py b/backends/openvino/quantizer/observers.py new file mode 100644 index 00000000000..6cda4561604 --- /dev/null +++ b/backends/openvino/quantizer/observers.py @@ -0,0 +1,186 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file found in the +# LICENSE file in the root directory of this source tree. + +# mypy: disable-error-code=import-not-found + +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped] + get_tensor_constant_from_node, +) +from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped] + constant_update, + module_insertion, + node_removal, +) +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, +) +from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped] + do_integer_quantization, +) +from nncf.tensor.tensor import Tensor as NNCFTensor # type: ignore[import-untyped] +from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped] + PTTargetPoint, + TargetType, +) +from nncf.torch.quantization.layers import ( # type: ignore[import-untyped] + BaseWeightsDecompressor, + INT4AsymmetricWeightsDecompressor, + INT4SymmetricWeightsDecompressor, + INT8AsymmetricWeightsDecompressor, + INT8SymmetricWeightsDecompressor, +) +from torchao.quantization.pt2e import ObserverBase + + +class WeightObserverBase(ObserverBase, ABC): + """ + Base implementation of an NNCF observer that defines the rules for compressing layer weights into the OpenVINO representation. + """ + + def __init__( + self, + wc_param: WeightCompressionParameters, + dtype: torch.dtype, + **kwargs, + ) -> None: + """ + :param wc_param: Weight compression parameters container. + :param dtype: target dtype for the quantization. + """ + super().__init__(dtype=dtype, is_dynamic=False) + self._wc_param = wc_param + + def calculate_qparams( # type: ignore[override] + self, + weight: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Calculates quantization parameters: quantized weight, quantization scale and quantization zero point. + + :param weight: FP weight to be used for calculating qparams. + :return: A tuple containing the quantized weight, quantization scale and quantization zero point. + """ + wc_param = self._wc_param + wc_config = wc_param.compression_config + reduction_axes = wc_param.reduction_axes + q_weight, scale, zp = do_integer_quantization( + NNCFTensor(weight), wc_config, reduction_axes=reduction_axes + ) + zp = zp.data if zp is not None else None + return q_weight.data, scale.data, zp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + def convert( + self, model: torch.fx.GraphModule, observer_node: torch.fx.Node + ) -> None: + """ + Replaces the given observer node from the given model with a quantized + weight and a OpenVINO specific decompression module. + + :param model: A `torch.fx.GraphModule` representing the statically traced model + with observer nodes attached and calibrated. + :param observer_node: The `torch.fx.Node` corresponding to the observer module for + the weight that is being transformed into a compressed representation. + """ + weight_node = observer_node.args[0] + original_weight = get_tensor_constant_from_node(weight_node, model) + q_weight, scale, zero_point = self.calculate_qparams(original_weight) + + decompressor = self._create_decompressor( + scale, zero_point, q_weight, original_weight + ) + packed_q_weight = decompressor.pack_weight(q_weight) + + # Weight port id is 0 since observer is inserted for a single weight only. + constant_update(model, observer_node, packed_q_weight, input_port_id=0) + + compressed_weight_name = observer_node.all_input_nodes[0].name + decompressor_suffix = "_".join( + compressed_weight_name.replace(".", "_").split("_")[:-2] + ) + decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" + + module_insertion( + model, + decompressor, + [ + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, + target_node_name=compressed_weight_name, + ) + ], + decompressor_name, + ) + node_removal(model, observer_node, 0) + + @abstractmethod + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + """ + Returns a respective NNCF decompressor for different types of quantization. + + :param scale: Calculated scale quantization parameter. + :param zero_point: Calculated zero_point quantization parameter. + :param q_weight: Calculated quantized weight. + :param original_weight: FP weight. + :return: NNCF observer according to the qmode which creates the decompression subgraph supported by OpenVINO. + """ + + +class INT4WeightObserver(WeightObserverBase): + """ + OpenVINO INT4 Weight Compression observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT4SymmetricWeightsDecompressor( + scale, q_weight.shape, original_weight.shape, original_weight.dtype + ) + return INT4AsymmetricWeightsDecompressor( + scale, + zero_point, + q_weight.shape, + original_weight.shape, + original_weight.dtype, + ) + + +class INT8WeightObserver(WeightObserverBase): + """ + OpenVINO INT8 Weight Compression per channel observer. + """ + + def _create_decompressor( + self, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + q_weight: torch.Tensor, + original_weight: torch.Tensor, + ) -> BaseWeightsDecompressor: + if zero_point is None: + return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype) + return INT8AsymmetricWeightsDecompressor( + scale, zero_point, original_weight.dtype + ) diff --git a/backends/openvino/quantizer/observers/nncf_observers.py b/backends/openvino/quantizer/observers/nncf_observers.py deleted file mode 100644 index ac95b1bbef5..00000000000 --- a/backends/openvino/quantizer/observers/nncf_observers.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped] - get_tensor_constant_from_node, -) -from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped] - constant_update_fn, - module_insertion_transformation_builder, -) -from nncf.parameters import CompressWeightsMode # type: ignore[import-untyped] -from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] - WeightCompressionConfig, -) - -from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped] - do_integer_quantization, -) -from nncf.tensor.tensor import Tensor # type: ignore[import-untyped] -from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped] - PTTargetPoint, - TargetType, -) -from nncf.torch.quantization.layers import ( # type: ignore[import-untyped] - INT4AsymmetricWeightsDecompressor, - INT4SymmetricWeightsDecompressor, - INT8AsymmetricWeightsDecompressor, - INT8SymmetricWeightsDecompressor, -) -from torchao.quantization.observer import AffineQuantizedMinMaxObserver -from torchao.quantization.pt2e import ( - get_block_size, - MappingType, - PerAxis, - PerChannelMinMaxObserver, - PerGroup, -) -from torchao.quantization.quant_primitives import _get_reduction_params - - -class PTPerBlockParamObserver(AffineQuantizedMinMaxObserver): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - qmode = ( - CompressWeightsMode.INT4_ASYM - if self.mapping_type == MappingType.ASYMMETRIC - else CompressWeightsMode.INT4_SYM - ) - assert isinstance( - self.granularity, PerGroup - ), "Only PerGroup granularity is supported" - self.wc_config = WeightCompressionConfig( - mode=qmode, group_size=self.granularity.group_size - ) - - def calculate_qparams(self, weight): - assert hasattr(self, "min_val") and hasattr( - self, "max_val" - ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" - _, reduction_dims = _get_reduction_params(self.block_size, weight.size()) - assert len(reduction_dims) == 1, "Only 1-D group size is supported" - reduction_dims = reduction_dims[0] - 1 - q_weight, scale, zp = do_integer_quantization( - Tensor(weight), self.wc_config, reduction_axes=reduction_dims - ) - zp = zp.data if zp is not None else None - return q_weight.data, scale.data, zp - - def convert(self, model: torch.fx.GraphModule, observer_node: torch.fx.Node): - print("calling convert") - assert ( - self.original_dtype is not None - ), "Expecting original_dtype to be populated" - weight_node = observer_node.args[0] - original_weight = get_tensor_constant_from_node(weight_node, model) - q_weight, scale, zero_point = self.calculate_qparams(original_weight) - - with model.graph.inserting_before(observer_node): - if zero_point is not None: - decompressor = INT4AsymmetricWeightsDecompressor( - scale, - zero_point, - q_weight.shape, - original_weight.shape, - original_weight.dtype, - ) - else: - decompressor = INT4SymmetricWeightsDecompressor( - scale, q_weight.shape, original_weight.shape, original_weight.dtype - ) - packed_q_weight = decompressor.pack_weight(q_weight) - constant_update_fn(model, observer_node, packed_q_weight, input_port_id=0) - compressed_weight_name = observer_node.all_input_nodes[0].name - decompressor_suffix = "_".join( - compressed_weight_name.replace(".", "_").split("_")[:-2] - ) - decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" - - module_insertion_transformation_builder( - decompressor, - [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, - target_node_name=compressed_weight_name, - ) - ], - decompressor_name, - )(model) - decomp_node = observer_node.args[0] - observer_node.replace_all_uses_with(decomp_node) # type: ignore[arg-type] - model.graph.erase_node(observer_node) - - -class NNCFInt8observer(PerChannelMinMaxObserver): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - qmode = ( - CompressWeightsMode.INT8_SYM - if self.qscheme == torch.per_channel_symmetric - else CompressWeightsMode.INT8_ASYM - ) - self.wc_config = WeightCompressionConfig(mode=qmode) - - def calculate_qparams(self, weight): - assert hasattr(self, "min_val") and hasattr( - self, "max_val" - ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" - self.granularity = PerAxis(axis=self.ch_axis) - self.block_size = get_block_size(weight.shape, self.granularity) - _, reduction_dims = _get_reduction_params(self.block_size, weight.size()) - q_weight, scale, zp = do_integer_quantization( - Tensor(weight), self.wc_config, reduction_axes=reduction_dims - ) - zp = zp.data if zp is not None else None - return q_weight.data, scale.data, zp - - def convert(self, model: torch.fx.GraphModule, observer_node: torch.fx.Node): - print("calling convert") - weight_node = observer_node.args[0] - original_weight = get_tensor_constant_from_node(weight_node, model) - q_weight, scale, zero_point = self.calculate_qparams(original_weight) - - with model.graph.inserting_before(observer_node): - if zero_point is not None: - decompressor = INT8AsymmetricWeightsDecompressor( - scale, zero_point, original_weight.dtype - ) - else: - decompressor = INT8SymmetricWeightsDecompressor( - scale, original_weight.dtype - ) - packed_q_weight = decompressor.pack_weight(q_weight) - constant_update_fn(model, observer_node, packed_q_weight, input_port_id=0) - compressed_weight_name = observer_node.all_input_nodes[0].name - decompressor_suffix = "_".join( - compressed_weight_name.replace(".", "_").split("_")[:-2] - ) - decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" - - module_insertion_transformation_builder( - decompressor, - [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, - target_node_name=compressed_weight_name, - ) - ], - decompressor_name, - )(model) - decomp_node = observer_node.args[0] - observer_node.replace_all_uses_with(decomp_node) # type: ignore[arg-type] - model.graph.erase_node(observer_node) diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index f2011431a03..bef1ef3274f 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -15,24 +15,20 @@ import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped] import torch.fx -from executorch.backends.openvino.quantizer.observers.nncf_observers import ( - NNCFInt8observer, - PTPerBlockParamObserver, +from executorch.backends.openvino.quantizer.observers import ( + INT4WeightObserver, + INT8WeightObserver, ) - from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] -from nncf.common.quantization.structs import ( # type: ignore[import-untyped] - QuantizationScheme, - QuantizerConfig, +from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] + WeightCompressionParameters, ) from nncf.quantization.quantize_model import ( # type: ignore[import-untyped] get_weight_compression_configuration, ) from torchao.quantization.pt2e import ( HistogramObserver, - MappingType, PerChannelMinMaxObserver, - PerGroup, UniformQuantizationObserverBase, ) from torchao.quantization.pt2e.quantizer import ( @@ -45,7 +41,6 @@ ) QUANT_ANNOTATION_KEY = "quantization_annotation" -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY class QuantizationMode(Enum): @@ -55,15 +50,19 @@ class QuantizationMode(Enum): - INT8_SYM: INT8 symmetric quantization for both activations and weights. - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models + - INT8WO_SYM: INT8 symmetric quantization for weights only. + - INT8WO_ASYM: INT8 asymmetric quantization for weights only. + - INT4WO_SYM: INT4 symmetric quantization for weights only. + - INT4WO_ASYM: INT4 asymmetric quantization for weights only """ INT8_SYM = "int8_sym" INT8_MIXED = "int8_mixed" INT8_TRANSFORMER = "int8_transformer" - INT8_SYM_WC = "int8_sym_wc" - INT8_ASYM_WC = "int8_asym_wc" - INT4_SYM_WC = "int4_sym" - INT4_ASYM_WC = "int4_asym" + INT8WO_SYM = "int8wo_sym" + INT8WO_ASYM = "int8wo_asym" + INT4WO_SYM = "int4wo_sym" + INT4WO_ASYM = "int4wo_asym" class OpenVINOQuantizer(Quantizer): @@ -72,6 +71,13 @@ class OpenVINOQuantizer(Quantizer): optimally for the inference via OpenVINO. """ + WEIGHTS_ONLY_COMPRESSION_MODES = ( + QuantizationMode.INT4WO_SYM, + QuantizationMode.INT4WO_ASYM, + QuantizationMode.INT8WO_SYM, + QuantizationMode.INT8WO_ASYM, + ) + def __init__( self, *, @@ -89,39 +95,32 @@ def __init__( :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm. """ self.mode = mode - self.wc_modes = [ - QuantizationMode.INT4_ASYM_WC, - QuantizationMode.INT4_SYM_WC, - QuantizationMode.INT8_ASYM_WC, - QuantizationMode.INT8_SYM_WC, - ] - if mode == QuantizationMode.INT8_SYM: - preset = quantization.structs.QuantizationPreset.PERFORMANCE - model_type = None - elif mode == QuantizationMode.INT8_MIXED: - preset = quantization.structs.QuantizationPreset.MIXED - model_type = None - else: - preset = None - model_type = nncf.parameters.ModelType.TRANSFORMER - if self.mode not in self.wc_modes: - self._min_max_algo = ( + if self.mode not in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + if mode == QuantizationMode.INT8_SYM: + preset = quantization.structs.QuantizationPreset.PERFORMANCE + model_type = None + elif mode == QuantizationMode.INT8_MIXED: + preset = quantization.structs.QuantizationPreset.MIXED + model_type = None + else: + preset = None + model_type = nncf.parameters.ModelType.TRANSFORMER + self._algo = ( nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization( preset=preset, model_type=model_type, **kwargs ) ) - self._algo = self._min_max_algo else: weight_compression_configuration = get_weight_compression_configuration( mode.value.replace( - "_wc", "" + "wo", "" ), # Mode value has to match NNCF CompressWeightsMode **kwargs, ) - self._weight_compression_algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression( - subset_size=None, **weight_compression_configuration + subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve + self._algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression( + subset_size=subset_size, **weight_compression_configuration ) - self._algo = self._weight_compression_algo def set_ignored_scope( self, @@ -158,104 +157,127 @@ def get_nncf_quantization_setup( self._algo._set_backend_entity(model) return self._algo.find_quantization_setup(model, nncf_graph) - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + def _annotate_weight_compression( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with weight-only quantization specs. - graph = model.graph - node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( - defaultdict(QuantizationAnnotation) + Identifies compressible nodes in the NNCF graph and attaches the corresponding + TorchAO quantization specifications to their weight edges for later transformation. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with weight compression annotations. + """ + self._algo.set_backend_entity(model) + all_wc_params, _ = self._algo.get_weight_compression_parameters( + model, nncf_graph ) - # Serperate into annotation for quantize and compress - if self.mode in self.wc_modes: - self._algo.set_backend_entity(model) - nodes_to_compress = self._algo.get_nodes_to_compress(nncf_graph) - for node in nodes_to_compress: - quantization_insertion_point = ( - quantization.quantizer_setup.WeightQuantizationInsertionPoint( - target_node_name=node.node_name - ) - ) - group_size = self._algo._group_size - num_bits = ( - 4 - if self.mode - in [QuantizationMode.INT4_SYM_WC, QuantizationMode.INT4_ASYM_WC] - else 8 - ) - qmode = ( - QuantizationScheme.SYMMETRIC - if self.mode - in [QuantizationMode.INT4_SYM_WC, QuantizationMode.INT8_SYM_WC] - else QuantizationScheme.ASYMMETRIC - ) - nncf_qconfig = QuantizerConfig(num_bits=num_bits, mode=qmode) - qp = quantization.quantizer_setup.SingleConfigQuantizationPoint( - qip=quantization_insertion_point, - qconfig=nncf_qconfig, - directly_quantized_operator_node_names=[node], - ) - edge_or_node, annotation = self._get_edge_or_node_and_annotation( - graph, nncf_graph, qp, node_vs_torch_annotation - ) - qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_nncf_config( - qp, group_size=group_size, weights_only=True + + for wc_param in all_wc_params: + node_with_weight = wc_param.node_with_weight + target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, node_with_weight.node_name + ) + annotation = node_vs_torch_annotation[target_node] + edge_or_node = self._get_weight_edge(target_node, nncf_graph) + qspec = self._get_torch_ao_qspec_from_nncf_config_for_wc(wc_param=wc_param) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + return node_vs_torch_annotation + + def _annotate_post_training_quantization( + self, + model: torch.fx.GraphModule, + graph: torch.fx.Graph, + nncf_graph: NNCFGraph, + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation], + ) -> DefaultDict[torch.fx.Node, QuantizationAnnotation]: + """ + Annotates the model graph with post-training quantization configurations. + + :param model: The FX GraphModule to annotate. + :param graph: The underlying FX graph. + :param nncf_graph: The corresponding NNCF graph. + :param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations. + :return: Updated mapping of FX nodes with post-training quantization annotations. + """ + quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) + + for qp in quantization_setup.quantization_points.values(): + edge_or_node, annotation = self._get_edge_or_node_and_annotation( + graph, nncf_graph, qp, node_vs_torch_annotation + ) + qspec: QuantizationSpecBase = ( + self._get_torch_ao_qspec_from_nncf_config_for_ptq(qp) + ) + self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + + for quantizer_ids in quantization_setup.unified_scale_groups.values(): + root_quantizer_id = self._get_unified_scales_root_quantizer_id( + nncf_graph, quantizer_ids, quantization_setup + ) + root_qp = quantization_setup.quantization_points[root_quantizer_id] + + if any( + root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig + for q_id in quantizer_ids + ): + qps = [ + quantization_setup.quantization_points[qid] for qid in quantizer_ids + ] + raise nncf.InternalError( + "Different quantization configs are set to one unified scale group:" + f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" ) - self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) - else: - quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph) - for qp in quantization_setup.quantization_points.values(): + root_target_node = nncf_fx.node_utils.get_graph_node_by_name( + graph, root_qp.insertion_point.target_node_name + ) + root_edge_or_node = self._get_edge_or_node( + root_target_node, root_qp, nncf_graph + ) + + for quantizer_id in quantizer_ids: + if quantizer_id == root_quantizer_id: + continue + + qspec = SharedQuantizationSpec(root_edge_or_node) # type: ignore[assignment] + qp = quantization_setup.quantization_points[quantizer_id] edge_or_node, annotation = self._get_edge_or_node_and_annotation( graph, nncf_graph, qp, node_vs_torch_annotation ) - qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_nncf_config( # type: ignore[no-redef] - qp - ) self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) - for quantizer_ids in quantization_setup.unified_scale_groups.values(): + return node_vs_torch_annotation - root_quantizer_id = self._get_unified_scales_root_quantizer_id( - nncf_graph, quantizer_ids, quantization_setup - ) - root_qp = quantization_setup.quantization_points[root_quantizer_id] - - if any( - root_qp.qconfig - != quantization_setup.quantization_points[q_id].qconfig - for q_id in quantizer_ids - ): - qps = [ - quantization_setup.quantization_points[q_id] - for q_id in quantizer_ids - ] - msg = ( - "Different quantization configs are set to one unified scale group:" - f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}" - ) - raise nncf.InternalError(msg) - - root_target_node = nncf_fx.node_utils.get_graph_node_by_name( - graph, root_qp.insertion_point.target_node_name - ) - root_edge_or_node = self._get_edge_or_node( - root_target_node, root_qp, nncf_graph - ) - - for quantizer_id in quantizer_ids: - if quantizer_id == root_quantizer_id: - continue + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model) + graph = model.graph + node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = ( + defaultdict(QuantizationAnnotation) + ) - qspec = SharedQuantizationSpec(root_edge_or_node) - qp = quantization_setup.quantization_points[quantizer_id] - edge_or_node, annotation = self._get_edge_or_node_and_annotation( - graph, nncf_graph, qp, node_vs_torch_annotation - ) - self._fill_torch_ao_annotation(edge_or_node, qspec, annotation) + if self.mode in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES: + node_vs_torch_annotation = self._annotate_weight_compression( + model, graph, nncf_graph, node_vs_torch_annotation + ) + else: + node_vs_torch_annotation = self._annotate_post_training_quantization( + model, graph, nncf_graph, node_vs_torch_annotation + ) for node, annotation in node_vs_torch_annotation.items(): - assert Q_ANNOTATION_KEY not in node.meta - node.meta[Q_ANNOTATION_KEY] = annotation + assert QUANT_ANNOTATION_KEY not in node.meta + node.meta[QUANT_ANNOTATION_KEY] = annotation + return model @staticmethod @@ -317,6 +339,35 @@ def _get_edge_or_node_and_annotation( edge_or_node = OpenVINOQuantizer._get_edge_or_node(target_node, qp, nncf_graph) return edge_or_node, annotation + @staticmethod + def _get_weight_edge( + target_node: torch.fx.Node, + nncf_graph: NNCFGraph, + ) -> tuple[torch.fx.Node, torch.fx.Node]: + """ + Returns the FX node corresponding to the weight tensor input of a given operator node. + Uses the NNCF graph to identify which input port of the target node holds the weight. + If multiple weight ports are present, a warning is issued and only the first one is used. + + :param target_node: FX node representing a weighted operation (e.g., Linear, Conv). + :param nncf_graph: NNCFGraph used to determine weight port indices. + :return: Edge represented by a Tuple of (weight_node, target_node), where weight_node is the FX node supplying the weight. + """ + nncf_node = nncf_graph.get_node_by_name(target_node.name) + weights_ports_ids = nncf.torch.model_graph_manager.get_weight_tensor_port_ids( + nncf_node, nncf_graph + ) + if len(weights_ports_ids) > 1: + # TODO(dlyakhov): support quantization for nodes with several weights + nncf.common.logging.nncf_logger.warning( + f"Quantization of the weighted node {target_node.name}" + " is not yet supported by the OpenVINOQuantizer." + f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." + f" Quantizable weights are located on ports: {weights_ports_ids}." + ) + weight_node = target_node.all_input_nodes[weights_ports_ids[0]] + return (weight_node, target_node) + @staticmethod def _get_edge_or_node( target_node: torch.fx.Node, @@ -333,22 +384,7 @@ def _get_edge_or_node( """ ip = qp.insertion_point if qp.is_weight_quantization_point(): - nncf_node = nncf_graph.get_node_by_name(target_node.name) - weights_ports_ids = ( - nncf.torch.model_graph_manager.get_weight_tensor_port_ids( - nncf_node, nncf_graph - ) - ) - if len(weights_ports_ids) > 1: - # TODO(dlyakhov): support quantization for nodes with several weights - nncf.common.logging.nncf_logger.warning( - f"Quantization of the weighted node {target_node.name}" - " is not yet supported by the OpenVINOQuantizer." - f" Only the weight on port ID {weights_ports_ids[0]} will be quantized." - f" Quantizable weights are located on ports: {weights_ports_ids}." - ) - weight_node = target_node.all_input_nodes[weights_ports_ids[0]] - return (weight_node, target_node) + OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph) if ip.input_port_id is None: return target_node @@ -375,20 +411,72 @@ def _fill_torch_ao_annotation( annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec @staticmethod - def _get_torch_ao_qspec_from_nncf_config( + def _get_torch_ao_qspec_from_nncf_config_for_wc( + wc_param: WeightCompressionParameters, + ) -> QuantizationSpec: + """ + Returns a TorchAO QuantizationSpec based on NNCF weight compression parameter. + + :param wc_param: NNCF Weight compression parameters for the node. + :return: A TorchAO QuantizationSpec. + """ + observer: Type[UniformQuantizationObserverBase] + + extra_args: Dict[str, Any] = {} + + qmode = wc_param.compression_config.mode + extra_args["wc_param"] = wc_param + is_asym_mode = wc_param.compression_config.is_asym_mode + if qmode in [ + nncf.CompressWeightsMode.INT4_ASYM, + nncf.CompressWeightsMode.INT4_SYM, + ]: + observer = INT4WeightObserver # type: ignore[type-abstract] + quant_min = -8 if not is_asym_mode else 0 + quant_max = 7 if not is_asym_mode else 15 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + else: + observer = INT8WeightObserver # type: ignore[type-abstract] + quant_min = -128 if not is_asym_mode else 0 + quant_max = 127 if not is_asym_mode else 255 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if not is_asym_mode + else torch.per_channel_affine + ) + return QuantizationSpec( + dtype=dtype, + observer_or_fake_quant_ctr=observer.with_args(**extra_args), + quant_min=quant_min, + quant_max=quant_max, + qscheme=torch_qscheme, + ch_axis=channel_axis, + is_dynamic=False, + ) + + @staticmethod + def _get_torch_ao_qspec_from_nncf_config_for_ptq( qp: quantization.quantizer_setup.QuantizationPointBase, - group_size=-1, - weights_only=False, ) -> QuantizationSpec: """ - Retrieves the quantization configuration from the given quantization point and - converts it into a QuantizationSpec. + Returns a TorchAO QuantizationSpec based on NNCF quantization point. - :param qp: An instance of QuantizationPointBase. - :return: A QuantizationSpec retrieved and converted from the quantization point. + :param qp: Quantization point from NNCF. + :return: A TorchAO QuantizationSpec. """ + observer: Type[UniformQuantizationObserverBase] + # Eps value is copied from nncf/torch/quantization/layers.py - extra_args = {"eps": 1e-16} + extra_args: Dict[str, Any] = {"eps": 1e-16} + is_weight = qp.is_weight_quantization_point() qconfig = qp.qconfig dtype = torch.int8 @@ -396,7 +484,6 @@ def _get_torch_ao_qspec_from_nncf_config( quant_max = None channel_axis = None - observer: Type[UniformQuantizationObserverBase] if qconfig.per_channel: torch_qscheme = ( torch.per_channel_symmetric @@ -410,33 +497,16 @@ def _get_torch_ao_qspec_from_nncf_config( else torch.per_tensor_affine ) if is_weight: - mapping_type = ( - MappingType.SYMMETRIC - if qconfig.mode == QuantizationScheme.SYMMETRIC - else MappingType.ASYMMETRIC + observer = PerChannelMinMaxObserver + quant_min = -128 + quant_max = 127 + dtype = torch.int8 + channel_axis = 0 + torch_qscheme = ( + torch.per_channel_symmetric + if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC + else torch.per_channel_affine ) - if qconfig.num_bits == 4: - extra_args["mapping_type"] = mapping_type # type: ignore[assignment] - extra_args["target_dtype"] = torch.int8 # type: ignore[assignment] - extra_args["granularity"] = PerGroup(group_size=group_size) # type: ignore[assignment] - observer = PTPerBlockParamObserver - quant_min = -8 - quant_max = 7 - dtype = torch.int8 - channel_axis = 0 - elif qconfig.num_bits == 8: - observer = ( - NNCFInt8observer if weights_only else PerChannelMinMaxObserver - ) - quant_min = -128 - quant_max = 127 - dtype = torch.int8 - channel_axis = 0 - torch_qscheme = ( - torch.per_channel_symmetric - if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC - else torch.per_channel_affine - ) else: observer = ( HistogramObserver diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c03630b3a1f..cbbf169a085 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -43,6 +43,7 @@ ) from executorch.extension.llm.export.quantizer_lib import ( get_coreml_quantizer, + get_ov_quantizer, get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, @@ -196,6 +197,8 @@ def build_args_parser() -> argparse.ArgumentParser: choices=[ "xnnpack_dynamic", "xnnpack_dynamic_qc4", + "openvino_4wo", + "openvino_8wo", "qnn_8a8w", "qnn_16a16w", "qnn_16a4w", @@ -555,13 +558,6 @@ def build_args_parser() -> argparse.ArgumentParser: help="path to the input pruning token mapping file (token_map.json)", ) - parser.add_argument( - "--nncf_compression", - default=False, - action="store_true", - help="Enables nncf compression for openvino backend", - ) - parser.add_argument( "--export_only", default=False, @@ -769,6 +765,14 @@ def get_quantizer_and_quant_params(llm_config): llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) + if llm_config.backend.openvino.enabled and llm_config.quantization.pt2e_quantize: + assert not quantizers, "Should not enable both xnnpack and openvino" + group_size = llm_config.quantization.group_size + group_size = group_size if group_size else 128 + ov_quantizer = get_ov_quantizer( + llm_config.quantization.pt2e_quantize.value, group_size + ) + quantizers.append(ov_quantizer) if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" coreml_quantizer = get_coreml_quantizer( @@ -878,10 +882,9 @@ def _to_edge_and_lower_llama_xnnpack( def _to_edge_and_lower_llama_openvino( builder_exported, modelname, + quantizers, additional_passes, openvino_device: str = "CPU", - nncf_compression: bool = False, - nncf_compression_group_size: int = 32, verbose: bool = False, ) -> LLMEdgeManager: # noqa: C901 partitioners = [] @@ -894,60 +897,9 @@ def _to_edge_and_lower_llama_openvino( for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - # Use NNCF compression if enabled - # TODO: Enable passing OpenVINOQuantizer as a parameter to pt2e_quantize - if nncf_compression: - try: - from functools import partial - - import nncf - from pytorch_tokenizers import get_tokenizer - except ImportError: - raise ImportError( - "Please install nncf via backends/openvino/requirements.txt" - ) - tokenizer = get_tokenizer(builder_exported.tokenizer_path) - - def transform_fn(prompts: str, tokenizer): - tokenized_text = tokenizer.encode(prompts, bos=False, eos=False) - logging.error(tokenized_text) - - inputs = () - inputs = ( - torch.tensor(tokenized_text).unsqueeze(0), - {"input_pos": torch.tensor([0])}, - ) - - return inputs - - builder_exported.calibration_data = ( - [builder_exported.calibration_data] - if isinstance(builder_exported.calibration_data, str) - else builder_exported.calibration_data - ) - builder_exported.calibration_data = ( - [ - word - for prompt in builder_exported.calibration_data - for word in prompt.split() - ] - if not builder_exported.dynamic_shapes - else builder_exported.calibration_data - ) - - builder_exported.pre_autograd_graph_module = nncf.compress_weights( - builder_exported.pre_autograd_graph_module, - dataset=nncf.Dataset( - builder_exported.calibration_data, - transform_func=partial(transform_fn, tokenizer=tokenizer), - ), - mode=nncf.CompressWeightsMode.INT4_SYM, - ratio=0.8, - group_size=nncf_compression_group_size, - sensitivity_metric=nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION, - ) - - builder = builder_exported.to_edge_transform_and_lower(partitioners) + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners + ) if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) @@ -1194,10 +1146,9 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 builder = _to_edge_and_lower_llama_openvino( builder_exported, modelname, + quantizers, additional_passes, openvino_device=llm_config.backend.openvino.device, - nncf_compression=llm_config.backend.openvino.nncf_compression, - nncf_compression_group_size=llm_config.backend.openvino.nncf_compression_group_size, verbose=llm_config.debug.verbose, ) else: diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index c8f15bc1f9a..615bb582880 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -275,6 +275,8 @@ class Pt2eQuantize(str, Enum): xnnpack_dynamic = "xnnpack_dynamic" xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4" + openvino_4wo = "openvino_4wo" + openvino_8wo = "openvino_8wo" qnn_8a8w = "qnn_8a8w" qnn_16a16w = "qnn_16a16w" qnn_16a4w = "qnn_16a4w" diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 2d87c86d113..f92c59cebd3 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -215,6 +215,47 @@ def get_qnn_quantizer( return qnn_quantizer, quant_dtype +def get_ov_quantizer( + pt2e_quantize: str, + group_size: int = 128, +): + try: + from executorch.backends.openvino.quantizer import ( + OpenVINOQuantizer, + QuantizationMode, + ) + except ImportError: + raise ImportError("Please install nncf via backends/openvino/requirements.txt") + + backend, quant_config = pt2e_quantize.split("_") + assert ( + backend == "openvino" + ), f"The quantization config is for backend {backend} instead of openvino." + assert ( + group_size + ), "Group Size None is Not Supported. It should be set to -1 for per-channel." + + quantization_params = {} + + if quant_config == "4wo": + quantization_params["mode"] = QuantizationMode.INT4WO_ASYM + quantization_params["group_size"] = group_size + quantization_params["ratio"] = 0.8 + + elif quant_config == "8wo": + quantization_params["mode"] = QuantizationMode.INT8WO_ASYM + quantization_params["group_size"] = -1 + quantization_params["ratio"] = None + + else: + raise AssertionError( + f"No support for quant type {quant_config}. Support 8a4w, 8a8w only." + ) + ov_quantizer = OpenVINOQuantizer(**quantization_params) + + return ov_quantizer + + def get_coreml_quantizer(pt2e_quantize: str): try: from coremltools.optimize.torch.quantization.quantization_config import (